diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py
index 40f700781..1929ef2ce 100755
--- a/examples/llm_ptq/example_utils.py
+++ b/examples/llm_ptq/example_utils.py
@@ -28,6 +28,7 @@
from accelerate.utils import get_max_memory
from transformers import (
AutoConfig,
+ AutoModel,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
@@ -64,27 +65,39 @@ def run_nemotron_vl_preview(
"""
from vlm_utils import run_text_only_generation, run_vl_preview_generation
- print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
- question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
- generation_config = {
- "max_new_tokens": 100,
- "do_sample": False,
- "eos_token_id": tokenizer.eos_token_id,
- }
-
- # Try text-only generation
- text_response = run_text_only_generation(
- full_model, tokenizer, question, generation_config, pyt_ckpt_path
- )
+ # Check if this is Nemotron-Parse (encoder-decoder model that requires images)
+ config = full_model.config
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
+ generated_ids = None
+
+ if not is_nemotron_parse:
+ # Only try text-only generation for models that support it (not Nemotron-Parse)
+ print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
+ question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
+ generation_config = {
+ "max_new_tokens": 100,
+ "do_sample": False,
+ "eos_token_id": tokenizer.eos_token_id,
+ }
+
+ # Try text-only generation
+ text_response = run_text_only_generation(
+ full_model, tokenizer, question, generation_config, pyt_ckpt_path
+ )
- if text_response is not None:
- print(f"✅ Text-only generation successful: {text_response[:100]}...")
- generated_ids = text_response
- elif allow_fallback:
- print("Text-only generation failed, falling back to standard generate...")
- generated_ids = full_model.generate(input_ids, max_new_tokens=100)
+ if text_response is not None:
+ print(f"✅ Text-only generation successful: {text_response[:100]}...")
+ generated_ids = text_response
+ elif allow_fallback:
+ print("Text-only generation failed, falling back to standard generate...")
+ generated_ids = full_model.generate(input_ids, max_new_tokens=100)
else:
- generated_ids = None
+ print(
+ f"Skipping text-only generation for Nemotron-Parse ({stage_name}) - "
+ "this encoder-decoder model requires images for all operations."
+ )
# Run additional VL test with images
print(f"Running additional VL test with images ({stage_name})...")
@@ -95,6 +108,10 @@ def run_nemotron_vl_preview(
def _is_multimodal_config(config):
"""Check if a config indicates a multimodal model (config-only version of is_multimodal_model)."""
+ # Check for Nemotron-Parse encoder-decoder architecture
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
return (
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
@@ -103,6 +120,7 @@ def _is_multimodal_config(config):
or (
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
) # Image embedding layers
+ or is_nemotron_parse # Nemotron-Parse conditional generation model
)
@@ -257,8 +275,19 @@ def get_processor(
)
return MllamaImageProcessor(processor, device)
-
- return None
+ else:
+ # Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
+ # This will only work if the model has a processor config
+ try:
+ processor = AutoProcessor.from_pretrained(
+ ckpt_path,
+ **model_kwargs,
+ )
+ print(f"Loaded AutoProcessor for model type: {model_type}")
+ return processor
+ except Exception as e:
+ print(f"Could not load processor for {model_type}: {e}")
+ return None
def get_dtype(dtype):
@@ -320,8 +349,6 @@ def get_model(
model_kwargs.setdefault("torch_dtype", "auto")
if "vila" in ckpt_path.lower():
- from transformers import AutoModel
-
hf_vila = AutoModel.from_pretrained(
ckpt_path,
device_map=device_map,
@@ -353,13 +380,13 @@ def get_model(
if not hasattr(transformers, architecture):
warnings.warn(
f"Architecture {architecture} not found in transformers: {transformers.__version__}. "
- "Falling back to AutoModelForCausalLM."
+ "Falling back to AutoModel."
)
assert trust_remote_code, (
"Please set trust_remote_code to True if you want to use this architecture"
)
- auto_model_module = AutoModelForCausalLM
+ auto_model_module = AutoModel
from_config = auto_model_module.from_config
else:
auto_model_module = getattr(transformers, architecture)
@@ -370,7 +397,7 @@ def get_model(
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
model_kwargs2 = model_kwargs.copy()
- if auto_model_module != AutoModelForCausalLM:
+ if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["torch_dtype"] = torch_dtype
model_kwargs2.pop("max_memory", None)
diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py
index a9862a742..b583fe48b 100755
--- a/examples/llm_ptq/hf_ptq.py
+++ b/examples/llm_ptq/hf_ptq.py
@@ -97,6 +97,75 @@
mto.enable_huggingface_checkpointing()
+def create_nemotron_parse_calib_wrapper(base_dataloader, processor, device, decoder_only=False):
+ """Wrap a text-only dataloader to add dummy images for Nemotron-Parse calibration.
+
+ Nemotron-Parse is an encoder-decoder model that requires pixel_values (for encoder)
+ and decoder_input_ids (for decoder) during calibration. This wrapper adds properly
+ formatted dummy images and decoder inputs.
+
+ Args:
+ base_dataloader: The base text-only dataloader
+ processor: The Nemotron-Parse processor
+ device: Device to place tensors on
+ decoder_only: If True, only provide decoder inputs (for when quantizing just the decoder)
+ """
+ from PIL import Image
+
+ class NemotronParseCalibWrapper:
+ def __init__(self, base_dataloader, processor, device, decoder_only=False):
+ self.base_dataloader = base_dataloader
+ self.processor = processor
+ self.device = device
+ self.decoder_only = decoder_only
+ # Create a simple dummy image (will be processed by the model's processor)
+ self.dummy_image = Image.new("RGB", (1024, 1280), color="white")
+
+ def __iter__(self):
+ for batch in self.base_dataloader:
+ # batch contains input_ids and attention_mask from text data
+ batch_size = batch["input_ids"].shape[0]
+
+ if self.decoder_only:
+ # When calibrating just the decoder, it expects input_ids directly
+ # (not decoder_input_ids, as that's only for the full encoder-decoder forward)
+ # Just pass through the original batch
+ yield batch
+ else:
+ # When calibrating the full model, we need pixel_values and decoder_input_ids
+ # Process dummy images using the Nemotron-Parse processor
+ dummy_images = [self.dummy_image] * batch_size
+
+ # Use the processor to get properly formatted pixel_values
+ prompts = [
+ ""
+ ] * batch_size
+ processed = self.processor(
+ text=prompts, images=dummy_images, return_tensors="pt"
+ )
+
+ # For encoder-decoder models like Nemotron-Parse:
+ # - pixel_values go to the vision encoder
+ # - decoder_input_ids are needed for the decoder
+ batch["pixel_values"] = processed["pixel_values"].to(self.device)
+ batch["decoder_input_ids"] = processed["input_ids"].to(self.device)
+ batch["decoder_attention_mask"] = processed["attention_mask"].to(self.device)
+
+ # Remove the encoder input_ids and attention_mask as they're not needed
+ # The model will use pixel_values for the encoder
+ if "input_ids" in batch:
+ del batch["input_ids"]
+ if "attention_mask" in batch:
+ del batch["attention_mask"]
+
+ yield batch
+
+ def __len__(self):
+ return len(self.base_dataloader)
+
+ return NemotronParseCalibWrapper(base_dataloader, processor, device, decoder_only)
+
+
def make_calib_dataloader(
args: argparse.Namespace,
language_model: torch.nn.Module,
@@ -317,6 +386,18 @@ def load_model(args: argparse.Namespace):
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]
+
+ # Check if this is a Nemotron VL model that needs a processor
+ is_nemotron_vl_model = is_nemotron_vl(full_model)
+ if is_nemotron_vl_model:
+ # Load processor for Nemotron VL models (like Nemotron-Parse)
+ processor = get_processor(
+ args.pyt_ckpt_path,
+ model_type,
+ device,
+ trust_remote_code=args.trust_remote_code,
+ )
+
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
default_padding_side = tokenizer.padding_side
@@ -569,10 +650,20 @@ def pre_quantize(
post-quantize generation.
"""
+ # Check if this is Nemotron-Parse (encoder-decoder model)
+ config = full_model.config
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
# Only run single sample for preview
- preview_input_ids = next(iter(calib_dataloader))[
- "input_features" if model_type == "whisper" else "input_ids"
- ][0:1]
+ # For Nemotron-Parse, use decoder_input_ids instead of input_ids
+ sample_batch = next(iter(calib_dataloader))
+ if is_nemotron_parse and "decoder_input_ids" in sample_batch:
+ preview_input_ids = sample_batch["decoder_input_ids"][0:1]
+ elif model_type == "whisper":
+ preview_input_ids = sample_batch["input_features"][0:1]
+ else:
+ preview_input_ids = sample_batch["input_ids"][0:1]
# Generate preview before quantization
if is_nemotron_vl_model and tokenizer is not None:
@@ -693,36 +784,46 @@ def quantize_main(
device: torch.device,
):
if args.batch_size == 0:
- # Calibration/sparsification will actually take much more memory than regular inference
- # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
- # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
- sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
- # Whisper model expects mel-spectrogram input features of length 3000
- # Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
- # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
- # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
- if model_type == "whisper":
- max_sample_length = 3000
- num_mel_bins = language_model.config.num_mel_bins
- sample_input_single_batch = (
- torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
- language_model.device
- )
- * 100
+ # Check if this is a vision-language model
+ # For VL models, skip automatic batch size detection and use a conservative default
+ # since proper multimodal input preparation is complex
+ if is_multimodal_model(full_model) or is_nemotron_vl(full_model):
+ print(
+ "Vision-language model detected. Using default batch_size=1 for calibration "
+ "to ensure proper handling of multimodal inputs."
)
+ args.batch_size = 1
else:
- sample_input_single_batch = None
+ # Calibration/sparsification will actually take much more memory than regular inference
+ # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
+ # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
+ sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
+ # Whisper model expects mel-spectrogram input features of length 3000
+ # Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
+ # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
+ # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
+ if model_type == "whisper":
+ max_sample_length = 3000
+ num_mel_bins = language_model.config.num_mel_bins
+ sample_input_single_batch = (
+ torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
+ language_model.device
+ )
+ * 100
+ )
+ else:
+ sample_input_single_batch = None
- run_auto_quant = args.auto_quantize_bits is not None
+ run_auto_quant = args.auto_quantize_bits is not None
- args.batch_size = get_max_batch_size(
- language_model,
- max_sample_length=args.calib_seq,
- sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
- sample_input_single_batch=sample_input_single_batch,
- enable_grad=run_auto_quant,
- )
- args.batch_size = min(args.batch_size, sum(args.calib_size))
+ args.batch_size = get_max_batch_size(
+ language_model,
+ max_sample_length=args.calib_seq,
+ sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
+ sample_input_single_batch=sample_input_single_batch,
+ enable_grad=run_auto_quant,
+ )
+ args.batch_size = min(args.batch_size, sum(args.calib_size))
print(f"Use calib batch_size {args.batch_size}")
@@ -733,6 +834,32 @@ def quantize_main(
# Detect if this is a Nemotron VL model using architecture-based detection
is_nemotron_vl_model = is_nemotron_vl(full_model)
+ # For Nemotron-Parse, wrap the text-only dataloader to add dummy images
+ # Nemotron-Parse is an encoder-decoder model that requires pixel_values
+ if is_nemotron_vl_model and processor is not None:
+ config = full_model.config
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
+ if is_nemotron_parse:
+ # Check if we're quantizing just the decoder or the full model
+ decoder_only = language_model is not full_model
+
+ if decoder_only:
+ print(
+ "Calibration will use text-only inputs for Nemotron-Parse decoder. "
+ "Vision encoder is excluded from quantization."
+ )
+ else:
+ print(
+ "Wrapping calibration dataloader for Nemotron-Parse to add dummy images. "
+ "Nemotron-Parse requires pixel_values for full model calibration."
+ )
+
+ calib_dataloader = create_nemotron_parse_calib_wrapper(
+ calib_dataloader, processor, device, decoder_only=decoder_only
+ )
+
preview_input_ids, generated_ids_before_ptq = pre_quantize(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)
diff --git a/examples/llm_ptq/vlm_utils.py b/examples/llm_ptq/vlm_utils.py
index 6c9d921b8..2d3d9f82c 100644
--- a/examples/llm_ptq/vlm_utils.py
+++ b/examples/llm_ptq/vlm_utils.py
@@ -18,7 +18,7 @@
import os
from PIL import Image
-from transformers import AutoImageProcessor, AutoProcessor
+from transformers import AutoImageProcessor, AutoProcessor, GenerationConfig
def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
@@ -73,13 +73,34 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
print(" Skipping VL preview generation.")
return None
+ # Check if this is Nemotron-Parse early to set up proper generation config
+ config = model.config
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
# Generate response
question = "Describe this image briefly." # Updated for single image
- generation_config = {
- "max_new_tokens": 50,
- "do_sample": False,
- "eos_token_id": tokenizer.eos_token_id,
- }
+
+ # Use model's GenerationConfig for Nemotron-Parse, dict for others
+ if is_nemotron_parse:
+ try:
+ generation_config = GenerationConfig.from_pretrained(
+ model_path, trust_remote_code=True
+ )
+ print("Using Nemotron-Parse GenerationConfig from model")
+ except Exception as e:
+ print(f"Warning: Could not load GenerationConfig: {e}, using defaults")
+ generation_config = {
+ "max_new_tokens": 50,
+ "do_sample": False,
+ "eos_token_id": tokenizer.eos_token_id,
+ }
+ else:
+ generation_config = {
+ "max_new_tokens": 50,
+ "do_sample": False,
+ "eos_token_id": tokenizer.eos_token_id,
+ }
print(f"Generating VL response ({stage_name})...")
@@ -105,27 +126,39 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
else:
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
- messages = [
- {"role": "system", "content": "/no_think"},
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": "",
- },
- {
- "type": "text",
- "text": question,
- },
- ],
- },
- ]
+ # Check if this is Nemotron-Parse (uses task prompts instead of chat templates)
+ config = model.config
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
- # Apply chat template
- prompt = tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
+ if is_nemotron_parse:
+ # Nemotron-Parse uses a specific task prompt format
+ # See: https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1#usage-example
+ prompt = ""
+ print(f"Using Nemotron-Parse task prompt: {prompt}")
+ else:
+ # Other VL models use chat templates
+ messages = [
+ {"role": "system", "content": "/no_think"},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "",
+ },
+ {
+ "type": "text",
+ "text": question,
+ },
+ ],
+ },
+ ]
+
+ # Apply chat template
+ prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
# Process inputs using the processor with single image
inputs = processor(
@@ -139,21 +172,55 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name):
inputs = inputs.to(model_device)
print(f" Moved inputs to {model_device}")
+ # Verify we have pixel_values for the vision encoder
+ if not hasattr(inputs, "pixel_values") or inputs.pixel_values is None:
+ raise ValueError(
+ "Processor did not generate pixel_values. Check processor configuration."
+ )
+
# Generate response using model.generate
- generated_ids = model.generate(
- pixel_values=inputs.pixel_values,
- input_ids=inputs.input_ids,
- attention_mask=inputs.attention_mask,
- **generation_config,
- )
+ if isinstance(generation_config, GenerationConfig):
+ # For Nemotron-Parse with GenerationConfig object
+ generated_ids = model.generate(
+ pixel_values=inputs.pixel_values,
+ input_ids=inputs.input_ids,
+ attention_mask=inputs.attention_mask,
+ generation_config=generation_config,
+ )
+ else:
+ # For other models with dict generation config
+ generated_ids = model.generate(
+ pixel_values=inputs.pixel_values,
+ input_ids=inputs.input_ids,
+ attention_mask=inputs.attention_mask,
+ **generation_config,
+ )
# Decode the response (trim input tokens like in the working example)
+ if generated_ids is None:
+ raise ValueError("Model generate returned None")
+
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
- output_text = processor.batch_decode(
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )
+
+ # For Nemotron-Parse, use tokenizer.batch_decode instead of processor.batch_decode
+ if is_nemotron_parse and hasattr(tokenizer, "batch_decode"):
+ output_text = tokenizer.batch_decode(
+ generated_ids_trimmed,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+ else:
+ output_text = processor.batch_decode(
+ generated_ids_trimmed,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+
+ if output_text is None or len(output_text) == 0:
+ raise ValueError("Decoding returned empty output")
+
response = output_text[0]
print(f"✅ VL generation {stage_name} successful!")
diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py
index 5a24429ad..40c313ad2 100755
--- a/modelopt/torch/export/model_utils.py
+++ b/modelopt/torch/export/model_utils.py
@@ -85,6 +85,7 @@ def is_multimodal_model(model):
- Vision LoRA configurations
- Audio processing capabilities
- Image embedding layers
+ - Nemotron-Parse conditional generation models
Args:
model: The HuggingFace model instance to check
@@ -103,6 +104,10 @@ def is_multimodal_model(model):
"""
config = model.config
+ # Check for Nemotron-Parse encoder-decoder architecture
+ architectures = getattr(config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
return (
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA)
@@ -112,6 +117,7 @@ def is_multimodal_model(model):
or (
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
) # Image embedding layers
+ or is_nemotron_parse # Nemotron-Parse conditional generation model
)
@@ -141,5 +147,9 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None:
if hasattr(model, "language_model"):
return [model, model.language_model]
- # Pattern 3: No language_model found
+ # Pattern 3: For encoder-decoder VL models (e.g., Nemotron-Parse), the decoder is the language model
+ if hasattr(model, "decoder"):
+ return [model, model.decoder]
+
+ # Pattern 4: No language_model found
return None
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index ccfc01200..eaefcbe9d 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -155,12 +155,14 @@ def _output_hook(module, input, output):
# Run forward pass so that all modules sharing the same input are collected using forward hook.
+ # Check if this is Nemotron-Parse (encoder-decoder VL model)
+ architectures = getattr(model.config, "architectures", [])
+ is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
+
with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}):
- if getattr(model.config, "is_encoder_decoder", False):
- # For encoder-decoder models, we need to pass both the encoder and decoder input ids
- model(fake_input, decoder_input_ids=decoder_fake_input)
- elif is_vl_model and "nemotron" in model_type:
- # For Nemotron VL models, try to run optimization on just the language model part
+ if is_vl_model and ("nemotron" in model_type or is_nemotron_parse):
+ # For Nemotron VL models (including Nemotron-Parse), run optimization on just the language model/decoder
+ # This avoids needing to create proper pixel_values for the vision encoder
language_model_lineage = get_language_model_from_vl(model)
if language_model_lineage is not None:
@@ -177,6 +179,9 @@ def _output_hook(module, input, output):
"This is required for requantization/resmoothing optimization. "
"Please ensure the model architecture is supported or file an issue."
)
+ elif getattr(model.config, "is_encoder_decoder", False):
+ # For other encoder-decoder models (non-VL), we need to pass both encoder and decoder input ids
+ model(fake_input, decoder_input_ids=decoder_fake_input)
else:
model(fake_input)
@@ -257,25 +262,42 @@ def _export_quantized_weight(
if quantization_format == QUANTIZATION_FP8:
# Convert amax to float32
- weight_quantizer._amax = weight_quantizer._amax.to(torch.float32)
-
- if weight_quantizer._amax.dim() == 1:
- # Per-tensor amax
- weight_scaling_factor = torch.tensor(
- weight_quantizer.amax.item() / weight_quantizer.maxbound
- )
+ # Note: Use the public 'amax' property, not the private '_amax' attribute
+ if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None:
+ weight_quantizer._amax = weight_quantizer._amax.to(torch.float32)
+ amax_tensor = weight_quantizer._amax
else:
- # Per-channel amax
- weight_scaling_factor = torch.tensor(weight_quantizer.amax / weight_quantizer.maxbound)
+ # Fallback to public amax property
+ amax_tensor = weight_quantizer.amax
+ if amax_tensor is not None and hasattr(amax_tensor, "to"):
+ amax_tensor = amax_tensor.to(torch.float32)
+
+ # Only compute scaling factor if amax_tensor is valid
+ if amax_tensor is not None and hasattr(amax_tensor, "dim"):
+ if amax_tensor.dim() == 1:
+ # Per-tensor amax
+ weight_scaling_factor = torch.tensor(
+ weight_quantizer.amax.item() / weight_quantizer.maxbound
+ )
+ else:
+ # Per-channel amax
+ weight_scaling_factor = torch.tensor(
+ weight_quantizer.amax / weight_quantizer.maxbound
+ )
- sub_module.register_buffer(
- quantizer_attrs.weight_scale,
- weight_scaling_factor,
- )
+ sub_module.register_buffer(
+ quantizer_attrs.weight_scale,
+ weight_scaling_factor,
+ )
- if hasattr(input_quantizer, "_amax"):
+ if hasattr(input_quantizer, "_amax") or (
+ input_quantizer is not None
+ and hasattr(input_quantizer, "amax")
+ and input_quantizer.amax is not None
+ ):
assert input_quantizer is not None
- input_quantizer._amax = input_quantizer._amax.to(torch.float32)
+ if hasattr(input_quantizer, "_amax") and input_quantizer._amax is not None:
+ input_quantizer._amax = input_quantizer._amax.to(torch.float32)
sub_module.register_buffer(
quantizer_attrs.input_scale,
@@ -284,9 +306,14 @@ def _export_quantized_weight(
).squeeze(),
)
- if hasattr(output_quantizer, "_amax"):
+ if hasattr(output_quantizer, "_amax") or (
+ output_quantizer is not None
+ and hasattr(output_quantizer, "amax")
+ and output_quantizer.amax is not None
+ ):
assert output_quantizer is not None
- output_quantizer._amax = output_quantizer._amax.to(torch.float32)
+ if hasattr(output_quantizer, "_amax") and output_quantizer._amax is not None:
+ output_quantizer._amax = output_quantizer._amax.to(torch.float32)
else:
# Register weight_scale and input_scale
if quantization_format == QUANTIZATION_FP8_PB_REAL:
@@ -327,6 +354,18 @@ def _export_quantized_weight(
weight_scale: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale, None)
weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
+ # If weight_scale is None (e.g., quantizer wasn't calibrated), skip quantization for this module
+ # This can happen for modules that were disabled from quantization or have invalid calibration data
+ if weight_scale is None and quantization_format not in [
+ QUANTIZATION_NVFP4,
+ QUANTIZATION_NVFP4_AWQ,
+ ]:
+ # For NVFP4, weight_scale is computed later, so we can't check here
+ print(
+ f"Warning: Skipping quantization for {type(sub_module).__name__} - no weight_scale found"
+ )
+ return
+
# Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
# Check if this is a BMM-style expert weight that needs transposition
is_bmm_expert_weight = weight.dim() == 3 and any(