Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
from auto_round.compressors.shard_writer import shard_writer
from auto_round.compressors.utils import (
DDPIndexSampler,
IndexSampler,
block_forward,
check_need_act_calibration,
Expand Down Expand Up @@ -145,6 +146,58 @@
"to_quant_block_names",
)

import torch.distributed as dist


def rank_log(msg: str) -> None:
"""Log message with rank information in distributed setting."""
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
logger.info(f"[Rank {rank}] {msg}")


# Source - https://stackoverflow.com/a
# Posted by Romuald Brunet, modified by community. See post 'Timeline' for change history
# Retrieved 2026-01-24, License - CC BY-SA 3.0

# import pdb


# class ForkedPdb(pdb.Pdb):
# """A Pdb subclass that may be used
# from a forked multiprocessing child

# """

# def interaction(self, *args, **kwargs):
# _stdin = sys.stdin
# try:
# sys.stdin = open("/dev/stdin")
# pdb.Pdb.interaction(self, *args, **kwargs)
# finally:
# sys.stdin = _stdin


def check_grad(block, msg):
the_first_param_with_grad = None
for name, param in block.named_parameters():
if param.grad is not None:
the_first_param_with_grad = (name, param)
break
if the_first_param_with_grad is None:
rank_log(f"{msg} No grad found in block.")
else:
name, param = the_first_param_with_grad
rank_log(
f"{msg} Grad found in block. Param name: {name}, Grad norm: {param.grad.norm().item()}. grad: {param.grad}"
)


def is_ddp():
return dist.is_initialized() and dist.get_world_size() > 1


from torch.nn.parallel import DistributedDataParallel as DDP


class BaseCompressor(object):
"""Base compressor for LLM quantization
Expand Down Expand Up @@ -2910,9 +2963,24 @@ def _quantize_block(
whole_indices = torch.arange(global_batch_size)
num_elm = self._get_current_num_elm(input_ids, whole_indices)

index_sampler = IndexSampler(nsamples, global_batch_size)
if is_ddp():
# index_sampler = DDPIndexSampler(nsamples, global_batch_size, iters=self.iters)
index_sampler = IndexSampler(nsamples, global_batch_size)
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
rank_log(f"DDP rank {rank} is quantizing the block")
# ForkedPdb().set_trace()
rank_log(
f"device info: device: {device}, loss_device: {loss_device}, block's device {next(block.parameters()).device}"
)
block = DDP(block, device_ids=[rank], find_unused_parameters=True)
dist.barrier()
logger.warning_once("DistributedDataParallel (DDP) is used for block quantization. " f"block: {block}")
else:
index_sampler = IndexSampler(nsamples, global_batch_size)
batch_size = self.batch_size

for i in range(self.iters):
rank_log(f"starts iteration {i} for block quantization, best loss so far {best_loss}")
if self.enable_alg_ext and self.data_type.endswith("dq"):
for n, m in block.named_modules():
m.cur_iter = i
Expand All @@ -2922,6 +2990,7 @@ def _quantize_block(
num_elm = self._get_non_zero_cnt(self.attention_mask, global_indices)

for tmp_step in range(self.gradient_accumulate_steps):
# rank_log(f"iteration {i} tmp_step {tmp_step} start")
indices = global_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size]
current_output = self._get_current_output(output, indices)
current_output = to_device(current_output, loss_device)
Expand All @@ -2933,9 +3002,11 @@ def _quantize_block(
if self.low_gpu_mem_usage and card_0_in_high_risk:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.5, device_list=self.device_list)

rank_log(
f"iteration {i} tmp_step {tmp_step} loss: {loss.item()}, loss device {loss.device}, starting backward"
)
self._scale_loss_and_backward(scaler, loss)

# check_grad(block, msg=f"block quantization iter {i} tmp_step {tmp_step}")
if self.low_gpu_mem_usage and card_0_in_high_risk:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.8, device_list=self.device_list)
Expand All @@ -2958,6 +3029,9 @@ def _quantize_block(
break
self._step(scaler, optimizer, lr_schedule)

rank_log(f"ends iteration {i} for block quantization")
# dist.barrier()

last_loss = total_loss
best_iter = self.iters
if not self.not_use_best_mse:
Expand Down
31 changes: 31 additions & 0 deletions auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,3 +978,34 @@ def next_batch(self) -> list[int]:
batch = self.indices[self.index : self.index + self.batch_size]
self.index += self.batch_size
return batch


def rank_in_ddp() -> int:
"""Returns the rank of the current process in a DDP setup.

Returns:
int: The rank of the current process. Returns 0 if not in DDP.
"""
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()


class DDPIndexSampler(IndexSampler):

def __init__(self, nsamples: int, batch_size: int, iters: int) -> None:
self.iters = iters
super().__init__(nsamples, batch_size)
rank = rank_in_ddp()
# run next_batch() for `rank` times to sync different rank's sampler
for _ in range(rank * iters):
self.next_batch()

def next_batch(self) -> list[int]:
if self.index + self.batch_size > self.nsamples:
random.shuffle(self.indices)
self.index = 0

batch = self.indices[self.index : self.index + self.batch_size]
self.index += self.batch_size
return batch
Loading