Skip to content
28 changes: 28 additions & 0 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint


class ModelType(str, Enum):
Expand Down Expand Up @@ -348,6 +349,7 @@ class ExportConfig:

quantized_torch_ckpt_path: Path | None = None
onnx_dir: Path | None = None
hf_ckpt_dir: Path | None = None
restore_from: Path | None = None

def validate(self) -> None:
Expand All @@ -363,6 +365,9 @@ def validate(self) -> None:
if self.onnx_dir and not self.onnx_dir.exists():
self.onnx_dir.mkdir(parents=True, exist_ok=True)

if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists():
self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True)


def setup_logging(verbose: bool = False) -> logging.Logger:
"""
Expand Down Expand Up @@ -862,6 +867,20 @@ def restore_checkpoint(self, backbone: nn.Module) -> None:
mto.restore(backbone, str(self.config.restore_from))
self.logger.info("Model restored successfully")

def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None:
"""
Export quantized model to HuggingFace checkpoint format.
Args:
pipe: Diffusion pipeline containing the quantized model
"""
if not self.config.hf_ckpt_dir:
return

self.logger.info(f"Exporting HuggingFace checkpoint to {self.config.hf_ckpt_dir}")
export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir)
self.logger.info("HuggingFace checkpoint export completed successfully")


def create_argument_parser() -> argparse.ArgumentParser:
"""
Expand Down Expand Up @@ -994,6 +1013,11 @@ def create_argument_parser() -> argparse.ArgumentParser:
help="Path to save quantized PyTorch checkpoint",
)
export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export")
export_group.add_argument(
"--hf-ckpt-dir",
type=str,
help="Directory for HuggingFace checkpoint export",
)
export_group.add_argument(
"--restore-from", type=str, help="Path to restore from previous checkpoint"
)
Expand Down Expand Up @@ -1070,6 +1094,7 @@ def main() -> None:
if args.quantized_torch_ckpt_save_path
else None,
onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None,
hf_ckpt_dir=Path(args.hf_ckpt_dir) if args.hf_ckpt_dir else None,
restore_from=Path(args.restore_from) if args.restore_from else None,
)

Expand Down Expand Up @@ -1125,6 +1150,9 @@ def forward_loop(mod):
model_config.model_type,
quant_config.format,
)

export_manager.export_hf_ckpt(pipe)

logger.info(
f"Quantization process completed successfully! Time taken = {time.time() - s} seconds"
)
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import modelopt.torch.quantization as mtq
from modelopt.torch.export import get_model_type
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
from modelopt.torch.quantization.config import need_calibration
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets
Expand Down Expand Up @@ -243,7 +243,7 @@ def export_model(
export_dir = Path(export_path)
export_dir.mkdir(parents=True, exist_ok=True)

post_state_dict, hf_quant_config = _export_hf_checkpoint(
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
model, torch.bfloat16, accelerator=accelerator
)

Expand Down
6 changes: 4 additions & 2 deletions examples/llm_qat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import modelopt.torch.opt as mto
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
from modelopt.torch.opt.conversion import restore_from_modelopt_state
from modelopt.torch.quantization.utils import set_quantizer_state_dict
from modelopt.torch.utils import print_rank_0
Expand Down Expand Up @@ -81,7 +81,9 @@ def main(args):
base_model_dir = export_dir

try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora)
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
model, is_modelopt_qlora=is_qlora
)

with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)
Expand Down
Loading