Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from accelerate.utils import get_max_memory
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
Expand Down Expand Up @@ -64,27 +65,39 @@ def run_nemotron_vl_preview(
"""
from vlm_utils import run_text_only_generation, run_vl_preview_generation

print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

# Try text-only generation
text_response = run_text_only_generation(
full_model, tokenizer, question, generation_config, pyt_ckpt_path
)
# Check if this is Nemotron-Parse (encoder-decoder model that requires images)
config = full_model.config
architectures = getattr(config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

generated_ids = None

if not is_nemotron_parse:
# Only try text-only generation for models that support it (not Nemotron-Parse)
print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

# Try text-only generation
text_response = run_text_only_generation(
full_model, tokenizer, question, generation_config, pyt_ckpt_path
)

if text_response is not None:
print(f"✅ Text-only generation successful: {text_response[:100]}...")
generated_ids = text_response
elif allow_fallback:
print("Text-only generation failed, falling back to standard generate...")
generated_ids = full_model.generate(input_ids, max_new_tokens=100)
if text_response is not None:
print(f"✅ Text-only generation successful: {text_response[:100]}...")
generated_ids = text_response
elif allow_fallback:
print("Text-only generation failed, falling back to standard generate...")
generated_ids = full_model.generate(input_ids, max_new_tokens=100)
else:
generated_ids = None
print(
f"Skipping text-only generation for Nemotron-Parse ({stage_name}) - "
"this encoder-decoder model requires images for all operations."
)

# Run additional VL test with images
print(f"Running additional VL test with images ({stage_name})...")
Expand All @@ -95,6 +108,10 @@ def run_nemotron_vl_preview(

def _is_multimodal_config(config):
"""Check if a config indicates a multimodal model (config-only version of is_multimodal_model)."""
# Check for Nemotron-Parse encoder-decoder architecture
architectures = getattr(config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

return (
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
Expand All @@ -103,6 +120,7 @@ def _is_multimodal_config(config):
or (
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
) # Image embedding layers
or is_nemotron_parse # Nemotron-Parse conditional generation model
)


Expand Down Expand Up @@ -257,8 +275,19 @@ def get_processor(
)

return MllamaImageProcessor(processor, device)

return None
else:
# Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
# This will only work if the model has a processor config
try:
processor = AutoProcessor.from_pretrained(
ckpt_path,
**model_kwargs,
)
print(f"Loaded AutoProcessor for model type: {model_type}")
return processor
except Exception as e:
print(f"Could not load processor for {model_type}: {e}")
return None


def get_dtype(dtype):
Expand Down Expand Up @@ -320,8 +349,6 @@ def get_model(
model_kwargs.setdefault("torch_dtype", "auto")

if "vila" in ckpt_path.lower():
from transformers import AutoModel

hf_vila = AutoModel.from_pretrained(
ckpt_path,
device_map=device_map,
Expand Down Expand Up @@ -353,13 +380,13 @@ def get_model(
if not hasattr(transformers, architecture):
warnings.warn(
f"Architecture {architecture} not found in transformers: {transformers.__version__}. "
"Falling back to AutoModelForCausalLM."
"Falling back to AutoModel."
)
assert trust_remote_code, (
"Please set trust_remote_code to True if you want to use this architecture"
)

auto_model_module = AutoModelForCausalLM
auto_model_module = AutoModel
from_config = auto_model_module.from_config
else:
auto_model_module = getattr(transformers, architecture)
Expand All @@ -370,7 +397,7 @@ def get_model(
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
model_kwargs2 = model_kwargs.copy()
if auto_model_module != AutoModelForCausalLM:
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["torch_dtype"] = torch_dtype
model_kwargs2.pop("max_memory", None)
Expand Down
185 changes: 156 additions & 29 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,75 @@
mto.enable_huggingface_checkpointing()


def create_nemotron_parse_calib_wrapper(base_dataloader, processor, device, decoder_only=False):
"""Wrap a text-only dataloader to add dummy images for Nemotron-Parse calibration.

Nemotron-Parse is an encoder-decoder model that requires pixel_values (for encoder)
and decoder_input_ids (for decoder) during calibration. This wrapper adds properly
formatted dummy images and decoder inputs.

Args:
base_dataloader: The base text-only dataloader
processor: The Nemotron-Parse processor
device: Device to place tensors on
decoder_only: If True, only provide decoder inputs (for when quantizing just the decoder)
"""
from PIL import Image

class NemotronParseCalibWrapper:
def __init__(self, base_dataloader, processor, device, decoder_only=False):
self.base_dataloader = base_dataloader
self.processor = processor
self.device = device
self.decoder_only = decoder_only
# Create a simple dummy image (will be processed by the model's processor)
self.dummy_image = Image.new("RGB", (1024, 1280), color="white")

def __iter__(self):
for batch in self.base_dataloader:
# batch contains input_ids and attention_mask from text data
batch_size = batch["input_ids"].shape[0]

if self.decoder_only:
# When calibrating just the decoder, it expects input_ids directly
# (not decoder_input_ids, as that's only for the full encoder-decoder forward)
# Just pass through the original batch
yield batch
else:
# When calibrating the full model, we need pixel_values and decoder_input_ids
# Process dummy images using the Nemotron-Parse processor
dummy_images = [self.dummy_image] * batch_size

# Use the processor to get properly formatted pixel_values
prompts = [
"</s><s><predict_bbox><predict_classes><output_markdown>"
] * batch_size
processed = self.processor(
text=prompts, images=dummy_images, return_tensors="pt"
)

# For encoder-decoder models like Nemotron-Parse:
# - pixel_values go to the vision encoder
# - decoder_input_ids are needed for the decoder
batch["pixel_values"] = processed["pixel_values"].to(self.device)
batch["decoder_input_ids"] = processed["input_ids"].to(self.device)
batch["decoder_attention_mask"] = processed["attention_mask"].to(self.device)

# Remove the encoder input_ids and attention_mask as they're not needed
# The model will use pixel_values for the encoder
if "input_ids" in batch:
del batch["input_ids"]
if "attention_mask" in batch:
del batch["attention_mask"]

yield batch

def __len__(self):
return len(self.base_dataloader)

return NemotronParseCalibWrapper(base_dataloader, processor, device, decoder_only)


def make_calib_dataloader(
args: argparse.Namespace,
language_model: torch.nn.Module,
Expand Down Expand Up @@ -317,6 +386,18 @@ def load_model(args: argparse.Namespace):
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]

# Check if this is a Nemotron VL model that needs a processor
is_nemotron_vl_model = is_nemotron_vl(full_model)
if is_nemotron_vl_model:
# Load processor for Nemotron VL models (like Nemotron-Parse)
processor = get_processor(
args.pyt_ckpt_path,
model_type,
device,
trust_remote_code=args.trust_remote_code,
)

tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

default_padding_side = tokenizer.padding_side
Expand Down Expand Up @@ -569,10 +650,20 @@ def pre_quantize(
post-quantize generation.

"""
# Check if this is Nemotron-Parse (encoder-decoder model)
config = full_model.config
architectures = getattr(config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

# Only run single sample for preview
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# For Nemotron-Parse, use decoder_input_ids instead of input_ids
sample_batch = next(iter(calib_dataloader))
if is_nemotron_parse and "decoder_input_ids" in sample_batch:
preview_input_ids = sample_batch["decoder_input_ids"][0:1]
elif model_type == "whisper":
preview_input_ids = sample_batch["input_features"][0:1]
else:
preview_input_ids = sample_batch["input_ids"][0:1]

# Generate preview before quantization
if is_nemotron_vl_model and tokenizer is not None:
Expand Down Expand Up @@ -693,36 +784,46 @@ def quantize_main(
device: torch.device,
):
if args.batch_size == 0:
# Calibration/sparsification will actually take much more memory than regular inference
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
# Whisper model expects mel-spectrogram input features of length 3000
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
if model_type == "whisper":
max_sample_length = 3000
num_mel_bins = language_model.config.num_mel_bins
sample_input_single_batch = (
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
language_model.device
)
* 100
# Check if this is a vision-language model
# For VL models, skip automatic batch size detection and use a conservative default
# since proper multimodal input preparation is complex
if is_multimodal_model(full_model) or is_nemotron_vl(full_model):
print(
"Vision-language model detected. Using default batch_size=1 for calibration "
"to ensure proper handling of multimodal inputs."
)
args.batch_size = 1
else:
sample_input_single_batch = None
# Calibration/sparsification will actually take much more memory than regular inference
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
# Whisper model expects mel-spectrogram input features of length 3000
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
if model_type == "whisper":
max_sample_length = 3000
num_mel_bins = language_model.config.num_mel_bins
sample_input_single_batch = (
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
language_model.device
)
* 100
)
else:
sample_input_single_batch = None

run_auto_quant = args.auto_quantize_bits is not None
run_auto_quant = args.auto_quantize_bits is not None

args.batch_size = get_max_batch_size(
language_model,
max_sample_length=args.calib_seq,
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
sample_input_single_batch=sample_input_single_batch,
enable_grad=run_auto_quant,
)
args.batch_size = min(args.batch_size, sum(args.calib_size))
args.batch_size = get_max_batch_size(
language_model,
max_sample_length=args.calib_seq,
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
sample_input_single_batch=sample_input_single_batch,
enable_grad=run_auto_quant,
)
args.batch_size = min(args.batch_size, sum(args.calib_size))

print(f"Use calib batch_size {args.batch_size}")

Expand All @@ -733,6 +834,32 @@ def quantize_main(
# Detect if this is a Nemotron VL model using architecture-based detection
is_nemotron_vl_model = is_nemotron_vl(full_model)

# For Nemotron-Parse, wrap the text-only dataloader to add dummy images
# Nemotron-Parse is an encoder-decoder model that requires pixel_values
if is_nemotron_vl_model and processor is not None:
config = full_model.config
architectures = getattr(config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

if is_nemotron_parse:
# Check if we're quantizing just the decoder or the full model
decoder_only = language_model is not full_model

if decoder_only:
print(
"Calibration will use text-only inputs for Nemotron-Parse decoder. "
"Vision encoder is excluded from quantization."
)
else:
print(
"Wrapping calibration dataloader for Nemotron-Parse to add dummy images. "
"Nemotron-Parse requires pixel_values for full model calibration."
)

calib_dataloader = create_nemotron_parse_calib_wrapper(
calib_dataloader, processor, device, decoder_only=decoder_only
)

preview_input_ids, generated_ids_before_ptq = pre_quantize(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)
Expand Down
Loading
Loading