Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5fc9857
support cuda graph capture offloading module
lhb8125 Dec 1, 2025
913fbe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2025
e04bc00
remove reset_hook and init_chunk_handler_hook
lhb8125 Dec 8, 2025
dda34c2
remove reset_hook and init_chunk_handler_hook
lhb8125 Dec 8, 2025
6ed4b91
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Dec 8, 2025
2f61c00
Merge branch 'main' into hongbinl/offload_activation_cuda_graph
lhb8125 Dec 8, 2025
88295b4
minor fix
Dec 18, 2025
ed2ee6a
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
Jan 13, 2026
09d0801
temp fix overlap-grad-reduce
lhb8125 Jan 19, 2026
8641228
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Jan 19, 2026
c3e341a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
6cd4af9
reuse mark_not_offload() and do not offload scale_inv
lhb8125 Jan 20, 2026
b54e77c
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Jan 20, 2026
ba065fc
temp fix for mxfp8
lhb8125 Jan 22, 2026
e00db5e
fix bug for record_stream and from_blob
lhb8125 Feb 2, 2026
f47b543
disable offloading core_attn_out and refine cpu overhead of at::empty
lhb8125 Feb 3, 2026
7ca3618
minor fix
lhb8125 Feb 5, 2026
12cf77b
Merge branch 'main' into hongbinl/offload_activation_cuda_graph
lhb8125 Feb 5, 2026
8c8fe59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2026
8421cf9
return ptr of whole buffer and offload the whole buffer
lhb8125 Feb 6, 2026
2e47119
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Feb 6, 2026
25dbad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2026
24d22cf
Merge branch 'main' into hongbinl/offload_activation_cuda_graph_mxfp8…
lhb8125 Feb 27, 2026
d65b416
remove mark_not_offload for core_attn_out
lhb8125 Feb 27, 2026
484b0d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def check_group_quantization_nvfp4_versus_reference(
reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose)
)

split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers)
split_quantize_outputs, _ = tex.split_quantize(x, split_sections, quantizers)

if return_identity:
x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs]
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def group_quantize_fp4(
]

if use_tex_split_quantize:
outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers)
outputs, _ = tex.split_quantize(x, split_sections, nvfp4_quantizers)
qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs]
sx_list = [output._rowwise_scale_inv for output in outputs]
qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs]
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def mark_activation_offload(*tensors):

def mark_not_offload(*tensors: torch.Tensor):
"""Marks tensors to prevent them from being offloaded."""
if NVTE_CPU_OFFLOAD_V1:
return

tensors, tensor_obj = prepare_for_saving(*tensors)

Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation = false);
std::tuple<std::vector<py::object>, std::vector<at::Tensor>> split_quantize(
const at::Tensor &tensor, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list, bool disable_bulk_allocation = false);

/***************************************************************************************************
* Bias gradient fusions
Expand Down
55 changes: 34 additions & 21 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,15 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten

namespace {

std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp8_blockwise_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<Float8BlockQuantizer *> &quantizer_cpp_list) {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>>
bulk_allocate_fp8_blockwise_tensors(std::vector<std::vector<size_t>> &shape_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<Float8BlockQuantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &buffer_list = std::get<2>(retval); // Buffers for offload

// Number of tensors
const size_t num_tensors = shape_list.size();
Expand Down Expand Up @@ -412,6 +414,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -455,6 +458,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -497,13 +501,15 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
return retval;
}

std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mxfp8_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<MXFP8Quantizer *> &quantizer_cpp_list) {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>>
bulk_allocate_mxfp8_tensors(std::vector<std::vector<size_t>> &shape_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<MXFP8Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &buffer_list = std::get<2>(retval); // Buffers for offload

// Number of tensors
const size_t num_tensors = shape_list.size();
Expand Down Expand Up @@ -565,6 +571,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -605,6 +612,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -650,14 +658,17 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// allocate fp4 data, fp8 scalings, and amax values
// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN]
// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_allocate_nvfp4_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool, std::vector<at::Tensor>>
bulk_allocate_nvfp4_tensors(std::vector<std::vector<size_t>> &shape_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool, std::vector<at::Tensor>>
retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &contiguous_data_and_scale = std::get<2>(retval);
auto &buffer_list = std::get<3>(retval); // Buffers for offload
contiguous_data_and_scale = true;

// Number of tensors
Expand Down Expand Up @@ -742,6 +753,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -804,6 +816,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -1250,18 +1263,17 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,

} // namespace

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation) {
std::tuple<std::vector<py::object>, std::vector<at::Tensor>> split_quantize(
const at::Tensor &tensor, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list, bool disable_bulk_allocation) {
init_extension();

// Check number of tensors
const size_t num_splits = split_sections.size();
NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ",
quantizer_list.size());
if (num_splits == 0) {
return {};
return {{}, {}};
}

// Input tensor properties
Expand Down Expand Up @@ -1328,14 +1340,15 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
// Allocate output tensors
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
std::vector<at::Tensor> buffer_list; // Buffers for offload (can be used for record_stream)
switch (allocation_method) {
case AllocationMethod::BULK_FP8_BLOCKWISE: {
// Bulk allocation for FP8 block-scaling tensors
std::vector<Float8BlockQuantizer *> blockwise_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
std::tie(output_py_list, output_cpp_list, buffer_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
break;
}
Expand All @@ -1345,7 +1358,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
for (auto &quantizer : quantizer_cpp_list) {
mxfp8_quantizers.push_back(static_cast<MXFP8Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
std::tie(output_py_list, output_cpp_list, buffer_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
break;
}
Expand All @@ -1356,7 +1369,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
bool contiguous_data_and_scale;
std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) =
std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale, buffer_list) =
bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers);
if (!contiguous_data_and_scale) {
// Avoid fused quantize kernel if data is not contiguous
Expand Down Expand Up @@ -1393,7 +1406,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
}

return output_py_list;
return {output_py_list, buffer_list};
}

} // namespace pytorch
Expand Down
45 changes: 38 additions & 7 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def _make_graphed_callables(
pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
) -> SingleOrTuple[Callable]:
"""
Helper method for `make_graphed_callables`
Expand Down Expand Up @@ -445,6 +447,8 @@ def hook_fn(
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
if pre_warmup_hook is not None:
pre_warmup_hook()
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
Expand Down Expand Up @@ -507,6 +511,8 @@ def hook_fn(
else:
grad_inputs = None
del outputs, grad_inputs
if post_warmup_hook is not None:
post_warmup_hook()
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
Expand All @@ -517,7 +523,6 @@ def hook_fn(
# All captures here share a mempool. To avoid replays corrupting each other's memory,
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

if _order is not None: # pylint: disable=too-many-nested-blocks
per_callable_static_outputs = [None] * len(flatten_sample_args)
per_callable_output_unflatten_spec = [None] * len(flatten_sample_args)
Expand Down Expand Up @@ -782,14 +787,15 @@ class Graphed(torch.autograd.Function):
"""Autograd function for graph replay."""

@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs):
# pylint: disable=missing-function-docstring

# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)

ctx.cuda_graph_stream = cuda_graph_stream
ctx.cuda_graph_event = cuda_graph_event
# Copy values from new tensors into static tensors
for i in range(len_user_args):
if (
Expand All @@ -799,7 +805,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs):
static_input_surface[i].copy_(inputs[i])

# Replay forward graph
fwd_graph.replay()
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
Comment on lines +808 to +811
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cuda_graph_event.record(cuda_graph_stream) after replay. Without recording the event, wait_event waits for the wrong completion point

Suggested change
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
cuda_graph_event.record(cuda_graph_stream)
torch.cuda.current_stream().wait_event(cuda_graph_event)

assert isinstance(static_outputs, tuple)
return tuple(o.detach() if o is not None else o for o in static_outputs)

Expand All @@ -816,15 +825,18 @@ def backward(ctx, *grads):
# incoming grad is already in the right place
if g.data_ptr() != grad.data_ptr():
g.copy_(grad)
bwd_graph.replay()
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
Comment on lines +828 to +831
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue: missing ctx.cuda_graph_event.record(ctx.cuda_graph_stream) after backward graph replay

Suggested change
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
ctx.cuda_graph_event.record(ctx.cuda_graph_stream)
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)


# Update FP8 scale factors if needed
if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return (None,) + tuple(
return (None, None, None) + tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)

Expand All @@ -839,6 +851,16 @@ def functionalized(*user_args, **user_kwargs):

skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

if "cuda_graph_stream" in user_kwargs:
cuda_graph_stream = user_kwargs["cuda_graph_stream"]
user_kwargs.pop("cuda_graph_stream")
else:
cuda_graph_stream = torch.cuda.current_stream()
if "cuda_graph_event" in user_kwargs:
cuda_graph_event = user_kwargs["cuda_graph_event"]
user_kwargs.pop("cuda_graph_event")
else:
cuda_graph_event = torch.cuda.Event()
# Check that required kwargs are provided
for key in kwargs_keys:
if key not in user_kwargs:
Expand All @@ -854,7 +876,9 @@ def functionalized(*user_args, **user_kwargs):
flatten_user_args, _ = _tree_flatten(user_args)
flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
out = Graphed.apply(skip_fp8_weight_update, *func_args)
out = Graphed.apply(
skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args
)
return _tree_unflatten(out, output_unflatten_spec)

return functionalized
Expand All @@ -867,6 +891,9 @@ def make_graphed_attribute_functions(graph_idx):
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()
for module in te_modules:
if hasattr(module, "trigger_backward_dw"):
module.trigger_backward_dw()

# Trigger the grad accumulation hook for wgrad graphs.
for module in te_modules:
Expand Down Expand Up @@ -1040,6 +1067,8 @@ def make_graphed_callables(
pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
) -> Union[Callable, Tuple[Callable, ...]]:
"""
Make CUDA graph version of Transformer Engine modules
Expand Down Expand Up @@ -1264,6 +1293,8 @@ def call_func(self, *args, **kwargs):
pool=pool,
retain_graph_in_backward=retain_graph_in_backward,
_reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers,
pre_warmup_hook=pre_warmup_hook,
post_warmup_hook=post_warmup_hook,
)

# Ensures warmup does not affect numerics for ops such as dropout.
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,13 @@ def need_backward_dw(self):
"""
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()

def trigger_backward_dw(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()

def backward_dw(self):
"""
Execute the delayed weight gradient computation.
Expand Down
Loading