Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
from modelopt.torch.utils.network import (
bind_forward_method,
get_module_device,
unpatch_forward_method,
)

from .calib import MseCalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
Expand Down Expand Up @@ -81,26 +85,36 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
forward_loop(model)
finish_stats_collection(model)

# amax sync for local experts in a SequentialMLP
for name, module in model.named_modules():
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

if not distributed_sync:
return

def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, device):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
sync_quantizer_amax_across_dp_ep(_q, parallel_state, device)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group, device)
quantizer.sync_amax_across_distributed_group(
parallel_state.expert_model_parallel_group, device
)
# TODO: create sync_bias_across_distributed_group

# Step 1:Sync amax across data parallelism
for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
if not isinstance(child, (TensorQuantizer, SequentialQuantizer)):
continue
sync_quantizer_amax_across_dp_ep(
child, module.parallel_state, get_module_device(module)
)
Comment on lines +114 to +116
Copy link
Contributor Author

@realAsma realAsma Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please test if all MoE quantizers have amax after this line (locally)?

if `experts` in name and "weight_quantizer` in name:
     assert child.amax is not None


# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand Down Expand Up @@ -182,10 +196,6 @@ def sync_quantizer_amax_across_tp(
parallel_state=module.parallel_state,
)

# MOE Quantization
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

# KV Cache Quantization
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
# We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache)
Expand Down
70 changes: 56 additions & 14 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any, Protocol

import torch
import torch.distributed as dist

try:
from torch.distributed.tensor import DTensor
Expand Down Expand Up @@ -1178,19 +1177,62 @@ def set_from_modelopt_state(self, modelopt_state, properties_only: bool = False)
modelopt_state.get("_pytorch_state_metadata", {})
)

def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup):
"""Synchronize the amax across all ranks in the given group."""
if parallel_group.is_initialized() and getattr(self, "_amax", None) is not None:
try:
dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group)
except RuntimeError as e:
# This error happens if the distributed backend is using GPU and
# the tensor is not on GPU (or vice versa).
warnings.warn(
f"Failed to synchronize amax: {e}, probably because the tensor is on a device which is not"
"supported by the current distributed backend. This warning can be ignored"
"if happening during modelopt restore."
)
def sync_amax_across_distributed_group(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the current sync_amax_across_distributed_group moves the amax to cpu -> this is to accommodate the case were some amaxs are None and some are tensors. However this happens typically only for MoEs.
so can we do the old method of sync for non MoEs:

 dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group)

and the sync as object via CPU only for MoEs?

self, parallel_group: DistributedProcessGroup, fallback_device: torch.device | None = None
):
"""Synchronize the amax across all ranks in the given group.

This handles the case where some ranks have amax=None (e.g., MoE experts that
haven't seen any tokens) while others have valid amax tensors.

Args:
parallel_group: The distributed process group to sync across.
fallback_device: Device to place the synced amax if local amax is None.
If None, defaults to the current CUDA device if available.
"""
if not parallel_group.is_initialized():
return

# Determine target device: prefer local amax device, then fallback, then current CUDA device
local_device = self.amax.device if self.amax is not None else None
if fallback_device is None and torch.cuda.is_available():
fallback_device = torch.device("cuda", torch.cuda.current_device())
target_device = local_device or fallback_device

def reduce_amaxs(amax_list: list) -> torch.Tensor | None:
"""Reduce amax values across ranks by taking element-wise maximum."""
valid_amaxs = [a for a in amax_list if a is not None]
if not valid_amaxs:
return None

# Iterative max handles both scalar and tensor amax values
result = valid_amaxs[0]
for amax in valid_amaxs[1:]:
result = torch.maximum(result, amax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if this line is comparing a scalar vs a tensor? how does it determine the max?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see https://docs.pytorch.org/docs/stable/generated/torch.maximum.html

it simply performs element wise maximum -> the shape does not matter as long as both are pytorch tensors (including scalar tensors)

return result

try:
# All gathering of objects happens on CPU, so lets move the local amax to CPU
local_amax_cpu = self.amax.cpu() if self.amax is not None else None

synced_amax = DistributedProcessGroup.get_dist_syncd_obj(
local_amax_cpu,
parallel_group,
reduce_amaxs,
)

if synced_amax is not None:
# Move to target device
if target_device is not None:
synced_amax = synced_amax.to(target_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add
synced_amax = synced_amax.clone().detach()
otherwise the sharding metadata of global_offset=(0, 0) on all ranks will be kept during save checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I am hoping you could take over the PR and address this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added below

self.amax = synced_amax.clone().detach()

except RuntimeError as e:
warnings.warn(
f"Failed to synchronize amax: {e}. This may occur if the tensor is on a "
"device not supported by the current distributed backend. This warning "
"can be ignored if happening during modelopt restore."
)

@contextlib.contextmanager
def disable_pre_quant_scale(self):
Expand Down
1 change: 0 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ def sync_moe_local_experts_amax(self):
This function is called to synchronize the amax values across local experts s.t. all localexperts will
share the same amax.
"""
torch.distributed.barrier()
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
Expand Down
3 changes: 3 additions & 0 deletions tests/_test_utils/torch/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def init_process(rank, size, job=None, backend="gloo", port=None):
torch.manual_seed(1234)
if job is not None:
job(rank, size)
# Explicitly destroy the process group to avoid leaving the process alive
if dist.is_initialized():
dist.destroy_process_group()


def spawn_multiprocess_job(size, job, backend="gloo"):
Expand Down
Loading