diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 40f700781..1929ef2ce 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -28,6 +28,7 @@ from accelerate.utils import get_max_memory from transformers import ( AutoConfig, + AutoModel, AutoModelForCausalLM, AutoProcessor, AutoTokenizer, @@ -64,27 +65,39 @@ def run_nemotron_vl_preview( """ from vlm_utils import run_text_only_generation, run_vl_preview_generation - print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...") - question = tokenizer.decode(input_ids[0], skip_special_tokens=True) - generation_config = { - "max_new_tokens": 100, - "do_sample": False, - "eos_token_id": tokenizer.eos_token_id, - } - - # Try text-only generation - text_response = run_text_only_generation( - full_model, tokenizer, question, generation_config, pyt_ckpt_path - ) + # Check if this is Nemotron-Parse (encoder-decoder model that requires images) + config = full_model.config + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + + generated_ids = None + + if not is_nemotron_parse: + # Only try text-only generation for models that support it (not Nemotron-Parse) + print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...") + question = tokenizer.decode(input_ids[0], skip_special_tokens=True) + generation_config = { + "max_new_tokens": 100, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + + # Try text-only generation + text_response = run_text_only_generation( + full_model, tokenizer, question, generation_config, pyt_ckpt_path + ) - if text_response is not None: - print(f"✅ Text-only generation successful: {text_response[:100]}...") - generated_ids = text_response - elif allow_fallback: - print("Text-only generation failed, falling back to standard generate...") - generated_ids = full_model.generate(input_ids, max_new_tokens=100) + if text_response is not None: + print(f"✅ Text-only generation successful: {text_response[:100]}...") + generated_ids = text_response + elif allow_fallback: + print("Text-only generation failed, falling back to standard generate...") + generated_ids = full_model.generate(input_ids, max_new_tokens=100) else: - generated_ids = None + print( + f"Skipping text-only generation for Nemotron-Parse ({stage_name}) - " + "this encoder-decoder model requires images for all operations." + ) # Run additional VL test with images print(f"Running additional VL test with images ({stage_name})...") @@ -95,6 +108,10 @@ def run_nemotron_vl_preview( def _is_multimodal_config(config): """Check if a config indicates a multimodal model (config-only version of is_multimodal_model).""" + # Check for Nemotron-Parse encoder-decoder architecture + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + return ( hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal @@ -103,6 +120,7 @@ def _is_multimodal_config(config): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or is_nemotron_parse # Nemotron-Parse conditional generation model ) @@ -257,8 +275,19 @@ def get_processor( ) return MllamaImageProcessor(processor, device) - - return None + else: + # Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse) + # This will only work if the model has a processor config + try: + processor = AutoProcessor.from_pretrained( + ckpt_path, + **model_kwargs, + ) + print(f"Loaded AutoProcessor for model type: {model_type}") + return processor + except Exception as e: + print(f"Could not load processor for {model_type}: {e}") + return None def get_dtype(dtype): @@ -320,8 +349,6 @@ def get_model( model_kwargs.setdefault("torch_dtype", "auto") if "vila" in ckpt_path.lower(): - from transformers import AutoModel - hf_vila = AutoModel.from_pretrained( ckpt_path, device_map=device_map, @@ -353,13 +380,13 @@ def get_model( if not hasattr(transformers, architecture): warnings.warn( f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM." + "Falling back to AutoModel." ) assert trust_remote_code, ( "Please set trust_remote_code to True if you want to use this architecture" ) - auto_model_module = AutoModelForCausalLM + auto_model_module = AutoModel from_config = auto_model_module.from_config else: auto_model_module = getattr(transformers, architecture) @@ -370,7 +397,7 @@ def get_model( # unless specified by the hf_config. torch_dtype = getattr(hf_config, "torch_dtype", torch.float16) model_kwargs2 = model_kwargs.copy() - if auto_model_module != AutoModelForCausalLM: + if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) model_kwargs2["torch_dtype"] = torch_dtype model_kwargs2.pop("max_memory", None) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..b583fe48b 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -97,6 +97,75 @@ mto.enable_huggingface_checkpointing() +def create_nemotron_parse_calib_wrapper(base_dataloader, processor, device, decoder_only=False): + """Wrap a text-only dataloader to add dummy images for Nemotron-Parse calibration. + + Nemotron-Parse is an encoder-decoder model that requires pixel_values (for encoder) + and decoder_input_ids (for decoder) during calibration. This wrapper adds properly + formatted dummy images and decoder inputs. + + Args: + base_dataloader: The base text-only dataloader + processor: The Nemotron-Parse processor + device: Device to place tensors on + decoder_only: If True, only provide decoder inputs (for when quantizing just the decoder) + """ + from PIL import Image + + class NemotronParseCalibWrapper: + def __init__(self, base_dataloader, processor, device, decoder_only=False): + self.base_dataloader = base_dataloader + self.processor = processor + self.device = device + self.decoder_only = decoder_only + # Create a simple dummy image (will be processed by the model's processor) + self.dummy_image = Image.new("RGB", (1024, 1280), color="white") + + def __iter__(self): + for batch in self.base_dataloader: + # batch contains input_ids and attention_mask from text data + batch_size = batch["input_ids"].shape[0] + + if self.decoder_only: + # When calibrating just the decoder, it expects input_ids directly + # (not decoder_input_ids, as that's only for the full encoder-decoder forward) + # Just pass through the original batch + yield batch + else: + # When calibrating the full model, we need pixel_values and decoder_input_ids + # Process dummy images using the Nemotron-Parse processor + dummy_images = [self.dummy_image] * batch_size + + # Use the processor to get properly formatted pixel_values + prompts = [ + "" + ] * batch_size + processed = self.processor( + text=prompts, images=dummy_images, return_tensors="pt" + ) + + # For encoder-decoder models like Nemotron-Parse: + # - pixel_values go to the vision encoder + # - decoder_input_ids are needed for the decoder + batch["pixel_values"] = processed["pixel_values"].to(self.device) + batch["decoder_input_ids"] = processed["input_ids"].to(self.device) + batch["decoder_attention_mask"] = processed["attention_mask"].to(self.device) + + # Remove the encoder input_ids and attention_mask as they're not needed + # The model will use pixel_values for the encoder + if "input_ids" in batch: + del batch["input_ids"] + if "attention_mask" in batch: + del batch["attention_mask"] + + yield batch + + def __len__(self): + return len(self.base_dataloader) + + return NemotronParseCalibWrapper(base_dataloader, processor, device, decoder_only) + + def make_calib_dataloader( args: argparse.Namespace, language_model: torch.nn.Module, @@ -317,6 +386,18 @@ def load_model(args: argparse.Namespace): args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ : len(args.dataset) ] + + # Check if this is a Nemotron VL model that needs a processor + is_nemotron_vl_model = is_nemotron_vl(full_model) + if is_nemotron_vl_model: + # Load processor for Nemotron VL models (like Nemotron-Parse) + processor = get_processor( + args.pyt_ckpt_path, + model_type, + device, + trust_remote_code=args.trust_remote_code, + ) + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) default_padding_side = tokenizer.padding_side @@ -569,10 +650,20 @@ def pre_quantize( post-quantize generation. """ + # Check if this is Nemotron-Parse (encoder-decoder model) + config = full_model.config + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + # Only run single sample for preview - preview_input_ids = next(iter(calib_dataloader))[ - "input_features" if model_type == "whisper" else "input_ids" - ][0:1] + # For Nemotron-Parse, use decoder_input_ids instead of input_ids + sample_batch = next(iter(calib_dataloader)) + if is_nemotron_parse and "decoder_input_ids" in sample_batch: + preview_input_ids = sample_batch["decoder_input_ids"][0:1] + elif model_type == "whisper": + preview_input_ids = sample_batch["input_features"][0:1] + else: + preview_input_ids = sample_batch["input_ids"][0:1] # Generate preview before quantization if is_nemotron_vl_model and tokenizer is not None: @@ -693,36 +784,46 @@ def quantize_main( device: torch.device, ): if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = language_model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( - language_model.device - ) - * 100 + # Check if this is a vision-language model + # For VL models, skip automatic batch size detection and use a conservative default + # since proper multimodal input preparation is complex + if is_multimodal_model(full_model) or is_nemotron_vl(full_model): + print( + "Vision-language model detected. Using default batch_size=1 for calibration " + "to ensure proper handling of multimodal inputs." ) + args.batch_size = 1 else: - sample_input_single_batch = None + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None - args.batch_size = get_max_batch_size( - language_model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) print(f"Use calib batch_size {args.batch_size}") @@ -733,6 +834,32 @@ def quantize_main( # Detect if this is a Nemotron VL model using architecture-based detection is_nemotron_vl_model = is_nemotron_vl(full_model) + # For Nemotron-Parse, wrap the text-only dataloader to add dummy images + # Nemotron-Parse is an encoder-decoder model that requires pixel_values + if is_nemotron_vl_model and processor is not None: + config = full_model.config + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + + if is_nemotron_parse: + # Check if we're quantizing just the decoder or the full model + decoder_only = language_model is not full_model + + if decoder_only: + print( + "Calibration will use text-only inputs for Nemotron-Parse decoder. " + "Vision encoder is excluded from quantization." + ) + else: + print( + "Wrapping calibration dataloader for Nemotron-Parse to add dummy images. " + "Nemotron-Parse requires pixel_values for full model calibration." + ) + + calib_dataloader = create_nemotron_parse_calib_wrapper( + calib_dataloader, processor, device, decoder_only=decoder_only + ) + preview_input_ids, generated_ids_before_ptq = pre_quantize( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) diff --git a/examples/llm_ptq/vlm_utils.py b/examples/llm_ptq/vlm_utils.py index 6c9d921b8..2d3d9f82c 100644 --- a/examples/llm_ptq/vlm_utils.py +++ b/examples/llm_ptq/vlm_utils.py @@ -18,7 +18,7 @@ import os from PIL import Image -from transformers import AutoImageProcessor, AutoProcessor +from transformers import AutoImageProcessor, AutoProcessor, GenerationConfig def run_vl_preview_generation(model, tokenizer, model_path, stage_name): @@ -73,13 +73,34 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): print(" Skipping VL preview generation.") return None + # Check if this is Nemotron-Parse early to set up proper generation config + config = model.config + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + # Generate response question = "Describe this image briefly." # Updated for single image - generation_config = { - "max_new_tokens": 50, - "do_sample": False, - "eos_token_id": tokenizer.eos_token_id, - } + + # Use model's GenerationConfig for Nemotron-Parse, dict for others + if is_nemotron_parse: + try: + generation_config = GenerationConfig.from_pretrained( + model_path, trust_remote_code=True + ) + print("Using Nemotron-Parse GenerationConfig from model") + except Exception as e: + print(f"Warning: Could not load GenerationConfig: {e}, using defaults") + generation_config = { + "max_new_tokens": 50, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + else: + generation_config = { + "max_new_tokens": 50, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } print(f"Generating VL response ({stage_name})...") @@ -105,27 +126,39 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): else: processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - messages = [ - {"role": "system", "content": "/no_think"}, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "", - }, - { - "type": "text", - "text": question, - }, - ], - }, - ] + # Check if this is Nemotron-Parse (uses task prompts instead of chat templates) + config = model.config + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) - # Apply chat template - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + if is_nemotron_parse: + # Nemotron-Parse uses a specific task prompt format + # See: https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1#usage-example + prompt = "" + print(f"Using Nemotron-Parse task prompt: {prompt}") + else: + # Other VL models use chat templates + messages = [ + {"role": "system", "content": "/no_think"}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "", + }, + { + "type": "text", + "text": question, + }, + ], + }, + ] + + # Apply chat template + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Process inputs using the processor with single image inputs = processor( @@ -139,21 +172,55 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): inputs = inputs.to(model_device) print(f" Moved inputs to {model_device}") + # Verify we have pixel_values for the vision encoder + if not hasattr(inputs, "pixel_values") or inputs.pixel_values is None: + raise ValueError( + "Processor did not generate pixel_values. Check processor configuration." + ) + # Generate response using model.generate - generated_ids = model.generate( - pixel_values=inputs.pixel_values, - input_ids=inputs.input_ids, - attention_mask=inputs.attention_mask, - **generation_config, - ) + if isinstance(generation_config, GenerationConfig): + # For Nemotron-Parse with GenerationConfig object + generated_ids = model.generate( + pixel_values=inputs.pixel_values, + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + generation_config=generation_config, + ) + else: + # For other models with dict generation config + generated_ids = model.generate( + pixel_values=inputs.pixel_values, + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + **generation_config, + ) # Decode the response (trim input tokens like in the working example) + if generated_ids is None: + raise ValueError("Model generate returned None") + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + + # For Nemotron-Parse, use tokenizer.batch_decode instead of processor.batch_decode + if is_nemotron_parse and hasattr(tokenizer, "batch_decode"): + output_text = tokenizer.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + else: + output_text = processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + if output_text is None or len(output_text) == 0: + raise ValueError("Decoding returned empty output") + response = output_text[0] print(f"✅ VL generation {stage_name} successful!") diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5a24429ad..40c313ad2 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -85,6 +85,7 @@ def is_multimodal_model(model): - Vision LoRA configurations - Audio processing capabilities - Image embedding layers + - Nemotron-Parse conditional generation models Args: model: The HuggingFace model instance to check @@ -103,6 +104,10 @@ def is_multimodal_model(model): """ config = model.config + # Check for Nemotron-Parse encoder-decoder architecture + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + return ( hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA) @@ -112,6 +117,7 @@ def is_multimodal_model(model): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or is_nemotron_parse # Nemotron-Parse conditional generation model ) @@ -141,5 +147,9 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None: if hasattr(model, "language_model"): return [model, model.language_model] - # Pattern 3: No language_model found + # Pattern 3: For encoder-decoder VL models (e.g., Nemotron-Parse), the decoder is the language model + if hasattr(model, "decoder"): + return [model, model.decoder] + + # Pattern 4: No language_model found return None diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ccfc01200..eaefcbe9d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -155,12 +155,14 @@ def _output_hook(module, input, output): # Run forward pass so that all modules sharing the same input are collected using forward hook. + # Check if this is Nemotron-Parse (encoder-decoder VL model) + architectures = getattr(model.config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part + if is_vl_model and ("nemotron" in model_type or is_nemotron_parse): + # For Nemotron VL models (including Nemotron-Parse), run optimization on just the language model/decoder + # This avoids needing to create proper pixel_values for the vision encoder language_model_lineage = get_language_model_from_vl(model) if language_model_lineage is not None: @@ -177,6 +179,9 @@ def _output_hook(module, input, output): "This is required for requantization/resmoothing optimization. " "Please ensure the model architecture is supported or file an issue." ) + elif getattr(model.config, "is_encoder_decoder", False): + # For other encoder-decoder models (non-VL), we need to pass both encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) else: model(fake_input) @@ -257,25 +262,42 @@ def _export_quantized_weight( if quantization_format == QUANTIZATION_FP8: # Convert amax to float32 - weight_quantizer._amax = weight_quantizer._amax.to(torch.float32) - - if weight_quantizer._amax.dim() == 1: - # Per-tensor amax - weight_scaling_factor = torch.tensor( - weight_quantizer.amax.item() / weight_quantizer.maxbound - ) + # Note: Use the public 'amax' property, not the private '_amax' attribute + if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: + weight_quantizer._amax = weight_quantizer._amax.to(torch.float32) + amax_tensor = weight_quantizer._amax else: - # Per-channel amax - weight_scaling_factor = torch.tensor(weight_quantizer.amax / weight_quantizer.maxbound) + # Fallback to public amax property + amax_tensor = weight_quantizer.amax + if amax_tensor is not None and hasattr(amax_tensor, "to"): + amax_tensor = amax_tensor.to(torch.float32) + + # Only compute scaling factor if amax_tensor is valid + if amax_tensor is not None and hasattr(amax_tensor, "dim"): + if amax_tensor.dim() == 1: + # Per-tensor amax + weight_scaling_factor = torch.tensor( + weight_quantizer.amax.item() / weight_quantizer.maxbound + ) + else: + # Per-channel amax + weight_scaling_factor = torch.tensor( + weight_quantizer.amax / weight_quantizer.maxbound + ) - sub_module.register_buffer( - quantizer_attrs.weight_scale, - weight_scaling_factor, - ) + sub_module.register_buffer( + quantizer_attrs.weight_scale, + weight_scaling_factor, + ) - if hasattr(input_quantizer, "_amax"): + if hasattr(input_quantizer, "_amax") or ( + input_quantizer is not None + and hasattr(input_quantizer, "amax") + and input_quantizer.amax is not None + ): assert input_quantizer is not None - input_quantizer._amax = input_quantizer._amax.to(torch.float32) + if hasattr(input_quantizer, "_amax") and input_quantizer._amax is not None: + input_quantizer._amax = input_quantizer._amax.to(torch.float32) sub_module.register_buffer( quantizer_attrs.input_scale, @@ -284,9 +306,14 @@ def _export_quantized_weight( ).squeeze(), ) - if hasattr(output_quantizer, "_amax"): + if hasattr(output_quantizer, "_amax") or ( + output_quantizer is not None + and hasattr(output_quantizer, "amax") + and output_quantizer.amax is not None + ): assert output_quantizer is not None - output_quantizer._amax = output_quantizer._amax.to(torch.float32) + if hasattr(output_quantizer, "_amax") and output_quantizer._amax is not None: + output_quantizer._amax = output_quantizer._amax.to(torch.float32) else: # Register weight_scale and input_scale if quantization_format == QUANTIZATION_FP8_PB_REAL: @@ -327,6 +354,18 @@ def _export_quantized_weight( weight_scale: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale, None) weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None) + # If weight_scale is None (e.g., quantizer wasn't calibrated), skip quantization for this module + # This can happen for modules that were disabled from quantization or have invalid calibration data + if weight_scale is None and quantization_format not in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + ]: + # For NVFP4, weight_scale is computed later, so we can't check here + print( + f"Warning: Skipping quantization for {type(sub_module).__name__} - no weight_scale found" + ) + return + # Transpose weight for bmm-style expert quantization (llama4, gpt-oss) # Check if this is a BMM-style expert weight that needs transposition is_bmm_expert_weight = weight.dim() == 3 and any(