diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 5f35e9ad10..6df6dabd29 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -77,7 +77,7 @@ def check_group_quantization_nvfp4_versus_reference( reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) ) - split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) + split_quantize_outputs, _ = tex.split_quantize(x, split_sections, quantizers) if return_identity: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815b..92efd9e42a 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -194,7 +194,7 @@ def group_quantize_fp4( ] if use_tex_split_quantize: - outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers) + outputs, _ = tex.split_quantize(x, split_sections, nvfp4_quantizers) qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs] sx_list = [output._rowwise_scale_inv for output in outputs] qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs] diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 05219b7b18..e1eaa9fa46 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -45,8 +45,6 @@ def mark_activation_offload(*tensors): def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" - if NVTE_CPU_OFFLOAD_V1: - return tensors, tensor_obj = prepare_for_saving(*tensors) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b2b0751b04..db74c9efd3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -260,10 +260,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation = false); +std::tuple, std::vector> split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation = false); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..c38075b9ed 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -345,13 +345,15 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { -std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +std::tuple, std::vector, std::vector> +bulk_allocate_fp8_blockwise_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector> retval; + std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); + auto &buffer_list = std::get<2>(retval); // Buffers for offload // Number of tensors const size_t num_tensors = shape_list.size(); @@ -412,6 +414,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -455,6 +458,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -497,13 +501,15 @@ std::tuple, std::vector> bulk_allocate_fp return retval; } -std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +std::tuple, std::vector, std::vector> +bulk_allocate_mxfp8_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector> retval; + std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); + auto &buffer_list = std::get<2>(retval); // Buffers for offload // Number of tensors const size_t num_tensors = shape_list.size(); @@ -565,6 +571,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -605,6 +612,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -650,14 +658,17 @@ std::tuple, std::vector> bulk_allocate_mx // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate -std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +std::tuple, std::vector, bool, std::vector> +bulk_allocate_nvfp4_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector, bool> retval; + std::tuple, std::vector, bool, std::vector> + retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); auto &contiguous_data_and_scale = std::get<2>(retval); + auto &buffer_list = std::get<3>(retval); // Buffers for offload contiguous_data_and_scale = true; // Number of tensors @@ -742,6 +753,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -804,6 +816,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -1250,10 +1263,9 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, } // namespace -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation) { +std::tuple, std::vector> split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation) { init_extension(); // Check number of tensors @@ -1261,7 +1273,7 @@ std::vector split_quantize(const at::Tensor &tensor, NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ", quantizer_list.size()); if (num_splits == 0) { - return {}; + return {{}, {}}; } // Input tensor properties @@ -1328,6 +1340,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Allocate output tensors std::vector output_cpp_list; std::vector output_py_list; + std::vector buffer_list; // Buffers for offload (can be used for record_stream) switch (allocation_method) { case AllocationMethod::BULK_FP8_BLOCKWISE: { // Bulk allocation for FP8 block-scaling tensors @@ -1335,7 +1348,7 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { blockwise_quantizers.push_back(static_cast(quantizer.get())); } - std::tie(output_py_list, output_cpp_list) = + std::tie(output_py_list, output_cpp_list, buffer_list) = bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers); break; } @@ -1345,7 +1358,7 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { mxfp8_quantizers.push_back(static_cast(quantizer.get())); } - std::tie(output_py_list, output_cpp_list) = + std::tie(output_py_list, output_cpp_list, buffer_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); break; } @@ -1356,7 +1369,7 @@ std::vector split_quantize(const at::Tensor &tensor, nvfp4_quantizers.push_back(static_cast(quantizer.get())); } bool contiguous_data_and_scale; - std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = + std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale, buffer_list) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous @@ -1393,7 +1406,7 @@ std::vector split_quantize(const at::Tensor &tensor, multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); } - return output_py_list; + return {output_py_list, buffer_list}; } } // namespace pytorch diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f4b1fb23ae..a29236cbb6 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -108,6 +108,8 @@ def _make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -445,6 +447,8 @@ def hook_fn( for module in func.modules(): hook = module.register_forward_hook(hook_fn) hooks.append(hook) + if pre_warmup_hook is not None: + pre_warmup_hook() outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: hook.remove() @@ -507,6 +511,8 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs + if post_warmup_hook is not None: + post_warmup_hook() # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): @@ -517,7 +523,6 @@ def hook_fn( # All captures here share a mempool. To avoid replays corrupting each other's memory, # the safest approach is to capture all passes in the same order they'll run: # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. - if _order is not None: # pylint: disable=too-many-nested-blocks per_callable_static_outputs = [None] * len(flatten_sample_args) per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) @@ -782,14 +787,15 @@ class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @staticmethod - def forward(ctx, skip_fp8_weight_update, *inputs): + def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs): # pylint: disable=missing-function-docstring # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) - + ctx.cuda_graph_stream = cuda_graph_stream + ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors for i in range(len_user_args): if ( @@ -799,7 +805,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs): static_input_surface[i].copy_(inputs[i]) # Replay forward graph - fwd_graph.replay() + cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with cuda_graph_stream: + fwd_graph.replay() + torch.cuda.current_stream().wait_event(cuda_graph_event) assert isinstance(static_outputs, tuple) return tuple(o.detach() if o is not None else o for o in static_outputs) @@ -816,7 +825,10 @@ def backward(ctx, *grads): # incoming grad is already in the right place if g.data_ptr() != grad.data_ptr(): g.copy_(grad) - bwd_graph.replay() + ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with ctx.cuda_graph_stream: + bwd_graph.replay() + torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) # Update FP8 scale factors if needed if ctx.is_first_module: @@ -824,7 +836,7 @@ def backward(ctx, *grads): # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) - return (None,) + tuple( + return (None, None, None) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) @@ -839,6 +851,16 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + if "cuda_graph_stream" in user_kwargs: + cuda_graph_stream = user_kwargs["cuda_graph_stream"] + user_kwargs.pop("cuda_graph_stream") + else: + cuda_graph_stream = torch.cuda.current_stream() + if "cuda_graph_event" in user_kwargs: + cuda_graph_event = user_kwargs["cuda_graph_event"] + user_kwargs.pop("cuda_graph_event") + else: + cuda_graph_event = torch.cuda.Event() # Check that required kwargs are provided for key in kwargs_keys: if key not in user_kwargs: @@ -854,7 +876,9 @@ def functionalized(*user_args, **user_kwargs): flatten_user_args, _ = _tree_flatten(user_args) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params - out = Graphed.apply(skip_fp8_weight_update, *func_args) + out = Graphed.apply( + skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args + ) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -867,6 +891,9 @@ def make_graphed_attribute_functions(graph_idx): def backward_dw(): if need_bwd_dw_graph.get(graph_idx, False): bwd_dw_graphs[graph_idx].replay() + for module in te_modules: + if hasattr(module, "trigger_backward_dw"): + module.trigger_backward_dw() # Trigger the grad accumulation hook for wgrad graphs. for module in te_modules: @@ -1040,6 +1067,8 @@ def make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1264,6 +1293,8 @@ def call_func(self, *args, **kwargs): pool=pool, retain_graph_in_backward=retain_graph_in_backward, _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, + pre_warmup_hook=pre_warmup_hook, + post_warmup_hook=post_warmup_hook, ) # Ensures warmup does not affect numerics for ops such as dropout. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4858383c26..fbb765c045 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1512,6 +1512,13 @@ def need_backward_dw(self): """ return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() + def trigger_backward_dw(self): + """ + Trigger the wgrad accumulation and reduce hooks. + """ + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() + def backward_dw(self): """ Execute the delayed weight gradient computation. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b381073d78..885b34cf2a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -51,6 +51,8 @@ Quantizer, prepare_for_saving, restore_from_saved, + get_columnwise_subview_info, + restore_columnwise_subviews, ) from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_state import TEDebugState @@ -145,16 +147,24 @@ def forward( ) inp_view = inp.reshape(-1, in_features) inputmats: list + offload_buffer: torch.Tensor = None + subview_restore_info: list = [] if fp8 and not debug: # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. - inputmats = tex.split_quantize( - inp_view, - m_splits, - input_quantizers, - disable_bulk_allocation=cpu_offloading, - ) + inputmats, buffer_list = tex.split_quantize(inp_view, m_splits, input_quantizers) + if cpu_offloading: + # Mark inputmats as not offload - we offload the buffer instead + mark_not_offload(*inputmats) + # buffer_list layout: [rowwise_buffer?, columnwise_buffer?] + # columnwise buffer is always last if present; we only offload it + # since rowwise data is discarded when weight_requires_grad is True + if buffer_list and input_quantizers[0].columnwise_usage: + offload_buffer = buffer_list[-1] + # Get subview boundary info for restoration in backward + if offload_buffer is not None: + subview_restore_info = get_columnwise_subview_info(inputmats, offload_buffer) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -163,7 +173,12 @@ def forward( inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) if cpu_offloading: - start_offload(*inputmats) + if offload_buffer is not None: + # Offload the buffer instead of individual tensors + # (rowwise data is discarded when weight_requires_grad is True) + start_offload(offload_buffer) + else: + start_offload(*inputmats) # Initialize weights weights_fp8: list @@ -238,10 +253,15 @@ def forward( if save_original_input: inputmats = [None] * num_gemms inputmats[0] = inp + offload_buffer = None + subview_restore_info = [] else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if cpu_offloading and offload_buffer is not None: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms @@ -263,6 +283,7 @@ def forward( *weights_fp8, *weights, *biases, + offload_buffer, ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects @@ -309,6 +330,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.subview_restore_info = subview_restore_info # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -323,8 +345,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] + offload_buffer = saved_tensors[4 * N] if len(saved_tensors) > 4 * N else None main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + # Restore subviews from reloaded buffer + if ctx.cpu_offloading and ctx.subview_restore_info and offload_buffer is not None: + restore_columnwise_subviews(inputmats, offload_buffer, ctx.subview_restore_info) + if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: for i, weight in enumerate(ctx.weight_objects): @@ -354,14 +381,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Unfused bias grad and multi-tensor quantize for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) - grad_output = tex.split_quantize( + grad_output, _ = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, ) else: # Multi-tensor quantize - grad_output = tex.split_quantize( + grad_output, _ = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, @@ -453,7 +480,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + inputmats, _ = tex.split_quantize( + inp_view, ctx.m_splits, ctx.input_quantizers + ) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d78677bc83..c9d848a1e4 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -173,6 +173,108 @@ def restore_from_saved( return tensor_objects +def get_columnwise_subview_info(inputmats: list, columnwise_buffer: torch.Tensor) -> list: + """ + Get boundary information for columnwise internal tensors in inputmats. + + This function extracts the byte offsets, shapes, strides, and dtypes of + columnwise internal tensors (_columnwise_data, _columnwise_scale_inv, + _columnwise_amax) relative to a shared buffer. This information is used + to restore subviews after CPU offload/reload. + + Only extracts columnwise data info since rowwise data is typically + discarded when weight_requires_grad is True. + + Uses tuples instead of dicts to minimize CPU overhead. + + Args: + inputmats: List of QuantizedTensorStorage objects created by bulk allocation. + columnwise_buffer: The buffer that contains the columnwise data (must be uint8). + + Returns: + List of tuples: [(columnwise_data_info, columnwise_scale_info, columnwise_amax_info), ...] + Each info is (byte_offset, shape, stride, dtype) or None if not present. + byte_offset is the offset in bytes from the start of the buffer. + """ + if columnwise_buffer is None: + return [] + + info_list = [] + buffer_ptr = columnwise_buffer.data_ptr() + + for tensor in inputmats: + # Get columnwise_data info + # Use data_ptr() difference to get the actual byte offset in buffer + col_data = getattr(tensor, "_columnwise_data", None) + col_data_info = ( + (col_data.data_ptr() - buffer_ptr, col_data.shape, col_data.stride(), col_data.dtype) + if col_data is not None + else None + ) + + # Get columnwise_scale_inv info + col_scale = getattr(tensor, "_columnwise_scale_inv", None) + col_scale_info = ( + ( + col_scale.data_ptr() - buffer_ptr, + col_scale.shape, + col_scale.stride(), + col_scale.dtype, + ) + if col_scale is not None + else None + ) + + # Get columnwise_amax info (for NVFP4) + col_amax = getattr(tensor, "_columnwise_amax", None) + col_amax_info = ( + (col_amax.data_ptr() - buffer_ptr, col_amax.shape, col_amax.stride(), col_amax.dtype) + if col_amax is not None + else None + ) + + info_list.append((col_data_info, col_scale_info, col_amax_info)) + return info_list + + +def restore_columnwise_subviews( + inputmats: list, columnwise_buffer: torch.Tensor, info_list: list +) -> None: + """ + Restore columnwise internal tensors from reloaded buffer. + + After CPU offload and reload, the columnwise_buffer may be at a new memory + location. This function restores the columnwise internal tensors of inputmats + to point to the correct locations in the reloaded buffer using as_strided. + + Args: + inputmats: List of QuantizedTensorStorage objects to restore. + columnwise_buffer: The reloaded columnwise buffer (must be uint8). + info_list: Boundary info returned by get_columnwise_subview_info(). + """ + if columnwise_buffer is None or not info_list: + return + + for tensor, info in zip(inputmats, info_list): + col_data_info, col_scale_info, col_amax_info = info + + # Restore columnwise_data using as_strided (avoids empty + set_ overhead) + # NOTE: byte_offset == element_offset because buffer dtype is uint8 + if col_data_info is not None: + byte_offset, shape, stride, _ = col_data_info + tensor._columnwise_data = columnwise_buffer.as_strided(shape, stride, byte_offset) + + # Restore columnwise_scale_inv + if col_scale_info is not None: + byte_offset, shape, stride, _ = col_scale_info + tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) + + # Restore columnwise_amax (NVFP4) + if col_amax_info is not None: + byte_offset, shape, stride, _ = col_amax_info + tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) + + class Quantizer(abc.ABC): """Builder class for quantized tensors. diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 4cd6d19cd8..bcc8815082 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -358,9 +358,9 @@ def update_usage( rowwise_usage = self._rowwise_data is not None if columnwise_usage is None: columnwise_usage = self._columnwise_data is not None - assert ( - columnwise_usage or rowwise_usage - ), "Must retain some data either columnwise or rowwise" + # assert ( + # columnwise_usage or rowwise_usage + # ), "Must retain some data either columnwise or rowwise" if columnwise_usage and rowwise_usage: if not self._is_2D_scaled: