From 5fc98579ffb4f9de0690d39f7531d00c8242d877 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 30 Nov 2025 21:33:30 -0800 Subject: [PATCH 01/17] support cuda graph capture offloading module Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/graph.py | 40 ++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f55f1dd128..32661742ad 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -92,6 +92,9 @@ def _make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, + init_chunk_handler: Optional[Callable] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -407,6 +410,8 @@ def hook_fn( for module in func.modules(): hook = module.register_forward_hook(hook_fn) hooks.append(hook) + if pre_warmup_hook is not None: + pre_warmup_hook() outputs, _ = _tree_flatten(func(*args, **kwargs)) for hook in hooks: hook.remove() @@ -458,6 +463,8 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs + if post_warmup_hook is not None: + post_warmup_hook() # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): @@ -481,6 +488,8 @@ def hook_fn( for c_id in _order: if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] + if init_chunk_handler is not None: + init_chunk_handler(vp_stage=c_id-1) m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): func = callables[_prefix_num_layers[m_chunk] + l_no] @@ -679,14 +688,15 @@ class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @staticmethod - def forward(ctx, skip_fp8_weight_update, *inputs): + def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs): # pylint: disable=missing-function-docstring # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) - + ctx.cuda_graph_stream = cuda_graph_stream + ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors for i in range(len_user_args): if ( @@ -696,7 +706,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs): static_input_surface[i].copy_(inputs[i]) # Replay forward graph - fwd_graph.replay() + cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with cuda_graph_stream: + fwd_graph.replay() + torch.cuda.current_stream().wait_event(cuda_graph_event) assert isinstance(static_outputs, tuple) return tuple(o.detach() for o in static_outputs) @@ -713,7 +726,10 @@ def backward(ctx, *grads): # incoming grad is already in the right place if g.data_ptr() != grad.data_ptr(): g.copy_(grad) - bwd_graph.replay() + ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with ctx.cuda_graph_stream: + bwd_graph.replay() + torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) # Update FP8 scale factors if needed if ctx.is_first_module: @@ -721,7 +737,7 @@ def backward(ctx, *grads): # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) - return (None,) + tuple( + return (None, None, None) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) @@ -736,6 +752,12 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + assert "cuda_graph_stream" in user_kwargs + assert "cuda_graph_event" in user_kwargs + cuda_graph_stream = user_kwargs["cuda_graph_stream"] + cuda_graph_event = user_kwargs["cuda_graph_event"] + user_kwargs.pop("cuda_graph_stream") + user_kwargs.pop("cuda_graph_event") # Check that required kwargs are provided for key in kwargs_keys: if key not in user_kwargs: @@ -751,7 +773,7 @@ def functionalized(*user_args, **user_kwargs): flatten_user_args, _ = _tree_flatten(user_args) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params - out = Graphed.apply(skip_fp8_weight_update, *func_args) + out = Graphed.apply(skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -927,6 +949,9 @@ def make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, + init_chunk_handler: Optional[Callable] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1151,6 +1176,9 @@ def call_func(self, *args, **kwargs): pool=pool, retain_graph_in_backward=retain_graph_in_backward, _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, + pre_warmup_hook=pre_warmup_hook, + post_warmup_hook=post_warmup_hook, + init_chunk_handler=init_chunk_handler, ) # Ensures warmup does not affect numerics for ops such as dropout. From 913fbe8adf37bd4aee4f525901a1f76f44a4ed40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 05:35:03 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 32661742ad..a8543778a5 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -489,7 +489,7 @@ def hook_fn( if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] if init_chunk_handler is not None: - init_chunk_handler(vp_stage=c_id-1) + init_chunk_handler(vp_stage=c_id - 1) m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): func = callables[_prefix_num_layers[m_chunk] + l_no] @@ -773,7 +773,9 @@ def functionalized(*user_args, **user_kwargs): flatten_user_args, _ = _tree_flatten(user_args) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params - out = Graphed.apply(skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args) + out = Graphed.apply( + skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args + ) return _tree_unflatten(out, output_unflatten_spec) return functionalized From e04bc00139ff500d60eb4aef6aff43a86f281dd4 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 06:35:16 -0800 Subject: [PATCH 03/17] remove reset_hook and init_chunk_handler_hook Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 32661742ad..1c192d8bba 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -475,7 +475,6 @@ def hook_fn( # All captures here share a mempool. To avoid replays corrupting each other's memory, # the safest approach is to capture all passes in the same order they'll run: # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. - if _order is not None: # pylint: disable=too-many-nested-blocks per_callable_static_outputs = [None] * len(flatten_sample_args) per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) From dda34c26bd038923e117a5704a8952b7ade7671a Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Dec 2025 06:35:51 -0800 Subject: [PATCH 04/17] remove reset_hook and init_chunk_handler_hook Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/graph.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 1c192d8bba..5a5c715e7e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -94,7 +94,6 @@ def _make_graphed_callables( _reuse_graph_input_output_buffers: bool = False, pre_warmup_hook: Optional[Callable] = None, post_warmup_hook: Optional[Callable] = None, - init_chunk_handler: Optional[Callable] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -487,8 +486,6 @@ def hook_fn( for c_id in _order: if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] - if init_chunk_handler is not None: - init_chunk_handler(vp_stage=c_id-1) m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): func = callables[_prefix_num_layers[m_chunk] + l_no] @@ -950,7 +947,6 @@ def make_graphed_callables( _reuse_graph_input_output_buffers: bool = False, pre_warmup_hook: Optional[Callable] = None, post_warmup_hook: Optional[Callable] = None, - init_chunk_handler: Optional[Callable] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1177,7 +1173,6 @@ def call_func(self, *args, **kwargs): _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, pre_warmup_hook=pre_warmup_hook, post_warmup_hook=post_warmup_hook, - init_chunk_handler=init_chunk_handler, ) # Ensures warmup does not affect numerics for ops such as dropout. From 88295b4d72ff41b035a57641482001ee7bf6d5c5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 21:50:01 -0800 Subject: [PATCH 05/17] minor fix Signed-off-by: root --- transformer_engine/pytorch/graph.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 8fe8798138..536b1bb224 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -748,12 +748,16 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] - assert "cuda_graph_stream" in user_kwargs - assert "cuda_graph_event" in user_kwargs - cuda_graph_stream = user_kwargs["cuda_graph_stream"] - cuda_graph_event = user_kwargs["cuda_graph_event"] - user_kwargs.pop("cuda_graph_stream") - user_kwargs.pop("cuda_graph_event") + if "cuda_graph_stream" in user_kwargs: + cuda_graph_stream = user_kwargs["cuda_graph_stream"] + user_kwargs.pop("cuda_graph_stream") + else: + cuda_graph_stream = torch.cuda.current_stream() + if "cuda_graph_event" in user_kwargs: + cuda_graph_event = user_kwargs["cuda_graph_event"] + user_kwargs.pop("cuda_graph_event") + else: + cuda_graph_event = torch.cuda.Event() # Check that required kwargs are provided for key in kwargs_keys: if key not in user_kwargs: From 09d0801f1daad559c8648700fbf3a562dad67d2f Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 19 Jan 2026 06:39:08 -0800 Subject: [PATCH 06/17] temp fix overlap-grad-reduce Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/graph.py | 7 +++++-- transformer_engine/pytorch/module/base.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 38bd802007..435398957e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -842,12 +842,15 @@ def functionalized(*user_args, **user_kwargs): return functionalized - def make_graphed_attribute_functions(graph_idx): + def make_graphed_attribute_functions(graph_idx, te_modules): # Attach backward_dw as an attribute to the graphed callable. def backward_dw(): if need_bwd_dw_graph.get(graph_idx, False): bwd_dw_graphs[graph_idx].replay() + for module in te_modules: + if hasattr(module, "trigger_backward_dw"): + module.trigger_backward_dw() # Attach reset as an attribute to the graphed callable. def reset(): @@ -932,7 +935,7 @@ def new_fwd(*user_args, **user_kwargs): else: ret.append(graphed) - backward_dw_func, reset_func = make_graphed_attribute_functions(i) + backward_dw_func, reset_func = make_graphed_attribute_functions(i, te_modules) setattr(ret[-1], "backward_dw", backward_dw_func) setattr(ret[-1], "reset", reset_func) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index acf9233281..55ef5c5184 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1479,6 +1479,9 @@ def need_backward_dw(self): """ return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() + def trigger_backward_dw(self): + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def backward_dw(self): """ Execute the delayed weight gradient computation. From c3e341aaea00a30dfecd45f9fd20e7d74e47af0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 14:41:45 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 55ef5c5184..b966611595 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1482,6 +1482,7 @@ def need_backward_dw(self): def trigger_backward_dw(self): for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: wgrad_accumulation_and_reduce_hook() + def backward_dw(self): """ Execute the delayed weight gradient computation. From 6cd4af99ce3f91a548b56e2d3d641ebc3fc68ff1 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 20 Jan 2026 05:03:27 -0800 Subject: [PATCH 08/17] reuse mark_not_offload() and do not offload scale_inv Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/cpu_offload.py | 4 ++-- .../pytorch/tensor/storage/mxfp8_tensor_storage.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 58ed063066..568ecb0a24 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -44,8 +44,8 @@ def mark_activation_offload(*tensors): def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" - if NVTE_CPU_OFFLOAD_V1: - return + # if NVTE_CPU_OFFLOAD_V1: + # return tensors, tensor_obj = prepare_for_saving(*tensors) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index e7840d2c43..cd4fe257fd 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -18,6 +18,7 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...utils import _empty_tensor +from ...cpu_offload import mark_not_offload class _FromMXFP8Func(torch.autograd.Function): @@ -85,6 +86,7 @@ def __new__( instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv + mark_not_offload(instance._rowwise_scale_inv, instance._columnwise_scale_inv) return instance From ba065fcffc25134f78705876484936d466c26528 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 22 Jan 2026 06:50:55 -0800 Subject: [PATCH 09/17] temp fix for mxfp8 Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/csrc/extensions/cast.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index fb5f0b55d4..641f367e78 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -985,6 +985,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Allocate output tensors std::vector output_cpp_list; std::vector output_py_list; + allocation_method = AllocationMethod::UNFUSED; switch (allocation_method) { case AllocationMethod::BULK_FP8_BLOCKWISE: { // Bulk allocation for FP8 block-scaling tensors From e00db5e5ec7ecccad6bea70c799240bf2925303f Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 1 Feb 2026 17:46:17 -0800 Subject: [PATCH 10/17] fix bug for record_stream and from_blob Signed-off-by: Hongbin Liu --- .../pytorch/csrc/extensions/cast.cpp | 99 +++++++++++-------- .../tensor/storage/mxfp8_tensor_storage.py | 2 +- 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 641f367e78..4f036f1f81 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -216,20 +216,29 @@ std::tuple, std::vector> bulk_allocate_fp constexpr size_t fp8_elem_size = 1; constexpr size_t scale_elem_size = 4; - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + // Helper function to construct tensor view using storage sharing + // Note: All views share the same storage, so record_stream() works correctly. + auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { + if (buffer.data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); + // Calculate storage offset based on dtype element size + size_t elem_size = at::elementSize(dtype); + int64_t storage_offset = static_cast(offset / elem_size); + // Compute default strides for the shape + std::vector strides(shape_int64.size()); + if (!strides.empty()) { + strides.back() = 1; + for (int64_t i = static_cast(strides.size()) - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape_int64[i + 1]; + } + } + // Create view sharing the same storage + return at::empty({0}, buffer.options().dtype(dtype)) + .set_(buffer.storage(), storage_offset, shape_int64, strides); }; // Allocate row-wise data @@ -258,8 +267,7 @@ std::tuple, std::vector> bulk_allocate_fp } // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -301,8 +309,7 @@ std::tuple, std::vector> bulk_allocate_fp } // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -367,20 +374,29 @@ std::tuple, std::vector> bulk_allocate_mx constexpr size_t fp8_elem_size = 1; constexpr size_t scale_elem_size = 1; - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + // Helper function to construct tensor view using storage sharing + // Note: All views share the same storage, so record_stream() works correctly. + auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { + if (buffer.data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); + // Calculate storage offset based on dtype element size + size_t elem_size = at::elementSize(dtype); + int64_t storage_offset = static_cast(offset / elem_size); + // Compute default strides for the shape + std::vector strides(shape_int64.size()); + if (!strides.empty()) { + strides.back() = 1; + for (int64_t i = static_cast(strides.size()) - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape_int64[i + 1]; + } + } + // Create view sharing the same storage + return at::empty({0}, buffer.options().dtype(dtype)) + .set_(buffer.storage(), storage_offset, shape_int64, strides); }; // Allocate row-wise data @@ -409,8 +425,7 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -449,8 +464,7 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -518,20 +532,29 @@ std::tuple, std::vector, bool> bulk_alloc const auto fp4_dtype = quantizer_cpp_list[0]->dtype; constexpr size_t scale_elem_size = 1; - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + // Helper function to construct tensor view using storage sharing + // Note: All views share the same storage, so record_stream() works correctly. + auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { + if (buffer.data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); + // Calculate storage offset based on dtype element size + size_t elem_size = at::elementSize(dtype); + int64_t storage_offset = static_cast(offset / elem_size); + // Compute default strides for the shape + std::vector strides(shape_int64.size()); + if (!strides.empty()) { + strides.back() = 1; + for (int64_t i = static_cast(strides.size()) - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape_int64[i + 1]; + } + } + // Create view sharing the same storage + return at::empty({0}, buffer.options().dtype(dtype)) + .set_(buffer.storage(), storage_offset, shape_int64, strides); }; // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) @@ -584,8 +607,7 @@ std::tuple, std::vector, bool> bulk_alloc } // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -646,8 +668,7 @@ std::tuple, std::vector, bool> bulk_alloc } // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index cd4fe257fd..a2a11625e5 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -86,7 +86,7 @@ def __new__( instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv - mark_not_offload(instance._rowwise_scale_inv, instance._columnwise_scale_inv) + # mark_not_offload(instance._rowwise_scale_inv, instance._columnwise_scale_inv) return instance From f47b54382274973472deb5eba83ab629f2356f07 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 2 Feb 2026 19:23:29 -0800 Subject: [PATCH 11/17] disable offloading core_attn_out and refine cpu overhead of at::empty Signed-off-by: Hongbin Liu --- .../dot_product_attention/backends.py | 5 +++ .../pytorch/csrc/extensions/cast.cpp | 34 +++++++++++++------ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c1ff46c75a..6b457f619f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -59,6 +59,7 @@ is_cpu_offload_enabled, start_offload, mark_activation_offload, + mark_not_offload, NVTE_CPU_OFFLOAD_V1, ) from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded @@ -1249,6 +1250,9 @@ def forward( # return appropriate tensors out_ret = out_fp8 if is_output_fp8 else out + mark_not_offload(out_fp8) + mark_not_offload(out) + # save appropriate tensors fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) @@ -1298,6 +1302,7 @@ def forward( out = out_ out_ret = out_ fp8_tensors = (None, None, None, None) + mark_not_offload(out) qkvo_tensors = (q, k, v, out) nvtx_range_pop(f"{nvtx_label}") diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 4f036f1f81..f821aa2c18 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -236,9 +236,14 @@ std::tuple, std::vector> bulk_allocate_fp strides[i] = strides[i + 1] * shape_int64[i + 1]; } } - // Create view sharing the same storage - return at::empty({0}, buffer.options().dtype(dtype)) - .set_(buffer.storage(), storage_offset, shape_int64, strides); + // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) + auto impl = c10::make_intrusive( + c10::Storage(buffer.storage()), + buffer.key_set(), + caffe2::TypeMeta::fromScalarType(dtype)); + impl->set_storage_offset(storage_offset); + impl->set_sizes_and_strides(shape_int64, strides); + return at::Tensor(std::move(impl)); }; // Allocate row-wise data @@ -394,9 +399,14 @@ std::tuple, std::vector> bulk_allocate_mx strides[i] = strides[i + 1] * shape_int64[i + 1]; } } - // Create view sharing the same storage - return at::empty({0}, buffer.options().dtype(dtype)) - .set_(buffer.storage(), storage_offset, shape_int64, strides); + // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) + auto impl = c10::make_intrusive( + c10::Storage(buffer.storage()), + buffer.key_set(), + caffe2::TypeMeta::fromScalarType(dtype)); + impl->set_storage_offset(storage_offset); + impl->set_sizes_and_strides(shape_int64, strides); + return at::Tensor(std::move(impl)); }; // Allocate row-wise data @@ -552,9 +562,14 @@ std::tuple, std::vector, bool> bulk_alloc strides[i] = strides[i + 1] * shape_int64[i + 1]; } } - // Create view sharing the same storage - return at::empty({0}, buffer.options().dtype(dtype)) - .set_(buffer.storage(), storage_offset, shape_int64, strides); + // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) + auto impl = c10::make_intrusive( + c10::Storage(buffer.storage()), + buffer.key_set(), + caffe2::TypeMeta::fromScalarType(dtype)); + impl->set_storage_offset(storage_offset); + impl->set_sizes_and_strides(shape_int64, strides); + return at::Tensor(std::move(impl)); }; // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) @@ -1006,7 +1021,6 @@ std::vector split_quantize(const at::Tensor &tensor, // Allocate output tensors std::vector output_cpp_list; std::vector output_py_list; - allocation_method = AllocationMethod::UNFUSED; switch (allocation_method) { case AllocationMethod::BULK_FP8_BLOCKWISE: { // Bulk allocation for FP8 block-scaling tensors From 7ca3618c2d815a485542cc177e7d1917d315f35c Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 4 Feb 2026 19:18:05 -0800 Subject: [PATCH 12/17] minor fix Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/cpu_offload.py | 2 -- .../pytorch/tensor/storage/mxfp8_tensor_storage.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 568ecb0a24..4f07e08b5b 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -44,8 +44,6 @@ def mark_activation_offload(*tensors): def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" - # if NVTE_CPU_OFFLOAD_V1: - # return tensors, tensor_obj = prepare_for_saving(*tensors) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index a2a11625e5..e7840d2c43 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -18,7 +18,6 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...utils import _empty_tensor -from ...cpu_offload import mark_not_offload class _FromMXFP8Func(torch.autograd.Function): @@ -86,7 +85,6 @@ def __new__( instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv - # mark_not_offload(instance._rowwise_scale_inv, instance._columnwise_scale_inv) return instance From 8c8fe59c0bcc042dba14daef74650f437fa08bca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 03:19:45 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/cast.cpp | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9ea5d492f2..343e976e5a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -219,8 +219,8 @@ std::tuple, std::vector> bulk_allocate_fp // Helper function to construct tensor view using storage sharing // Note: All views share the same storage, so record_stream() works correctly. - auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { + auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, size_t offset, + at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; if (buffer.data_ptr() == nullptr || is_empty_shape) { @@ -239,9 +239,7 @@ std::tuple, std::vector> bulk_allocate_fp } // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) auto impl = c10::make_intrusive( - c10::Storage(buffer.storage()), - buffer.key_set(), - caffe2::TypeMeta::fromScalarType(dtype)); + c10::Storage(buffer.storage()), buffer.key_set(), caffe2::TypeMeta::fromScalarType(dtype)); impl->set_storage_offset(storage_offset); impl->set_sizes_and_strides(shape_int64, strides); return at::Tensor(std::move(impl)); @@ -384,8 +382,8 @@ std::tuple, std::vector> bulk_allocate_mx // Helper function to construct tensor view using storage sharing // Note: All views share the same storage, so record_stream() works correctly. - auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { + auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, size_t offset, + at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; if (buffer.data_ptr() == nullptr || is_empty_shape) { @@ -404,9 +402,7 @@ std::tuple, std::vector> bulk_allocate_mx } // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) auto impl = c10::make_intrusive( - c10::Storage(buffer.storage()), - buffer.key_set(), - caffe2::TypeMeta::fromScalarType(dtype)); + c10::Storage(buffer.storage()), buffer.key_set(), caffe2::TypeMeta::fromScalarType(dtype)); impl->set_storage_offset(storage_offset); impl->set_sizes_and_strides(shape_int64, strides); return at::Tensor(std::move(impl)); @@ -549,8 +545,8 @@ std::tuple, std::vector, bool> bulk_alloc // Helper function to construct tensor view using storage sharing // Note: All views share the same storage, so record_stream() works correctly. - auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { + auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, size_t offset, + at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; if (buffer.data_ptr() == nullptr || is_empty_shape) { @@ -569,9 +565,7 @@ std::tuple, std::vector, bool> bulk_alloc } // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) auto impl = c10::make_intrusive( - c10::Storage(buffer.storage()), - buffer.key_set(), - caffe2::TypeMeta::fromScalarType(dtype)); + c10::Storage(buffer.storage()), buffer.key_set(), caffe2::TypeMeta::fromScalarType(dtype)); impl->set_storage_offset(storage_offset); impl->set_sizes_and_strides(shape_int64, strides); return at::Tensor(std::move(impl)); From 8421cf93413d3ce890cc7fcd4cdfaf007c3f66a5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 6 Feb 2026 07:25:12 -0800 Subject: [PATCH 14/17] return ptr of whole buffer and offload the whole buffer Signed-off-by: Hongbin Liu --- .../nvfp4/test_nvfp4_group_quantize.py | 2 +- tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 2 +- transformer_engine/pytorch/csrc/extensions.h | 7 +- .../pytorch/csrc/extensions/cast.cpp | 157 ++++++++---------- transformer_engine/pytorch/module/base.py | 3 + .../pytorch/module/grouped_linear.py | 48 +++++- .../pytorch/quantized_tensor.py | 99 +++++++++++ 7 files changed, 214 insertions(+), 104 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 01a4a01205..11802f1069 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -189,7 +189,7 @@ def check_group_quantization_nvfp4_versus_reference( reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) ) - split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) + split_quantize_outputs, _ = tex.split_quantize(x, split_sections, quantizers) if return_identity: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815b..92efd9e42a 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -194,7 +194,7 @@ def group_quantize_fp4( ] if use_tex_split_quantize: - outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers) + outputs, _ = tex.split_quantize(x, split_sections, nvfp4_quantizers) qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs] sx_list = [output._rowwise_scale_inv for output in outputs] qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..a7d170a9c5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -253,10 +253,9 @@ py::object dequantize(const py::handle &input, DType otype); std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation = false); +std::tuple, std::vector> split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation = false); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9ea5d492f2..005db84433 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -194,13 +194,15 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { -std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( +std::tuple, std::vector, std::vector> +bulk_allocate_fp8_blockwise_tensors( std::vector> &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector> retval; + std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); + auto &buffer_list = std::get<2>(retval); // Buffers for offload // Number of tensors const size_t num_tensors = shape_list.size(); @@ -217,34 +219,20 @@ std::tuple, std::vector> bulk_allocate_fp constexpr size_t fp8_elem_size = 1; constexpr size_t scale_elem_size = 4; - // Helper function to construct tensor view using storage sharing - // Note: All views share the same storage, so record_stream() works correctly. - auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; - if (buffer.data_ptr() == nullptr || is_empty_shape) { + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } - // Calculate storage offset based on dtype element size - size_t elem_size = at::elementSize(dtype); - int64_t storage_offset = static_cast(offset / elem_size); - // Compute default strides for the shape - std::vector strides(shape_int64.size()); - if (!strides.empty()) { - strides.back() = 1; - for (int64_t i = static_cast(strides.size()) - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * shape_int64[i + 1]; - } - } - // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) - auto impl = c10::make_intrusive( - c10::Storage(buffer.storage()), - buffer.key_set(), - caffe2::TypeMeta::fromScalarType(dtype)); - impl->set_storage_offset(storage_offset); - impl->set_sizes_and_strides(shape_int64, strides); - return at::Tensor(std::move(impl)); + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); }; // Allocate row-wise data @@ -273,7 +261,9 @@ std::tuple, std::vector> bulk_allocate_fp } // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -315,7 +305,9 @@ std::tuple, std::vector> bulk_allocate_fp } // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -358,13 +350,15 @@ std::tuple, std::vector> bulk_allocate_fp return retval; } -std::tuple, std::vector> bulk_allocate_mxfp8_tensors( +std::tuple, std::vector, std::vector> +bulk_allocate_mxfp8_tensors( std::vector> &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector> retval; + std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); + auto &buffer_list = std::get<2>(retval); // Buffers for offload // Number of tensors const size_t num_tensors = shape_list.size(); @@ -382,34 +376,20 @@ std::tuple, std::vector> bulk_allocate_mx constexpr size_t fp8_elem_size = 1; constexpr size_t scale_elem_size = 1; - // Helper function to construct tensor view using storage sharing - // Note: All views share the same storage, so record_stream() works correctly. - auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; - if (buffer.data_ptr() == nullptr || is_empty_shape) { + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } - // Calculate storage offset based on dtype element size - size_t elem_size = at::elementSize(dtype); - int64_t storage_offset = static_cast(offset / elem_size); - // Compute default strides for the shape - std::vector strides(shape_int64.size()); - if (!strides.empty()) { - strides.back() = 1; - for (int64_t i = static_cast(strides.size()) - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * shape_int64[i + 1]; - } - } - // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) - auto impl = c10::make_intrusive( - c10::Storage(buffer.storage()), - buffer.key_set(), - caffe2::TypeMeta::fromScalarType(dtype)); - impl->set_storage_offset(storage_offset); - impl->set_sizes_and_strides(shape_int64, strides); - return at::Tensor(std::move(impl)); + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); }; // Allocate row-wise data @@ -438,7 +418,9 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -477,7 +459,9 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -523,14 +507,17 @@ std::tuple, std::vector> bulk_allocate_mx // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate -std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( +std::tuple, std::vector, bool, std::vector> +bulk_allocate_nvfp4_tensors( std::vector> &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector, bool> retval; + std::tuple, std::vector, bool, std::vector> + retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); auto &contiguous_data_and_scale = std::get<2>(retval); + auto &buffer_list = std::get<3>(retval); // Buffers for offload contiguous_data_and_scale = true; // Number of tensors @@ -547,34 +534,20 @@ std::tuple, std::vector, bool> bulk_alloc const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; constexpr size_t scale_elem_size = 1; - // Helper function to construct tensor view using storage sharing - // Note: All views share the same storage, so record_stream() works correctly. - auto make_torch_view = [](at::Tensor &buffer, const std::vector &shape, + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; - if (buffer.data_ptr() == nullptr || is_empty_shape) { + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } - // Calculate storage offset based on dtype element size - size_t elem_size = at::elementSize(dtype); - int64_t storage_offset = static_cast(offset / elem_size); - // Compute default strides for the shape - std::vector strides(shape_int64.size()); - if (!strides.empty()) { - strides.back() = 1; - for (int64_t i = static_cast(strides.size()) - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * shape_int64[i + 1]; - } - } - // Directly create TensorImpl with shared storage (avoids empty + set_ overhead) - auto impl = c10::make_intrusive( - c10::Storage(buffer.storage()), - buffer.key_set(), - caffe2::TypeMeta::fromScalarType(dtype)); - impl->set_storage_offset(storage_offset); - impl->set_sizes_and_strides(shape_int64, strides); - return at::Tensor(std::move(impl)); + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); }; // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) @@ -627,7 +600,9 @@ std::tuple, std::vector, bool> bulk_alloc } // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -688,7 +663,9 @@ std::tuple, std::vector, bool> bulk_alloc } // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -1135,10 +1112,9 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, } // namespace -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation) { +std::tuple, std::vector> split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation) { init_extension(); // Check number of tensors @@ -1146,7 +1122,7 @@ std::vector split_quantize(const at::Tensor &tensor, NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ", quantizer_list.size()); if (num_splits == 0) { - return {}; + return {{}, {}}; } // Input tensor properties @@ -1213,6 +1189,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Allocate output tensors std::vector output_cpp_list; std::vector output_py_list; + std::vector buffer_list; // Buffers for offload (can be used for record_stream) switch (allocation_method) { case AllocationMethod::BULK_FP8_BLOCKWISE: { // Bulk allocation for FP8 block-scaling tensors @@ -1220,7 +1197,7 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { blockwise_quantizers.push_back(static_cast(quantizer.get())); } - std::tie(output_py_list, output_cpp_list) = + std::tie(output_py_list, output_cpp_list, buffer_list) = bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers); break; } @@ -1230,7 +1207,7 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { mxfp8_quantizers.push_back(static_cast(quantizer.get())); } - std::tie(output_py_list, output_cpp_list) = + std::tie(output_py_list, output_cpp_list, buffer_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); break; } @@ -1241,7 +1218,7 @@ std::vector split_quantize(const at::Tensor &tensor, nvfp4_quantizers.push_back(static_cast(quantizer.get())); } bool contiguous_data_and_scale; - std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = + std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale, buffer_list) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous @@ -1278,7 +1255,7 @@ std::vector split_quantize(const at::Tensor &tensor, multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); } - return output_py_list; + return {output_py_list, buffer_list}; } } // namespace pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8a7164bd0f..c45cd3eaee 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1509,6 +1509,9 @@ def need_backward_dw(self): return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() def trigger_backward_dw(self): + """ + Trigger the wgrad accumulation and reduce hooks. + """ for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: wgrad_accumulation_and_reduce_hook() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..d0891e20df 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -49,6 +49,8 @@ Quantizer, prepare_for_saving, restore_from_saved, + get_columnwise_subview_info, + restore_columnwise_subviews, ) from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_state import TEDebugState @@ -142,13 +144,24 @@ def forward( ) inp_view = inp.reshape(-1, in_features) inputmats: list + offload_buffer: torch.Tensor = None + subview_restore_info: list = [] if fp8 and not debug: # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. - inputmats = tex.split_quantize( - inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading - ) + inputmats, buffer_list = tex.split_quantize(inp_view, m_splits, input_quantizers) + if cpu_offloading: + # Mark inputmats as not offload - we offload the buffer instead + mark_not_offload(*inputmats) + # buffer_list layout: [rowwise_buffer?, columnwise_buffer?] + # columnwise buffer is always last if present; we only offload it + # since rowwise data is discarded when weight_requires_grad is True + if buffer_list and input_quantizers[0].columnwise_usage: + offload_buffer = buffer_list[-1] + # Get subview boundary info for restoration in backward + if offload_buffer is not None: + subview_restore_info = get_columnwise_subview_info(inputmats, offload_buffer) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -157,7 +170,12 @@ def forward( inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) if cpu_offloading: - start_offload(*inputmats) + if offload_buffer is not None: + # Offload the buffer instead of individual tensors + # (rowwise data is discarded when weight_requires_grad is True) + start_offload(offload_buffer) + else: + start_offload(*inputmats) # Initialize weights weights_fp8: list @@ -232,10 +250,15 @@ def forward( if save_original_input: inputmats = [None] * num_gemms inputmats[0] = inp + offload_buffer = None + subview_restore_info = [] else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if cpu_offloading and offload_buffer is not None: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms @@ -257,6 +280,7 @@ def forward( *weights_fp8, *weights, *biases, + offload_buffer, ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects @@ -303,6 +327,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.subview_restore_info = subview_restore_info # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -317,8 +342,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] + offload_buffer = saved_tensors[4 * N] if len(saved_tensors) > 4 * N else None main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + # Restore subviews from reloaded buffer + if ctx.cpu_offloading and ctx.subview_restore_info and offload_buffer is not None: + restore_columnwise_subviews( + inputmats, offload_buffer, ctx.subview_restore_info + ) + if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: for i, weight in enumerate(ctx.weight_objects): @@ -348,14 +380,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Unfused bias grad and multi-tensor quantize for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) - grad_output = tex.split_quantize( + grad_output, _ = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, ) else: # Multi-tensor quantize - grad_output = tex.split_quantize( + grad_output, _ = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, @@ -443,7 +475,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + inputmats, _ = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 0a6ad61ff0..aa84c8de31 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -161,6 +161,105 @@ def restore_from_saved( return tensor_objects +def get_columnwise_subview_info( + inputmats: list, columnwise_buffer: torch.Tensor +) -> list: + """ + Get boundary information for columnwise internal tensors in inputmats. + + This function extracts the byte offsets, shapes, strides, and dtypes of + columnwise internal tensors (_columnwise_data, _columnwise_scale_inv, + _columnwise_amax) relative to a shared buffer. This information is used + to restore subviews after CPU offload/reload. + + Only extracts columnwise data info since rowwise data is typically + discarded when weight_requires_grad is True. + + Uses tuples instead of dicts to minimize CPU overhead. + + Args: + inputmats: List of QuantizedTensorStorage objects created by bulk allocation. + columnwise_buffer: The buffer that contains the columnwise data (must be uint8). + + Returns: + List of tuples: [(columnwise_data_info, columnwise_scale_info, columnwise_amax_info), ...] + Each info is (byte_offset, shape, stride, dtype) or None if not present. + byte_offset is the offset in bytes from the start of the buffer. + """ + if columnwise_buffer is None: + return [] + + info_list = [] + buffer_ptr = columnwise_buffer.data_ptr() + + for tensor in inputmats: + # Get columnwise_data info + # Use data_ptr() difference to get the actual byte offset in buffer + col_data = getattr(tensor, "_columnwise_data", None) + col_data_info = ( + (col_data.data_ptr() - buffer_ptr, col_data.shape, col_data.stride(), col_data.dtype) + if col_data is not None + else None + ) + + # Get columnwise_scale_inv info + col_scale = getattr(tensor, "_columnwise_scale_inv", None) + col_scale_info = ( + (col_scale.data_ptr() - buffer_ptr, col_scale.shape, col_scale.stride(), col_scale.dtype) + if col_scale is not None + else None + ) + + # Get columnwise_amax info (for NVFP4) + col_amax = getattr(tensor, "_columnwise_amax", None) + col_amax_info = ( + (col_amax.data_ptr() - buffer_ptr, col_amax.shape, col_amax.stride(), col_amax.dtype) + if col_amax is not None + else None + ) + + info_list.append((col_data_info, col_scale_info, col_amax_info)) + return info_list + + +def restore_columnwise_subviews( + inputmats: list, columnwise_buffer: torch.Tensor, info_list: list +) -> None: + """ + Restore columnwise internal tensors from reloaded buffer. + + After CPU offload and reload, the columnwise_buffer may be at a new memory + location. This function restores the columnwise internal tensors of inputmats + to point to the correct locations in the reloaded buffer using as_strided. + + Args: + inputmats: List of QuantizedTensorStorage objects to restore. + columnwise_buffer: The reloaded columnwise buffer (must be uint8). + info_list: Boundary info returned by get_columnwise_subview_info(). + """ + if columnwise_buffer is None or not info_list: + return + + for tensor, info in zip(inputmats, info_list): + col_data_info, col_scale_info, col_amax_info = info + + # Restore columnwise_data using as_strided (avoids empty + set_ overhead) + # NOTE: byte_offset == element_offset because buffer dtype is uint8 + if col_data_info is not None: + byte_offset, shape, stride, _ = col_data_info + tensor._columnwise_data = columnwise_buffer.as_strided(shape, stride, byte_offset) + + # Restore columnwise_scale_inv + if col_scale_info is not None: + byte_offset, shape, stride, _ = col_scale_info + tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) + + # Restore columnwise_amax (NVFP4) + if col_amax_info is not None: + byte_offset, shape, stride, _ = col_amax_info + tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) + + class Quantizer(abc.ABC): """Builder class for quantized tensors. From 25dbad1d5efb2867bc66c1a28c9d75dbda698d52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:03:13 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/cast.cpp | 18 +++++++++--------- .../pytorch/module/grouped_linear.py | 8 ++++---- transformer_engine/pytorch/quantized_tensor.py | 11 +++++++---- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 005db84433..1805bdd73f 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -195,9 +195,9 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { std::tuple, std::vector, std::vector> -bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +bulk_allocate_fp8_blockwise_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); @@ -351,9 +351,9 @@ bulk_allocate_fp8_blockwise_tensors( } std::tuple, std::vector, std::vector> -bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +bulk_allocate_mxfp8_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); @@ -508,9 +508,9 @@ bulk_allocate_mxfp8_tensors( // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate std::tuple, std::vector, bool, std::vector> -bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +bulk_allocate_nvfp4_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, bool, std::vector> retval; diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d0891e20df..3d61ccd44b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -347,9 +347,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Restore subviews from reloaded buffer if ctx.cpu_offloading and ctx.subview_restore_info and offload_buffer is not None: - restore_columnwise_subviews( - inputmats, offload_buffer, ctx.subview_restore_info - ) + restore_columnwise_subviews(inputmats, offload_buffer, ctx.subview_restore_info) if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -475,7 +473,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats, _ = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + inputmats, _ = tex.split_quantize( + inp_view, ctx.m_splits, ctx.input_quantizers + ) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index aa84c8de31..12869cccc7 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -161,9 +161,7 @@ def restore_from_saved( return tensor_objects -def get_columnwise_subview_info( - inputmats: list, columnwise_buffer: torch.Tensor -) -> list: +def get_columnwise_subview_info(inputmats: list, columnwise_buffer: torch.Tensor) -> list: """ Get boundary information for columnwise internal tensors in inputmats. @@ -205,7 +203,12 @@ def get_columnwise_subview_info( # Get columnwise_scale_inv info col_scale = getattr(tensor, "_columnwise_scale_inv", None) col_scale_info = ( - (col_scale.data_ptr() - buffer_ptr, col_scale.shape, col_scale.stride(), col_scale.dtype) + ( + col_scale.data_ptr() - buffer_ptr, + col_scale.shape, + col_scale.stride(), + col_scale.dtype, + ) if col_scale is not None else None ) From d65b416c906a338687a7a5ae2b4cc7ba14aeb126 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 27 Feb 2026 02:40:25 -0800 Subject: [PATCH 16/17] remove mark_not_offload for core_attn_out Signed-off-by: Hongbin Liu --- .../pytorch/attention/dot_product_attention/backends.py | 4 ---- .../tensor/storage/float8_blockwise_tensor_storage.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 55c003ca26..7106f0ac7f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -59,7 +59,6 @@ is_cpu_offload_enabled, start_offload, mark_activation_offload, - mark_not_offload, NVTE_CPU_OFFLOAD_V1, ) from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded @@ -1308,8 +1307,6 @@ def forward( # return appropriate tensors out_ret = out_fp8 if is_output_fp8 else out - mark_not_offload(out_fp8) - mark_not_offload(out) # save appropriate tensors fp8_tensors = (None, None, None, None) @@ -1361,7 +1358,6 @@ def forward( out = out_ out_ret = out_ fp8_tensors = (None, None, None, None) - mark_not_offload(out) qkvo_tensors = (q, k, v, out) nvtx_range_pop(f"{nvtx_label}") diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 4cd6d19cd8..bcc8815082 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -358,9 +358,9 @@ def update_usage( rowwise_usage = self._rowwise_data is not None if columnwise_usage is None: columnwise_usage = self._columnwise_data is not None - assert ( - columnwise_usage or rowwise_usage - ), "Must retain some data either columnwise or rowwise" + # assert ( + # columnwise_usage or rowwise_usage + # ), "Must retain some data either columnwise or rowwise" if columnwise_usage and rowwise_usage: if not self._is_2D_scaled: From 484b0d5b837b281f6db75e772696996a39472b70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:11:19 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/backends.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 7106f0ac7f..aa6c063951 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1307,7 +1307,6 @@ def forward( # return appropriate tensors out_ret = out_fp8 if is_output_fp8 else out - # save appropriate tensors fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None)