From b9acc43b0140a584b8951eed68a73d1c5b81d535 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 9 Jan 2026 01:21:53 -0800 Subject: [PATCH 01/15] Add support for VLM calibration with image-text pair data Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 97 +++++++++++++++++- modelopt/torch/utils/vlm_dataset_utils.py | 115 ++++++++++++++++++---- 2 files changed, 190 insertions(+), 22 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..bc7530f68 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -65,7 +65,10 @@ from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader -from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader +from modelopt.torch.utils.vlm_dataset_utils import ( + get_supported_vlm_datasets, + get_vlm_dataset_dataloader, +) RAND_SEED = 1234 @@ -107,7 +110,25 @@ def make_calib_dataloader( ) -> tuple[DataLoader, str | None]: calib_dataloader = None first_text_speech_dataset = None - if model_type == "mllama": + if getattr(args, "calib_with_images", False): + # Generic multimodal calibration path (used for Nemotron VL and other HF VLMs). + assert processor is not None, ( + "Please provide a processor (e.g., AutoProcessor) for image calibration." + ) + assert len(args.calib_size) == 1, ( + "Image calibration currently supports a single dataset. " + "Please pass --calib_size with one value (e.g., --calib_size 256)." + ) + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=getattr(args, "vlm_dataset", "scienceqa"), + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + device=device, + max_length=args.calib_seq, + require_image=True, + ) + elif model_type == "mllama": assert processor is not None and isinstance(processor, MllamaImageProcessor), ( "The MllamaImageProcessor must be set." ) @@ -164,6 +185,12 @@ def auto_quantize( ): """Auto search quantization of multiple formats.""" + if getattr(args, "calib_with_images", False): + raise NotImplementedError( + "AutoQuantize with image-text calibration is not supported yet. " + "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." + ) + assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( "Auto Quantization is not supported for pipeline parallel size > 1" ) @@ -292,6 +319,7 @@ def load_model(args: argparse.Namespace): language_model = full_model default_padding_side = None + is_nemotron_vl_model = is_nemotron_vl(full_model) if model_type == "mllama": processor = get_processor( args.pyt_ckpt_path, @@ -307,6 +335,41 @@ def load_model(args: argparse.Namespace): device, trust_remote_code=args.trust_remote_code, ) + elif is_nemotron_vl_model and getattr(args, "calib_with_images", False): + # For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs. + try: + processor = AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left" + ) + except Exception as e: + raise RuntimeError( + "Failed to load AutoProcessor for Nemotron VL image calibration. " + "Please ensure the checkpoint provides a compatible processor." + ) from e + + if hasattr(processor, "tokenizer") and processor.tokenizer is not None: + tokenizer = processor.tokenizer + else: + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + + default_padding_side = tokenizer.padding_side + tokenizer.padding_side = "left" + + # Quantize only the language model, but keep the full_model for calibration forward. + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + language_model = language_model_lineage.pop(-1) + ancestors = language_model_lineage + disabled_quant_cfg = {"quant_cfg": {"default": {"enable": False}}, "algorithm": "max"} + + memo = set(ancestors) | {language_model} + for ancestor in ancestors: + for _, module in ancestor.named_children(): + if module not in memo: + mtq.quantize(module, disabled_quant_cfg, forward_loop=None) + memo.add(module) + + model_type = get_model_type(language_model) else: if args.dataset is None: args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] @@ -432,9 +495,19 @@ def mono_quantize( if not use_calibration: warnings.warn("Dynamic quantization. Calibration skipped.") - calibrate_loop = ( - create_forward_loop(dataloader=calib_dataloader) if use_calibration else None - ) + calibrate_loop = None + if use_calibration: + base_forward_loop = create_forward_loop(dataloader=calib_dataloader) + # For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values). + # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. + if getattr(args, "calib_with_images", False) and is_nemotron_vl_model: + + def calibrate_full_model(_model): + return base_forward_loop(full_model) + + calibrate_loop = calibrate_full_model + else: + calibrate_loop = base_forward_loop if calibration_only: language_model = mtq.calibrate( @@ -856,6 +929,20 @@ def parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--calib_with_images", + action="store_true", + help=( + "Calibrate with image-text pairs (for VLMs). " + "For Nemotron VL this enables multimodal calibration using --vlm_dataset." + ), + ) + parser.add_argument( + "--vlm_dataset", + type=str, + default="scienceqa", + help=f"VLM calibration dataset name (choices: {get_supported_vlm_datasets()}).", + ) parser.add_argument("--inference_tensor_parallel", type=int, default=1) parser.add_argument("--inference_pipeline_parallel", type=int, default=1) parser.add_argument("--awq_block_size", default=0, type=int) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 1d9f59484..321e0f469 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility functions for getting samples and forward loop function for different vlm datasets.""" +"""Utility functions for getting samples and dataloader for different VLM calibration datasets.""" from typing import Any +import torch from torch.utils.data import DataLoader from .image_processor import MllamaImageProcessor @@ -30,12 +31,13 @@ __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] -def _get_vlm_dataset(dataset_name: str, num_samples: int): +def _get_vlm_dataset(dataset_name: str, num_samples: int, require_image: bool = True): """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. + require_image: If True, keep only samples that have an image field. Returns: A hugging face Dataset. @@ -53,7 +55,31 @@ def _get_vlm_dataset(dataset_name: str, num_samples: int): f"dataset {dataset_name} is not supported. Please use one of the following:" f" {get_supported_vlm_datasets()}." ) - return dataset.select(range(num_samples)) + + # `load_dataset` returns a DatasetDict. Use the configured split. + split = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].get("split", "train") + ds = dataset[split] if hasattr(dataset, "__getitem__") and split in dataset else dataset + + if require_image: + # Keep only samples with a non-null image field (ScienceQA has both). + try: + ds = ds.filter(lambda ex: ex.get("image", None) is not None) + except Exception: + # Some dataset backends may not support filter; fall back to best-effort selection below. + pass + + # Select the first `num_samples` entries (or fewer if dataset is smaller). + try: + return ds.select(range(min(num_samples, len(ds)))) + except Exception: + # For iterable datasets without __len__/select, take first N items. + collected = [] + for i, ex in enumerate(ds): + if i >= num_samples: + break + if not require_image or ex.get("image", None) is not None: + collected.append(ex) + return collected def get_supported_vlm_datasets() -> list[str]: @@ -75,9 +101,12 @@ def get_supported_vlm_datasets() -> list[str]: def get_vlm_dataset_dataloader( dataset_name: str = "scienceqa", - processor: MllamaImageProcessor = None, + processor: Any = None, batch_size: int = 1, num_samples: int = 512, + device: str | torch.device | None = None, + max_length: int | None = None, + require_image: bool = True, ) -> DataLoader: """Get a dataloader with the dataset name and processor of the target model. @@ -86,22 +115,74 @@ def get_vlm_dataset_dataloader( processor: Processor used for encoding images and text data. batch_size: Batch size of the returned dataloader. num_samples: Number of samples from the dataset. + device: Device to move returned tensors to. If None, keep on CPU. + max_length: Optional max length for text tokenization (if supported by the processor). + require_image: If True, keep only samples that have an image field. Returns: An instance of dataloader. """ assert processor is not None, "Please provide a valid processor." - dataset = _get_vlm_dataset(dataset_name, num_samples=num_samples) - # Apply the preprocessing function to the dataset - processed_dataset = dataset.map( - processor.preprocess_function, batched=False, remove_columns=dataset.column_names - ) - - # Create DataLoader with the custom collate function - return DataLoader( - processed_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=processor.collate_function, - ) + if device is not None: + device = torch.device(device) + + dataset = _get_vlm_dataset(dataset_name, num_samples=num_samples, require_image=require_image) + + # Legacy path: our internal image processor wrapper (e.g., Mllama). + if isinstance(processor, MllamaImageProcessor): + processed_dataset = dataset.map( + processor.preprocess_function, batched=False, remove_columns=dataset.column_names + ) + return DataLoader( + processed_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + # Generic HF ProcessorMixin / AutoProcessor path: tokenize & process images at collate-time. + # This works well for models that need extra multimodal kwargs (e.g., image_flags) in addition to pixel_values. + def _build_prompt(proc: Any, question: str) -> str: + tok = getattr(proc, "tokenizer", None) + # Prefer a chat template if present; it typically inserts the correct image placeholder tokens. + if tok is not None and getattr(tok, "chat_template", None) is not None: + try: + return tok.apply_chat_template( + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": question}], + } + ], + add_generation_prompt=True, + ) + except Exception: + pass + # Fallback: plain question. Many processors still correctly handle `images=...`. + return question + + def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: + questions = [ex.get("question", "Describe this image.") for ex in examples] + images = [ex.get("image", None) for ex in examples] + prompts = [_build_prompt(processor, q) for q in questions] + + kwargs: dict[str, Any] = {"text": prompts, "images": images, "return_tensors": "pt", "padding": True} + if max_length is not None: + kwargs.update({"truncation": True, "max_length": max_length}) + + enc = processor(**kwargs) + + # Some processors return BatchEncoding; normalize to plain dict of tensors. + if hasattr(enc, "data"): + enc = enc.data + out: dict[str, Any] = dict(enc) + + # Move tensors to device if requested. + if device is not None: + for k, v in list(out.items()): + if torch.is_tensor(v): + out[k] = v.to(device) + return out + + return DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate_fn) From 528b51d35965edf2b955f8baa193bc4920a25f98 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 9 Jan 2026 01:25:22 -0800 Subject: [PATCH 02/15] Add support for VLM calibration with image-text pair data Signed-off-by: Zhiyu Cheng --- modelopt/torch/utils/vlm_dataset_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 321e0f469..9a71e2d52 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -15,6 +15,7 @@ """Utility functions for getting samples and dataloader for different VLM calibration datasets.""" +import contextlib from typing import Any import torch @@ -62,11 +63,8 @@ def _get_vlm_dataset(dataset_name: str, num_samples: int, require_image: bool = if require_image: # Keep only samples with a non-null image field (ScienceQA has both). - try: + with contextlib.suppress(Exception): ds = ds.filter(lambda ex: ex.get("image", None) is not None) - except Exception: - # Some dataset backends may not support filter; fall back to best-effort selection below. - pass # Select the first `num_samples` entries (or fewer if dataset is smaller). try: @@ -167,7 +165,12 @@ def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dic images = [ex.get("image", None) for ex in examples] prompts = [_build_prompt(processor, q) for q in questions] - kwargs: dict[str, Any] = {"text": prompts, "images": images, "return_tensors": "pt", "padding": True} + kwargs: dict[str, Any] = { + "text": prompts, + "images": images, + "return_tensors": "pt", + "padding": True, + } if max_length is not None: kwargs.update({"truncation": True, "max_length": max_length}) From 3ef4b9d1302da29ed43f781ae31d00304204a12c Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 9 Jan 2026 18:09:34 -0800 Subject: [PATCH 03/15] add support for sampling from Nemotron-VLM-Dataset-v2 Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 27 ++++ modelopt/torch/utils/vlm_dataset_utils.py | 153 +++++++++++++++++++--- 2 files changed, 159 insertions(+), 21 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index bc7530f68..4248856d2 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -127,6 +127,9 @@ def make_calib_dataloader( device=device, max_length=args.calib_seq, require_image=True, + subsets=getattr(args, "vlm_subsets", None), + shuffle_buffer_size=getattr(args, "vlm_shuffle_buffer", 10_000), + seed=getattr(args, "vlm_shuffle_seed", 42), ) elif model_type == "mllama": assert processor is not None and isinstance(processor, MllamaImageProcessor), ( @@ -943,6 +946,27 @@ def parse_args() -> argparse.Namespace: default="scienceqa", help=f"VLM calibration dataset name (choices: {get_supported_vlm_datasets()}).", ) + parser.add_argument( + "--vlm_subsets", + type=str, + default="docvqa_cot,chartqa_cot", + help=( + "Comma-separated subset/config names for multi-subset VLM datasets " + "(e.g., nemotron_vlm_dataset_v2)." + ), + ) + parser.add_argument( + "--vlm_shuffle_buffer", + type=int, + default=10_000, + help="Shuffle buffer size for streaming VLM datasets (higher is more random but downloads more).", + ) + parser.add_argument( + "--vlm_shuffle_seed", + type=int, + default=42, + help="Random seed for streaming VLM dataset shuffle.", + ) parser.add_argument("--inference_tensor_parallel", type=int, default=1) parser.add_argument("--inference_pipeline_parallel", type=int, default=1) parser.add_argument("--awq_block_size", default=0, type=int) @@ -1109,4 +1133,7 @@ def main(args: argparse.Namespace): args.dataset = args.dataset.split(",") if args.dataset else None args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + args.vlm_subsets = ( + [s.strip() for s in args.vlm_subsets.split(",") if s.strip()] if args.vlm_subsets else None + ) main(args) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 9a71e2d52..d75cce578 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -13,9 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility functions for getting samples and dataloader for different VLM calibration datasets.""" +"""Utility functions for getting samples and dataloader for different VLM calibration datasets. + +This module supports both: +- Small non-streaming VLM datasets (e.g., ScienceQA) +- Large streaming VLM datasets (e.g., Nemotron-VLM-Dataset-v2) where we want to avoid downloading everything. +""" import contextlib +import itertools from typing import Any import torch @@ -27,18 +33,75 @@ # If we want to export more options to user like target languages, we need more standardized approach like dataclass. SUPPORTED_VLM_DATASET_CONFIG: dict[str, dict[str, Any]] = { "scienceqa": {"config": {"path": "derek-thomas/ScienceQA", "split": "train"}}, + # Large multi-subset dataset (use streaming to avoid downloading the entire dataset) + "nemotron_vlm_dataset_v2": { + "config": {"path": "nvidia/Nemotron-VLM-Dataset-v2", "split": "train", "streaming": True}, + # Provide a sane default that is easy to extend from the CLI. + "default_subsets": ["docvqa_cot", "chartqa_cot"], + }, } __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] -def _get_vlm_dataset(dataset_name: str, num_samples: int, require_image: bool = True): +class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): + """Wrap a HF streaming IterableDataset to be compatible with torch DataLoader.""" + + def __init__(self, hf_iterable, num_samples: int): + super().__init__() + self._hf_iterable = hf_iterable + self._num_samples = num_samples + + def __iter__(self): + return itertools.islice(iter(self._hf_iterable), self._num_samples) + + def __len__(self): + return self._num_samples + + +def _extract_text_from_messages(messages: Any) -> str | None: + """Best-effort extraction of a user text prompt from a chat-style `messages` field.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + if msg.get("role") != "user": + continue + content = msg.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + # Common multimodal format: [{"type":"image"}, {"type":"text","text":"..."}] + texts = [ + part["text"] + for part in content + if isinstance(part, dict) + and part.get("type") == "text" + and isinstance(part.get("text"), str) + ] + if texts: + return "\n".join(texts) + return None + + +def _get_vlm_dataset( + dataset_name: str, + num_samples: int, + require_image: bool = True, + subsets: list[str] | None = None, + shuffle_buffer_size: int = 10_000, + seed: int = 42, +): """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. require_image: If True, keep only samples that have an image field. + subsets: Optional subset/config names for multi-subset datasets (e.g., Nemotron-VLM-Dataset-v2). + shuffle_buffer_size: Shuffle buffer size for streaming datasets (higher is "more random"). + seed: RNG seed for streaming dataset shuffle. Returns: A hugging face Dataset. @@ -47,37 +110,61 @@ def _get_vlm_dataset(dataset_name: str, num_samples: int, require_image: bool = if dataset_name in SUPPORTED_VLM_DATASET_CONFIG: from datasets import load_dataset - # Use streaming can reduce the downloading time for large datasets - dataset = load_dataset( - **SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"], - ) + cfg = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].copy() + streaming = bool(cfg.pop("streaming", False)) + + if dataset_name == "nemotron_vlm_dataset_v2": + # This dataset contains many subsets; load only the requested ones via `name=...`. + if not subsets: + subsets = SUPPORTED_VLM_DATASET_CONFIG[dataset_name].get("default_subsets", []) + if not subsets: + raise ValueError("No VLM subsets provided for nemotron_vlm_dataset_v2.") + + # Load each subset as a separate (streaming) dataset, then interleave. + streams = [ + load_dataset( + cfg["path"], + name=subset, + split=cfg.get("split", "train"), + streaming=streaming, + ) + for subset in subsets + ] + try: + from datasets import interleave_datasets + + ds = interleave_datasets(streams) + except Exception: + # Fallback: round-robin by chaining (less balanced than interleave). + ds = itertools.chain.from_iterable(streams) + else: + dataset = load_dataset(**cfg, streaming=streaming) + split = cfg.get("split", "train") + ds = dataset[split] if hasattr(dataset, "__getitem__") and split in dataset else dataset else: raise NotImplementedError( f"dataset {dataset_name} is not supported. Please use one of the following:" f" {get_supported_vlm_datasets()}." ) - # `load_dataset` returns a DatasetDict. Use the configured split. - split = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].get("split", "train") - ds = dataset[split] if hasattr(dataset, "__getitem__") and split in dataset else dataset + # Streaming datasets: shuffle with bounded buffer and wrap into a torch IterableDataset. + if dataset_name == "nemotron_vlm_dataset_v2": + with contextlib.suppress(Exception): + ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer_size) if require_image: # Keep only samples with a non-null image field (ScienceQA has both). with contextlib.suppress(Exception): - ds = ds.filter(lambda ex: ex.get("image", None) is not None) + ds = ds.filter( + lambda ex: ex.get("image", None) is not None or ex.get("images", None) is not None + ) # Select the first `num_samples` entries (or fewer if dataset is smaller). try: return ds.select(range(min(num_samples, len(ds)))) except Exception: - # For iterable datasets without __len__/select, take first N items. - collected = [] - for i, ex in enumerate(ds): - if i >= num_samples: - break - if not require_image or ex.get("image", None) is not None: - collected.append(ex) - return collected + # For streaming/iterable datasets without __len__/select, wrap for DataLoader iteration. + return _HFDatasetsIterableWrapper(ds, num_samples=num_samples) def get_supported_vlm_datasets() -> list[str]: @@ -105,6 +192,9 @@ def get_vlm_dataset_dataloader( device: str | torch.device | None = None, max_length: int | None = None, require_image: bool = True, + subsets: list[str] | None = None, + shuffle_buffer_size: int = 10_000, + seed: int = 42, ) -> DataLoader: """Get a dataloader with the dataset name and processor of the target model. @@ -125,7 +215,14 @@ def get_vlm_dataset_dataloader( if device is not None: device = torch.device(device) - dataset = _get_vlm_dataset(dataset_name, num_samples=num_samples, require_image=require_image) + dataset = _get_vlm_dataset( + dataset_name, + num_samples=num_samples, + require_image=require_image, + subsets=subsets, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) # Legacy path: our internal image processor wrapper (e.g., Mllama). if isinstance(processor, MllamaImageProcessor): @@ -161,8 +258,22 @@ def _build_prompt(proc: Any, question: str) -> str: return question def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: - questions = [ex.get("question", "Describe this image.") for ex in examples] - images = [ex.get("image", None) for ex in examples] + questions = [] + images = [] + for ex in examples: + q = ex.get("question") + if q is None and "messages" in ex: + q = _extract_text_from_messages(ex.get("messages")) + if q is None: + q = "Describe this image." + questions.append(q) + + img = ex.get("image", None) + if img is None: + img = ex.get("images", None) + if isinstance(img, list) and img: + img = img[0] + images.append(img) prompts = [_build_prompt(processor, q) for q in questions] kwargs: dict[str, Any] = { From 2d60f988aec1a84dc1c0d18e2e48d6a9571b5a97 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 9 Jan 2026 21:26:30 -0800 Subject: [PATCH 04/15] update readme Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 23ab1ecf9..3ed11ebab 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -326,6 +326,24 @@ with torch.inference_mode(): python hf_ptq.py --pyt_ckpt_path --qformat fp8 --export_path --trust_remote_code ``` +#### VLM calibration with image-text pairs (e.g., Nemotron VL) + +For vision-language models, calibration quality can improve by using image-text pairs instead of text-only data: + +```bash +python hf_ptq.py \ + --pyt_ckpt_path \ + --qformat nvfp4 \ + --export_path \ + --trust_remote_code \ + --calib_with_images \ + --vlm_dataset nemotron_vlm_dataset_v2 \ + --vlm_subsets docvqa_cot,chartqa_cot \ + --calib_size 256 +``` + +> Note: when `--calib_with_images` is set, `--calib_size` must be a single value. + ### Hugging Face framework [Script](./scripts/huggingface_example.sh) Alternatively, the framework script `huggingface_example.sh` also supports quantize and export: From 42a8406a9d9bd4cb2694d36cb9eac069a7216e4e Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 10:57:02 -0800 Subject: [PATCH 05/15] fix issues when calibrate with image data for Nemotron Nano VL Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/README.md | 2 +- examples/llm_ptq/hf_ptq.py | 105 +++- modelopt/torch/utils/vlm_dataset_utils.py | 559 ++++++++++++++++++++-- 3 files changed, 627 insertions(+), 39 deletions(-) diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 3ed11ebab..271b684f4 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -338,7 +338,7 @@ python hf_ptq.py \ --trust_remote_code \ --calib_with_images \ --vlm_dataset nemotron_vlm_dataset_v2 \ - --vlm_subsets docvqa_cot,chartqa_cot \ + --vlm_subsets sparsetables,plotqa_cot,wiki_en \ --calib_size 256 ``` diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 4248856d2..34c4d7b57 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,8 @@ # limitations under the License. import argparse +import contextlib +import inspect import random import time import warnings @@ -130,6 +132,9 @@ def make_calib_dataloader( subsets=getattr(args, "vlm_subsets", None), shuffle_buffer_size=getattr(args, "vlm_shuffle_buffer", 10_000), seed=getattr(args, "vlm_shuffle_seed", 42), + image_root=getattr(args, "vlm_image_root", None), + use_media_shards=not getattr(args, "vlm_disable_media_shards", False), + max_shards=getattr(args, "vlm_max_shards", None), ) elif model_type == "mllama": assert processor is not None and isinstance(processor, MllamaImageProcessor), ( @@ -355,6 +360,11 @@ def load_model(args: argparse.Namespace): else: tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + # Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration. + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!" + default_padding_side = tokenizer.padding_side tokenizer.padding_side = "left" @@ -506,7 +516,72 @@ def mono_quantize( if getattr(args, "calib_with_images", False) and is_nemotron_vl_model: def calibrate_full_model(_model): - return base_forward_loop(full_model) + forward_params = inspect.signature(full_model.forward).parameters + accepts_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in forward_params.values() + ) + allowed_keys = set(forward_params.keys()) + + full_model.eval() + with torch.no_grad(): + for batch in calib_dataloader: + if accepts_kwargs: + call_kwargs = batch + else: + call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys} + call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None} + + # Nemotron VL v2 vision encoder expects pixel_values in its torch_dtype (typically bf16). + # The processor often returns float32 pixel_values; also, ModelOpt may wrap some vision + # layers with quant modules (even if disabled), which can be stricter about dtype. + vision_dtype = None + with contextlib.suppress(Exception): + vision_dtype = getattr(full_model.vision_model.config, "torch_dtype", None) + if vision_dtype is None: + vision_dtype = getattr(full_model.language_model.config, "torch_dtype", None) + if vision_dtype is not None and "pixel_values" in call_kwargs: + pv = call_kwargs["pixel_values"] + if torch.is_tensor(pv) and pv.dtype != vision_dtype: + call_kwargs["pixel_values"] = pv.to(dtype=vision_dtype) + # Avoid calling Nemotron VL `full_model.forward()` directly during calibration: + # - Some versions call `torch.distributed.get_rank()` unconditionally. + # - Some versions construct an output object that assumes `past_key_values` exists. + # Instead, reproduce the minimal forward needed to exercise both vision + language paths. + pixel_values = call_kwargs.get("pixel_values", None) + input_ids = call_kwargs.get("input_ids", None) + attention_mask = call_kwargs.get("attention_mask", None) + position_ids = call_kwargs.get("position_ids", None) + image_flags = call_kwargs.get("image_flags", None) + + if pixel_values is None or input_ids is None or image_flags is None: + continue + + inputs_embeds = full_model.language_model.get_input_embeddings()(input_ids) + image_flags_s = image_flags.squeeze(-1) + + B, N, C = inputs_embeds.shape + flat_embeds = inputs_embeds.reshape(B * N, C) + flat_ids = input_ids.reshape(B * N) + selected = flat_ids == full_model.img_context_token_id + + vit_embeds = full_model.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags_s == 1] + try: + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) + except Exception: + vit_embeds = vit_embeds.reshape(-1, C) + n_token = selected.sum() + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds[:n_token] + + inputs_embeds = flat_embeds.reshape(B, N, C) + + full_model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=False, + ) calibrate_loop = calibrate_full_model else: @@ -949,7 +1024,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--vlm_subsets", type=str, - default="docvqa_cot,chartqa_cot", + default="sparsetables,plotqa_cot,wiki_en", help=( "Comma-separated subset/config names for multi-subset VLM datasets " "(e.g., nemotron_vlm_dataset_v2)." @@ -967,6 +1042,32 @@ def parse_args() -> argparse.Namespace: default=42, help="Random seed for streaming VLM dataset shuffle.", ) + parser.add_argument( + "--vlm_image_root", + type=str, + default=None, + help=( + "Local directory containing image files referenced by the VLM dataset annotations. " + "Required for nemotron_vlm_dataset_v2 subsets that only ship JSONL (e.g., docvqa_cot, chartqa_cot)." + ), + ) + parser.add_argument( + "--vlm_max_shards", + type=int, + default=1, + help=( + "For VLM subsets that include in-repo tar shards under `/media/*.tar`, " + "limit how many shards to download/use for calibration. Increase if you don't get enough samples." + ), + ) + parser.add_argument( + "--vlm_disable_media_shards", + action="store_true", + help=( + "Disable reading in-repo `media/shard_*.tar` files for nemotron_vlm_dataset_v2. " + "Useful if you want to use JSONL-only subsets together with --vlm_image_root." + ), + ) parser.add_argument("--inference_tensor_parallel", type=int, default=1) parser.add_argument("--inference_pipeline_parallel", type=int, default=1) parser.add_argument("--awq_block_size", default=0, type=int) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index d75cce578..a2849fa77 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -21,7 +21,15 @@ """ import contextlib +import copy +import functools import itertools +import json +import os +import random +import tarfile +from io import BytesIO +from pathlib import Path from typing import Any import torch @@ -36,13 +44,42 @@ # Large multi-subset dataset (use streaming to avoid downloading the entire dataset) "nemotron_vlm_dataset_v2": { "config": {"path": "nvidia/Nemotron-VLM-Dataset-v2", "split": "train", "streaming": True}, - # Provide a sane default that is easy to extend from the CLI. - "default_subsets": ["docvqa_cot", "chartqa_cot"], + # Provide a sane default that (a) includes in-repo media shards and (b) is document-centric. + # Subsets like docvqa_cot/chartqa_cot are JSONL-only in the dataset repo and require --vlm_image_root. + "default_subsets": ["sparsetables", "plotqa_cot", "wiki_en"], }, } __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] +_IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + + +@functools.lru_cache(maxsize=8) +def _list_repo_files(repo_id: str, repo_type: str = "dataset") -> list[str]: + from huggingface_hub import list_repo_files + + return list_repo_files(repo_id=repo_id, repo_type=repo_type) + + +def _iter_shuffle_buffer(items_iter, buffer_size: int, seed: int): + """Shuffle an iterator with a bounded buffer (approximate shuffle for streaming sources).""" + if buffer_size <= 1: + yield from items_iter + return + + rng = random.Random(seed) + buf = [] + for item in items_iter: + buf.append(item) + if len(buf) >= buffer_size: + rng.shuffle(buf) + yield from buf + buf = [] + if buf: + rng.shuffle(buf) + yield from buf + class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): """Wrap a HF streaming IterableDataset to be compatible with torch DataLoader.""" @@ -59,6 +96,267 @@ def __len__(self): return self._num_samples +class _TarShardIterable(torch.utils.data.IterableDataset): + """Iterate a list of tar shards and yield decoded samples with PIL images + metadata. + + This is a lightweight alternative to webdataset/energon for calibration use-cases. + """ + + def __init__( + self, + repo_id: str, + shard_paths: list[str], + num_samples: int, + seed: int = 42, + shuffle_buffer_size: int = 0, + ): + super().__init__() + self.repo_id = repo_id + self.shard_paths = shard_paths + self.num_samples = num_samples + self.seed = seed + self.shuffle_buffer_size = shuffle_buffer_size + + def __iter__(self): + from huggingface_hub import hf_hub_download + from PIL import Image + + rng = random.Random(self.seed) + shard_paths = list(self.shard_paths) + rng.shuffle(shard_paths) + + def _raw_samples(): + yielded = 0 + for shard in shard_paths: + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) + with tarfile.open(local_tar, "r:*") as tf: + current_key = None + bucket: dict[str, bytes] = {} + + def _emit(key: str, data_bucket: dict[str, bytes]): + nonlocal yielded + if yielded >= self.num_samples: + return None + img_bytes = data_bucket.get("img") + meta_bytes = data_bucket.get("json") + if img_bytes is None or meta_bytes is None: + return None + try: + meta = json.loads(meta_bytes.decode("utf-8", errors="ignore")) + except Exception: + return None + try: + img = Image.open(BytesIO(img_bytes)).convert("RGB") + except Exception: + return None + + sample: dict[str, Any] = {"id": meta.get("id", key), "image": img} + if "messages" in meta: + sample["messages"] = meta["messages"] + else: + text = meta.get("text") or meta.get("question") or "Describe the image." + sample["messages"] = [ + { + "role": "user", + "content": [ + {"type": "image", "image": ""}, + {"type": "text", "text": text}, + ], + } + ] + + yielded += 1 + return sample + + for member in tf: + if not member.isfile(): + continue + name = member.name + base, ext = os.path.splitext(name) + ext = ext.lower() + if ext not in _IMG_EXTS and ext != ".json": + continue + f = tf.extractfile(member) + if f is None: + continue + data = f.read() + + if current_key is None: + current_key = base + if base != current_key: + out = _emit(current_key, bucket) + if out is not None: + yield out + if yielded >= self.num_samples: + return + bucket = {} + current_key = base + + if ext == ".json": + bucket["json"] = data + else: + bucket["img"] = data + + if current_key is not None: + out = _emit(current_key, bucket) + if out is not None: + yield out + if yielded >= self.num_samples: + return + + it = _raw_samples() + if self.shuffle_buffer_size and self.shuffle_buffer_size > 1: + it = _iter_shuffle_buffer(it, buffer_size=self.shuffle_buffer_size, seed=self.seed) + yield from itertools.islice(it, self.num_samples) + + +class _NemotronTarPlusJsonlIterable(torch.utils.data.IterableDataset): + """Join Nemotron VLM `media/shard_*.tar` (images-only) with `/.jsonl` (messages). + + Many Nemotron-VLM-Dataset-v2 subsets store PNGs in tar shards and store messages in a separate JSONL where the + image content item references the PNG filename (e.g., `{"type":"image","image":"292180.png"}`). + """ + + def __init__( + self, + repo_id: str, + subsets: list[str], + shard_paths: list[str], + num_samples: int, + seed: int, + shuffle_buffer_size: int, + max_shards: int | None, + ): + super().__init__() + self.repo_id = repo_id + self.subsets = subsets + self.shard_paths = shard_paths + self.num_samples = num_samples + self.seed = seed + self.shuffle_buffer_size = shuffle_buffer_size + self.max_shards = max_shards + + def __iter__(self): + from huggingface_hub import hf_hub_download + from PIL import Image + + rng = random.Random(self.seed) + + # Partition shards by subset. + shards_by_subset: dict[str, list[str]] = {s: [] for s in self.subsets} + for p in self.shard_paths: + subset = p.split("/", 1)[0] + if subset in shards_by_subset: + shards_by_subset[subset].append(p) + + for subset in list(shards_by_subset.keys()): + shard_list = sorted(shards_by_subset[subset]) + if self.max_shards is not None: + shard_list = shard_list[: max(0, self.max_shards)] + shards_by_subset[subset] = shard_list + + # Roughly split sample budget across subsets. + per_subset_target = max(1, self.num_samples // max(1, len(self.subsets))) + yielded_total = 0 + + for subset in self.subsets: + if yielded_total >= self.num_samples: + break + + shard_list = list(shards_by_subset.get(subset, [])) + if not shard_list: + continue + rng.shuffle(shard_list) + + # 1) Collect candidate image filenames from tar headers (no payload reads). + candidate_names: list[str] = [] + header_limit = per_subset_target * 50 + for shard in shard_list: + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if not member.isfile(): + continue + name = member.name + _, ext = os.path.splitext(name) + if ext.lower() not in _IMG_EXTS: + continue + candidate_names.append(name) + if len(candidate_names) >= header_limit: + break + if len(candidate_names) >= header_limit: + break + + if not candidate_names: + continue + + rng.shuffle(candidate_names) + lookup_limit = per_subset_target * 10 + candidate_set = set(candidate_names[:lookup_limit]) + + # 2) Scan JSONL to map image filename -> messages. + jsonl_path = hf_hub_download( + repo_id=self.repo_id, filename=f"{subset}/{subset}.jsonl", repo_type="dataset" + ) + meta_by_image: dict[str, dict[str, Any]] = {} + with open(jsonl_path, encoding="utf-8") as f: + for line in f: + try: + obj = json.loads(line) + except Exception: + continue + msgs = obj.get("messages") + img_name = ( + _extract_first_image_from_messages(msgs) if msgs is not None else None + ) + if isinstance(img_name, str) and img_name in candidate_set: + meta_by_image[img_name] = {"id": obj.get("id"), "messages": msgs} + if len(meta_by_image) >= per_subset_target: + break + + if not meta_by_image: + continue + + # 3) Extract matched images and yield samples. + needed = set(meta_by_image.keys()) + for shard in shard_list: + if yielded_total >= self.num_samples or not needed: + break + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if yielded_total >= self.num_samples or not needed: + break + if not member.isfile(): + continue + name = member.name + if name not in needed: + continue + f = tf.extractfile(member) + if f is None: + continue + try: + img = Image.open(BytesIO(f.read())).convert("RGB") + except Exception: + continue + meta = meta_by_image.get(name) + if not meta: + continue + yield { + "id": meta.get("id", name), + "messages": meta.get("messages"), + "image": img, + } + needed.discard(name) + yielded_total += 1 + + def _extract_text_from_messages(messages: Any) -> str | None: """Best-effort extraction of a user text prompt from a chat-style `messages` field.""" if not isinstance(messages, list): @@ -85,6 +383,118 @@ def _extract_text_from_messages(messages: Any) -> str | None: return None +def _messages_up_to_last_user(messages: Any) -> list[dict[str, Any]] | None: + """Return messages truncated to the last user turn (inclusive).""" + if not isinstance(messages, list): + return None + last_user_idx = None + for i, msg in enumerate(messages): + if isinstance(msg, dict) and msg.get("role") == "user": + last_user_idx = i + if last_user_idx is None: + return None + trimmed = messages[: last_user_idx + 1] + return [m for m in trimmed if isinstance(m, dict)] + + +def _messages_has_image(messages: Any) -> bool: + """Return True if `messages` contains an image content item.""" + if not isinstance(messages, list): + return False + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + return True + return False + + +def _extract_first_image_from_messages(messages: Any) -> Any: + """Best-effort extraction of an image object from a chat-style `messages` field.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not (isinstance(part, dict) and part.get("type") == "image"): + continue + # Common keys used by HF datasets / chat templates + for key in ("image", "images", "value", "data", "path", "image_url", "url"): + if key in part: + val = part[key] + if isinstance(val, list) and val: + return val[0] + return val + # Fallback: return the dict itself (some processors may accept it) + return part + return None + + +def _maybe_load_image(image_obj: Any, repo_id: str | None, image_root: str | Path | None) -> Any: + """Convert common image references (path/bytes) into a PIL image if possible. + + For some streaming datasets, images are stored as file paths inside the dataset repo. + In that case, we lazily download just the referenced files via `hf_hub_download`. + """ + if image_obj is None: + return None + + # If it's a list, take the first (some formats store a list for multi-image samples). + if isinstance(image_obj, list) and image_obj: + image_obj = image_obj[0] + + # Path-like reference + if isinstance(image_obj, str): + # First, try resolving against a local image root (best option for datasets that only ship JSONL refs). + if image_root is not None: + try: + from PIL import Image + + local_path = Path(image_root) / image_obj + if local_path.exists(): + return Image.open(local_path).convert("RGB") + except Exception: + pass + + if repo_id is None: + return image_obj + try: + from huggingface_hub import hf_hub_download + from PIL import Image + + local_path = hf_hub_download(repo_id=repo_id, filename=image_obj, repo_type="dataset") + return Image.open(local_path).convert("RGB") + except Exception: + return None + + # Dict-like reference (common in chat content items) + if isinstance(image_obj, dict): + # bytes payload + if "bytes" in image_obj and isinstance(image_obj["bytes"], (bytes, bytearray)): + try: + from PIL import Image + + return Image.open(BytesIO(image_obj["bytes"])).convert("RGB") + except Exception: + return None + + # path/url-ish payloads + for key in ("path", "image", "image_path", "file", "url", "image_url"): + if key in image_obj and isinstance(image_obj[key], str): + return _maybe_load_image(image_obj[key], repo_id=repo_id, image_root=image_root) + + # If it's already a PIL/numpy/torch image-like object, just return it and let the processor validate. + return image_obj + + def _get_vlm_dataset( dataset_name: str, num_samples: int, @@ -92,6 +502,8 @@ def _get_vlm_dataset( subsets: list[str] | None = None, shuffle_buffer_size: int = 10_000, seed: int = 42, + use_media_shards: bool = True, + max_shards: int | None = None, ): """Load a portion of train dataset with the dataset name and a given size. @@ -102,6 +514,8 @@ def _get_vlm_dataset( subsets: Optional subset/config names for multi-subset datasets (e.g., Nemotron-VLM-Dataset-v2). shuffle_buffer_size: Shuffle buffer size for streaming datasets (higher is "more random"). seed: RNG seed for streaming dataset shuffle. + use_media_shards: If True, prefer reading in-repo `media/shard_*.tar` files when available. + max_shards: Optional cap on the number of tar shards to download/use. Returns: A hugging face Dataset. @@ -120,6 +534,34 @@ def _get_vlm_dataset( if not subsets: raise ValueError("No VLM subsets provided for nemotron_vlm_dataset_v2.") + repo_id = cfg["path"] + + # Prefer in-repo media tar shards when present. HF `datasets` streaming alone does not join media. + if use_media_shards: + all_files = _list_repo_files(repo_id, repo_type="dataset") + shard_paths: list[str] = [] + for subset in subsets: + prefix = f"{subset}/media/" + shard_paths.extend( + [ + p + for p in all_files + if p.startswith(prefix) and p.lower().endswith(".tar") + ] + ) + + shard_paths = sorted(set(shard_paths)) + if shard_paths: + return _NemotronTarPlusJsonlIterable( + repo_id=repo_id, + subsets=subsets, + shard_paths=shard_paths, + num_samples=num_samples, + seed=seed, + shuffle_buffer_size=shuffle_buffer_size, + max_shards=max_shards, + ) + # Load each subset as a separate (streaming) dataset, then interleave. streams = [ load_dataset( @@ -156,7 +598,9 @@ def _get_vlm_dataset( # Keep only samples with a non-null image field (ScienceQA has both). with contextlib.suppress(Exception): ds = ds.filter( - lambda ex: ex.get("image", None) is not None or ex.get("images", None) is not None + lambda ex: ex.get("image", None) is not None + or ex.get("images", None) is not None + or _messages_has_image(ex.get("messages")) ) # Select the first `num_samples` entries (or fewer if dataset is smaller). @@ -195,6 +639,9 @@ def get_vlm_dataset_dataloader( subsets: list[str] | None = None, shuffle_buffer_size: int = 10_000, seed: int = 42, + image_root: str | Path | None = None, + use_media_shards: bool = True, + max_shards: int | None = None, ) -> DataLoader: """Get a dataloader with the dataset name and processor of the target model. @@ -212,6 +659,11 @@ def get_vlm_dataset_dataloader( """ assert processor is not None, "Please provide a valid processor." + # Optional: allow callers to set a local image root for datasets that only ship JSON references. + # We store it on the processor instance to avoid threading it through a bunch of nested closures. + if image_root is not None: + setattr(processor, "_modelopt_vlm_image_root", image_root) + if device is not None: device = torch.device(device) @@ -222,6 +674,8 @@ def get_vlm_dataset_dataloader( subsets=subsets, shuffle_buffer_size=shuffle_buffer_size, seed=seed, + use_media_shards=use_media_shards, + max_shards=max_shards, ) # Legacy path: our internal image processor wrapper (e.g., Mllama). @@ -237,48 +691,72 @@ def get_vlm_dataset_dataloader( ) # Generic HF ProcessorMixin / AutoProcessor path: tokenize & process images at collate-time. - # This works well for models that need extra multimodal kwargs (e.g., image_flags) in addition to pixel_values. - def _build_prompt(proc: Any, question: str) -> str: - tok = getattr(proc, "tokenizer", None) - # Prefer a chat template if present; it typically inserts the correct image placeholder tokens. - if tok is not None and getattr(tok, "chat_template", None) is not None: - try: - return tok.apply_chat_template( - [ - { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "text": question}], - } - ], - add_generation_prompt=True, - ) - except Exception: - pass - # Fallback: plain question. Many processors still correctly handle `images=...`. - return question + # For Nemotron VLM datasets, we prefer to follow the model-card flow: + # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # inputs = processor(text=[prompt], images=[pil_image], ...) def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: - questions = [] - images = [] + repo_id = None + if dataset_name == "nemotron_vlm_dataset_v2": + repo_id = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"] + image_root = getattr(processor, "_modelopt_vlm_image_root", None) + + pairs: list[tuple[str, Any]] = [] for ex in examples: - q = ex.get("question") - if q is None and "messages" in ex: - q = _extract_text_from_messages(ex.get("messages")) - if q is None: - q = "Describe this image." - questions.append(q) + messages = ex.get("messages") + # Image extraction img = ex.get("image", None) if img is None: img = ex.get("images", None) - if isinstance(img, list) and img: - img = img[0] - images.append(img) - prompts = [_build_prompt(processor, q) for q in questions] + if img is None and messages is not None: + img = _extract_first_image_from_messages(messages) + img = _maybe_load_image(img, repo_id=repo_id, image_root=image_root) + if require_image and img is None: + continue + + # Prompt extraction + prompt = None + tok = getattr(processor, "tokenizer", None) + if tok is not None and messages is not None: + trimmed = _messages_up_to_last_user(messages) or [] + # For some Nemotron-style templates, the image content expects an empty string. + # Keep the actual image path separate for loading; blank it in the prompt message. + prompt_msgs = copy.deepcopy(trimmed) + for msg in prompt_msgs: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + part["image"] = "" + with contextlib.suppress(Exception): + prompt = tok.apply_chat_template( + prompt_msgs, tokenize=False, add_generation_prompt=True + ) + + if prompt is None: + # Fallback: best-effort question-only prompt. + q = ex.get("question") + if q is None and messages is not None: + q = _extract_text_from_messages(messages) + prompt = q or "Describe the image." + + pairs.append((prompt, img)) + + if not pairs: + raise ValueError( + "No usable images found in the current batch. " + "If you're using JSONL-only subsets (e.g., docvqa_cot/chartqa_cot), provide " + "`--vlm_image_root ` so referenced paths can be resolved. " + "If you're using asset-included subsets, keep media shard loading enabled " + "(default) and consider increasing `--vlm_max_shards`." + ) + + prompts, images = zip(*pairs) kwargs: dict[str, Any] = { - "text": prompts, - "images": images, + "text": list(prompts), + "images": list(images), "return_tensors": "pt", "padding": True, } @@ -292,6 +770,15 @@ def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dic enc = enc.data out: dict[str, Any] = dict(enc) + # Nemotron Nano VL v2 expects `image_flags` in forward(), but the processor does not emit it. + # `pixel_values` is flattened across batch*images, so `image_flags` should align with pixel_values.shape[0]. + if out.get("pixel_values") is not None and out.get("image_flags") is None: + pv = out["pixel_values"] + if torch.is_tensor(pv): + out["image_flags"] = torch.ones( + (pv.shape[0], 1), device=pv.device, dtype=torch.long + ) + # Move tensors to device if requested. if device is not None: for k, v in list(out.items()): From 7489a363fde903acdeae947a96d6dc2e42b75234 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 10:58:04 -0800 Subject: [PATCH 06/15] fix issues when calibrate with image data for Nemotron Nano VL Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 34c4d7b57..b75aefa8c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -536,9 +536,13 @@ def calibrate_full_model(_model): # layers with quant modules (even if disabled), which can be stricter about dtype. vision_dtype = None with contextlib.suppress(Exception): - vision_dtype = getattr(full_model.vision_model.config, "torch_dtype", None) + vision_dtype = getattr( + full_model.vision_model.config, "torch_dtype", None + ) if vision_dtype is None: - vision_dtype = getattr(full_model.language_model.config, "torch_dtype", None) + vision_dtype = getattr( + full_model.language_model.config, "torch_dtype", None + ) if vision_dtype is not None and "pixel_values" in call_kwargs: pv = call_kwargs["pixel_values"] if torch.is_tensor(pv) and pv.dtype != vision_dtype: @@ -547,16 +551,18 @@ def calibrate_full_model(_model): # - Some versions call `torch.distributed.get_rank()` unconditionally. # - Some versions construct an output object that assumes `past_key_values` exists. # Instead, reproduce the minimal forward needed to exercise both vision + language paths. - pixel_values = call_kwargs.get("pixel_values", None) - input_ids = call_kwargs.get("input_ids", None) - attention_mask = call_kwargs.get("attention_mask", None) - position_ids = call_kwargs.get("position_ids", None) - image_flags = call_kwargs.get("image_flags", None) + pixel_values = call_kwargs.get("pixel_values") + input_ids = call_kwargs.get("input_ids") + attention_mask = call_kwargs.get("attention_mask") + position_ids = call_kwargs.get("position_ids") + image_flags = call_kwargs.get("image_flags") if pixel_values is None or input_ids is None or image_flags is None: continue - inputs_embeds = full_model.language_model.get_input_embeddings()(input_ids) + inputs_embeds = full_model.language_model.get_input_embeddings()( + input_ids + ) image_flags_s = image_flags.squeeze(-1) B, N, C = inputs_embeds.shape @@ -567,11 +573,15 @@ def calibrate_full_model(_model): vit_embeds = full_model.extract_feature(pixel_values) vit_embeds = vit_embeds[image_flags_s == 1] try: - flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) + flat_embeds[selected] = flat_embeds[ + selected + ] * 0.0 + vit_embeds.reshape(-1, C) except Exception: vit_embeds = vit_embeds.reshape(-1, C) n_token = selected.sum() - flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds[:n_token] + flat_embeds[selected] = ( + flat_embeds[selected] * 0.0 + vit_embeds[:n_token] + ) inputs_embeds = flat_embeds.reshape(B, N, C) From bd871549afb78e9245690bcd79b89cd635839f80 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 10:59:51 -0800 Subject: [PATCH 07/15] fix issues when calibrate with image data for Nemotron Nano VL Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 12 ++++++------ modelopt/torch/utils/vlm_dataset_utils.py | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index b75aefa8c..41276d707 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -565,9 +565,9 @@ def calibrate_full_model(_model): ) image_flags_s = image_flags.squeeze(-1) - B, N, C = inputs_embeds.shape - flat_embeds = inputs_embeds.reshape(B * N, C) - flat_ids = input_ids.reshape(B * N) + b, n, c = inputs_embeds.shape + flat_embeds = inputs_embeds.reshape(b * n, c) + flat_ids = input_ids.reshape(b * n) selected = flat_ids == full_model.img_context_token_id vit_embeds = full_model.extract_feature(pixel_values) @@ -575,15 +575,15 @@ def calibrate_full_model(_model): try: flat_embeds[selected] = flat_embeds[ selected - ] * 0.0 + vit_embeds.reshape(-1, C) + ] * 0.0 + vit_embeds.reshape(-1, c) except Exception: - vit_embeds = vit_embeds.reshape(-1, C) + vit_embeds = vit_embeds.reshape(-1, c) n_token = selected.sum() flat_embeds[selected] = ( flat_embeds[selected] * 0.0 + vit_embeds[:n_token] ) - inputs_embeds = flat_embeds.reshape(B, N, C) + inputs_embeds = flat_embeds.reshape(b, n, c) full_model.language_model( inputs_embeds=inputs_embeds, diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index a2849fa77..2ee26885c 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -342,7 +342,11 @@ def __iter__(self): if f is None: continue try: - img = Image.open(BytesIO(f.read())).convert("RGB") + raw = f.read() + # Some tarfile stubs type `read()` as returning `str`; normalize to bytes for mypy. + if isinstance(raw, str): + raw = raw.encode() + img = Image.open(BytesIO(raw)).convert("RGB") except Exception: continue meta = meta_by_image.get(name) From 3200a6399d2cd2e2319e614a75ebe42f93641954 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 11:01:13 -0800 Subject: [PATCH 08/15] fix issues when calibrate with image data for Nemotron Nano VL Signed-off-by: Zhiyu Cheng --- modelopt/torch/utils/vlm_dataset_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 2ee26885c..14d7018f4 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -346,7 +346,9 @@ def __iter__(self): # Some tarfile stubs type `read()` as returning `str`; normalize to bytes for mypy. if isinstance(raw, str): raw = raw.encode() - img = Image.open(BytesIO(raw)).convert("RGB") + # Help mypy: BytesIO expects a bytes-like buffer. + raw_bytes: bytes = raw + img = Image.open(BytesIO(raw_bytes)).convert("RGB") except Exception: continue meta = meta_by_image.get(name) From 8964aa5eaf94b63437161bd8252e064403c27506 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 11:58:58 -0800 Subject: [PATCH 09/15] simplify Signed-off-by: Zhiyu Cheng --- modelopt/torch/utils/vlm_dataset_utils.py | 135 ---------------------- 1 file changed, 135 deletions(-) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 14d7018f4..7dd2870d7 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -62,25 +62,6 @@ def _list_repo_files(repo_id: str, repo_type: str = "dataset") -> list[str]: return list_repo_files(repo_id=repo_id, repo_type=repo_type) -def _iter_shuffle_buffer(items_iter, buffer_size: int, seed: int): - """Shuffle an iterator with a bounded buffer (approximate shuffle for streaming sources).""" - if buffer_size <= 1: - yield from items_iter - return - - rng = random.Random(seed) - buf = [] - for item in items_iter: - buf.append(item) - if len(buf) >= buffer_size: - rng.shuffle(buf) - yield from buf - buf = [] - if buf: - rng.shuffle(buf) - yield from buf - - class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): """Wrap a HF streaming IterableDataset to be compatible with torch DataLoader.""" @@ -96,122 +77,6 @@ def __len__(self): return self._num_samples -class _TarShardIterable(torch.utils.data.IterableDataset): - """Iterate a list of tar shards and yield decoded samples with PIL images + metadata. - - This is a lightweight alternative to webdataset/energon for calibration use-cases. - """ - - def __init__( - self, - repo_id: str, - shard_paths: list[str], - num_samples: int, - seed: int = 42, - shuffle_buffer_size: int = 0, - ): - super().__init__() - self.repo_id = repo_id - self.shard_paths = shard_paths - self.num_samples = num_samples - self.seed = seed - self.shuffle_buffer_size = shuffle_buffer_size - - def __iter__(self): - from huggingface_hub import hf_hub_download - from PIL import Image - - rng = random.Random(self.seed) - shard_paths = list(self.shard_paths) - rng.shuffle(shard_paths) - - def _raw_samples(): - yielded = 0 - for shard in shard_paths: - local_tar = hf_hub_download( - repo_id=self.repo_id, filename=shard, repo_type="dataset" - ) - with tarfile.open(local_tar, "r:*") as tf: - current_key = None - bucket: dict[str, bytes] = {} - - def _emit(key: str, data_bucket: dict[str, bytes]): - nonlocal yielded - if yielded >= self.num_samples: - return None - img_bytes = data_bucket.get("img") - meta_bytes = data_bucket.get("json") - if img_bytes is None or meta_bytes is None: - return None - try: - meta = json.loads(meta_bytes.decode("utf-8", errors="ignore")) - except Exception: - return None - try: - img = Image.open(BytesIO(img_bytes)).convert("RGB") - except Exception: - return None - - sample: dict[str, Any] = {"id": meta.get("id", key), "image": img} - if "messages" in meta: - sample["messages"] = meta["messages"] - else: - text = meta.get("text") or meta.get("question") or "Describe the image." - sample["messages"] = [ - { - "role": "user", - "content": [ - {"type": "image", "image": ""}, - {"type": "text", "text": text}, - ], - } - ] - - yielded += 1 - return sample - - for member in tf: - if not member.isfile(): - continue - name = member.name - base, ext = os.path.splitext(name) - ext = ext.lower() - if ext not in _IMG_EXTS and ext != ".json": - continue - f = tf.extractfile(member) - if f is None: - continue - data = f.read() - - if current_key is None: - current_key = base - if base != current_key: - out = _emit(current_key, bucket) - if out is not None: - yield out - if yielded >= self.num_samples: - return - bucket = {} - current_key = base - - if ext == ".json": - bucket["json"] = data - else: - bucket["img"] = data - - if current_key is not None: - out = _emit(current_key, bucket) - if out is not None: - yield out - if yielded >= self.num_samples: - return - - it = _raw_samples() - if self.shuffle_buffer_size and self.shuffle_buffer_size > 1: - it = _iter_shuffle_buffer(it, buffer_size=self.shuffle_buffer_size, seed=self.seed) - yield from itertools.islice(it, self.num_samples) - - class _NemotronTarPlusJsonlIterable(torch.utils.data.IterableDataset): """Join Nemotron VLM `media/shard_*.tar` (images-only) with `/.jsonl` (messages). From 3b7373d3f109e47fdfa3952dc92052501489366b Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 15:01:06 -0800 Subject: [PATCH 10/15] refactor to make hf_ptq cleaner, create a separate vlm dataset utils for Nemotron-VLM-Dataset-v2 Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 65 +----- examples/llm_ptq/nemotron_vl_calib.py | 89 ++++++++ .../torch/utils/nemotron_vlm_dataset_utils.py | 203 ++++++++++++++++++ modelopt/torch/utils/vlm_dataset_utils.py | 169 +-------------- 4 files changed, 297 insertions(+), 229 deletions(-) create mode 100644 examples/llm_ptq/nemotron_vl_calib.py create mode 100644 modelopt/torch/utils/nemotron_vlm_dataset_utils.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 41276d707..344ad74f2 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import contextlib import inspect import random import time @@ -34,6 +33,7 @@ is_nemotron_vl, run_nemotron_vl_preview, ) +from nemotron_vl_calib import safe_nemotron_vl_forward from torch.utils.data import DataLoader from transformers import ( AutoConfig, @@ -530,68 +530,7 @@ def calibrate_full_model(_model): else: call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys} call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None} - - # Nemotron VL v2 vision encoder expects pixel_values in its torch_dtype (typically bf16). - # The processor often returns float32 pixel_values; also, ModelOpt may wrap some vision - # layers with quant modules (even if disabled), which can be stricter about dtype. - vision_dtype = None - with contextlib.suppress(Exception): - vision_dtype = getattr( - full_model.vision_model.config, "torch_dtype", None - ) - if vision_dtype is None: - vision_dtype = getattr( - full_model.language_model.config, "torch_dtype", None - ) - if vision_dtype is not None and "pixel_values" in call_kwargs: - pv = call_kwargs["pixel_values"] - if torch.is_tensor(pv) and pv.dtype != vision_dtype: - call_kwargs["pixel_values"] = pv.to(dtype=vision_dtype) - # Avoid calling Nemotron VL `full_model.forward()` directly during calibration: - # - Some versions call `torch.distributed.get_rank()` unconditionally. - # - Some versions construct an output object that assumes `past_key_values` exists. - # Instead, reproduce the minimal forward needed to exercise both vision + language paths. - pixel_values = call_kwargs.get("pixel_values") - input_ids = call_kwargs.get("input_ids") - attention_mask = call_kwargs.get("attention_mask") - position_ids = call_kwargs.get("position_ids") - image_flags = call_kwargs.get("image_flags") - - if pixel_values is None or input_ids is None or image_flags is None: - continue - - inputs_embeds = full_model.language_model.get_input_embeddings()( - input_ids - ) - image_flags_s = image_flags.squeeze(-1) - - b, n, c = inputs_embeds.shape - flat_embeds = inputs_embeds.reshape(b * n, c) - flat_ids = input_ids.reshape(b * n) - selected = flat_ids == full_model.img_context_token_id - - vit_embeds = full_model.extract_feature(pixel_values) - vit_embeds = vit_embeds[image_flags_s == 1] - try: - flat_embeds[selected] = flat_embeds[ - selected - ] * 0.0 + vit_embeds.reshape(-1, c) - except Exception: - vit_embeds = vit_embeds.reshape(-1, c) - n_token = selected.sum() - flat_embeds[selected] = ( - flat_embeds[selected] * 0.0 + vit_embeds[:n_token] - ) - - inputs_embeds = flat_embeds.reshape(b, n, c) - - full_model.language_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False, - return_dict=False, - ) + safe_nemotron_vl_forward(full_model, call_kwargs) calibrate_loop = calibrate_full_model else: diff --git a/examples/llm_ptq/nemotron_vl_calib.py b/examples/llm_ptq/nemotron_vl_calib.py new file mode 100644 index 000000000..1efa3db5a --- /dev/null +++ b/examples/llm_ptq/nemotron_vl_calib.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nemotron VL calibration helpers. + +Nemotron Nano VL v2 remote-code wrapper `forward()` is not ideal to call during PTQ calibration because it may: +- Call `torch.distributed.get_rank()` unconditionally +- Assume `past_key_values` exists in the language model output + +Instead, we run a "safe multimodal forward" that exercises: +- Vision encoder feature extraction (C-RADIOv2-H) +- Insertion of vision embeddings into token embeddings at `img_context_token_id` +- Language model forward pass (to trigger quantizer calibration) +""" + +from __future__ import annotations + +import contextlib +from typing import Any + +import torch + + +def safe_nemotron_vl_forward(full_model: torch.nn.Module, batch: dict[str, Any]) -> None: + """Run a minimal multimodal forward for Nemotron VL that avoids wrapper output packaging.""" + pixel_values = batch.get("pixel_values") + input_ids = batch.get("input_ids") + attention_mask = batch.get("attention_mask") + position_ids = batch.get("position_ids") + image_flags = batch.get("image_flags") + + if pixel_values is None or input_ids is None or image_flags is None: + return + + # Match the model's preferred vision dtype (usually bf16). + vision_dtype = None + with contextlib.suppress(Exception): + vision_dtype = getattr(full_model.vision_model.config, "torch_dtype", None) + if vision_dtype is None: + with contextlib.suppress(Exception): + vision_dtype = getattr(full_model.language_model.config, "torch_dtype", None) + if ( + vision_dtype is not None + and torch.is_tensor(pixel_values) + and pixel_values.dtype != vision_dtype + ): + pixel_values = pixel_values.to(dtype=vision_dtype) + + # Token embeddings + inputs_embeds = full_model.language_model.get_input_embeddings()(input_ids) + image_flags_s = image_flags.squeeze(-1) + + b, n, c = inputs_embeds.shape + flat_embeds = inputs_embeds.reshape(b * n, c) + flat_ids = input_ids.reshape(b * n) + selected = flat_ids == full_model.img_context_token_id + + # Vision embeddings + vit_embeds = full_model.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags_s == 1] + try: + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c) + except Exception: + vit_embeds = vit_embeds.reshape(-1, c) + n_token = selected.sum() + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds[:n_token] + + inputs_embeds = flat_embeds.reshape(b, n, c) + + # LLM forward (drives activation stats) + full_model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=False, + ) diff --git a/modelopt/torch/utils/nemotron_vlm_dataset_utils.py b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py new file mode 100644 index 000000000..6a2dcdb44 --- /dev/null +++ b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nemotron VLM dataset utilities. + +This module contains the Nemotron-VLM-Dataset-v2 specific logic: +- Subsets can store images in `media/shard_*.tar` (images only) +- Prompts/messages live in `/.jsonl` and reference the image filename (e.g. `292180.png`) + +We join the tar images with the JSONL messages by the shared filename and yield samples compatible with our +VLM calibration pipeline. +""" + +from __future__ import annotations + +import functools +import json +import os +import random +import tarfile +from io import BytesIO +from typing import Any + +import torch + +_IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + + +@functools.lru_cache(maxsize=8) +def list_repo_files_cached(repo_id: str, repo_type: str = "dataset") -> list[str]: + from huggingface_hub import list_repo_files + + return list_repo_files(repo_id=repo_id, repo_type=repo_type) + + +def extract_first_image_from_messages(messages: Any) -> Any: + """Best-effort extraction of an image reference from Nemotron-style `messages`.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not (isinstance(part, dict) and part.get("type") == "image"): + continue + if "image" in part: + return part["image"] + # fallback + for key in ("images", "path", "image_url", "url", "value", "data"): + if key in part: + return part[key] + return None + + +class NemotronTarPlusJsonlIterable(torch.utils.data.IterableDataset): + """Join Nemotron VLM `media/shard_*.tar` (images-only) with `/.jsonl` (messages).""" + + def __init__( + self, + repo_id: str, + subsets: list[str], + shard_paths: list[str], + num_samples: int, + seed: int, + shuffle_buffer_size: int, + max_shards: int | None, + ): + super().__init__() + self.repo_id = repo_id + self.subsets = subsets + self.shard_paths = shard_paths + self.num_samples = num_samples + self.seed = seed + self.shuffle_buffer_size = shuffle_buffer_size + self.max_shards = max_shards + + def __iter__(self): + from huggingface_hub import hf_hub_download + from PIL import Image + + rng = random.Random(self.seed) + + # Partition shards by subset. + shards_by_subset: dict[str, list[str]] = {s: [] for s in self.subsets} + for p in self.shard_paths: + subset = p.split("/", 1)[0] + if subset in shards_by_subset: + shards_by_subset[subset].append(p) + + for subset in list(shards_by_subset.keys()): + shard_list = sorted(shards_by_subset[subset]) + if self.max_shards is not None: + shard_list = shard_list[: max(0, self.max_shards)] + shards_by_subset[subset] = shard_list + + # Roughly split sample budget across subsets. + per_subset_target = max(1, self.num_samples // max(1, len(self.subsets))) + yielded_total = 0 + + for subset in self.subsets: + if yielded_total >= self.num_samples: + break + + shard_list = list(shards_by_subset.get(subset, [])) + if not shard_list: + continue + rng.shuffle(shard_list) + + # 1) Collect candidate image filenames from tar headers (no payload reads). + candidate_names: list[str] = [] + header_limit = per_subset_target * 50 + for shard in shard_list: + local_tar = hf_hub_download(repo_id=self.repo_id, filename=shard, repo_type="dataset") + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if not member.isfile(): + continue + name = member.name + _, ext = os.path.splitext(name) + if ext.lower() not in _IMG_EXTS: + continue + candidate_names.append(name) + if len(candidate_names) >= header_limit: + break + if len(candidate_names) >= header_limit: + break + + if not candidate_names: + continue + + rng.shuffle(candidate_names) + lookup_limit = per_subset_target * 10 + candidate_set = set(candidate_names[:lookup_limit]) + + # 2) Scan JSONL to map image filename -> messages. + jsonl_path = hf_hub_download( + repo_id=self.repo_id, filename=f"{subset}/{subset}.jsonl", repo_type="dataset" + ) + meta_by_image: dict[str, dict[str, Any]] = {} + with open(jsonl_path, encoding="utf-8") as f: + for line in f: + try: + obj = json.loads(line) + except Exception: + continue + msgs = obj.get("messages") + img_name = extract_first_image_from_messages(msgs) if msgs is not None else None + if isinstance(img_name, str) and img_name in candidate_set: + meta_by_image[img_name] = {"id": obj.get("id"), "messages": msgs} + if len(meta_by_image) >= per_subset_target: + break + + if not meta_by_image: + continue + + # 3) Extract matched images and yield samples. + needed = set(meta_by_image.keys()) + for shard in shard_list: + if yielded_total >= self.num_samples or not needed: + break + local_tar = hf_hub_download(repo_id=self.repo_id, filename=shard, repo_type="dataset") + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if yielded_total >= self.num_samples or not needed: + break + if not member.isfile(): + continue + name = member.name + if name not in needed: + continue + f = tf.extractfile(member) + if f is None: + continue + try: + raw = f.read() + if isinstance(raw, str): + raw = raw.encode() + raw_bytes: bytes = raw + img = Image.open(BytesIO(raw_bytes)).convert("RGB") + except Exception: + continue + meta = meta_by_image.get(name) + if not meta: + continue + yield {"id": meta.get("id", name), "messages": meta.get("messages"), "image": img} + needed.discard(name) + yielded_total += 1 + diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 7dd2870d7..b964c5046 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -22,12 +22,7 @@ import contextlib import copy -import functools import itertools -import json -import os -import random -import tarfile from io import BytesIO from pathlib import Path from typing import Any @@ -52,14 +47,7 @@ __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] -_IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} - - -@functools.lru_cache(maxsize=8) -def _list_repo_files(repo_id: str, repo_type: str = "dataset") -> list[str]: - from huggingface_hub import list_repo_files - - return list_repo_files(repo_id=repo_id, repo_type=repo_type) +from .nemotron_vlm_dataset_utils import NemotronTarPlusJsonlIterable, list_repo_files_cached class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): @@ -77,157 +65,6 @@ def __len__(self): return self._num_samples -class _NemotronTarPlusJsonlIterable(torch.utils.data.IterableDataset): - """Join Nemotron VLM `media/shard_*.tar` (images-only) with `/.jsonl` (messages). - - Many Nemotron-VLM-Dataset-v2 subsets store PNGs in tar shards and store messages in a separate JSONL where the - image content item references the PNG filename (e.g., `{"type":"image","image":"292180.png"}`). - """ - - def __init__( - self, - repo_id: str, - subsets: list[str], - shard_paths: list[str], - num_samples: int, - seed: int, - shuffle_buffer_size: int, - max_shards: int | None, - ): - super().__init__() - self.repo_id = repo_id - self.subsets = subsets - self.shard_paths = shard_paths - self.num_samples = num_samples - self.seed = seed - self.shuffle_buffer_size = shuffle_buffer_size - self.max_shards = max_shards - - def __iter__(self): - from huggingface_hub import hf_hub_download - from PIL import Image - - rng = random.Random(self.seed) - - # Partition shards by subset. - shards_by_subset: dict[str, list[str]] = {s: [] for s in self.subsets} - for p in self.shard_paths: - subset = p.split("/", 1)[0] - if subset in shards_by_subset: - shards_by_subset[subset].append(p) - - for subset in list(shards_by_subset.keys()): - shard_list = sorted(shards_by_subset[subset]) - if self.max_shards is not None: - shard_list = shard_list[: max(0, self.max_shards)] - shards_by_subset[subset] = shard_list - - # Roughly split sample budget across subsets. - per_subset_target = max(1, self.num_samples // max(1, len(self.subsets))) - yielded_total = 0 - - for subset in self.subsets: - if yielded_total >= self.num_samples: - break - - shard_list = list(shards_by_subset.get(subset, [])) - if not shard_list: - continue - rng.shuffle(shard_list) - - # 1) Collect candidate image filenames from tar headers (no payload reads). - candidate_names: list[str] = [] - header_limit = per_subset_target * 50 - for shard in shard_list: - local_tar = hf_hub_download( - repo_id=self.repo_id, filename=shard, repo_type="dataset" - ) - with tarfile.open(local_tar, "r:*") as tf: - for member in tf: - if not member.isfile(): - continue - name = member.name - _, ext = os.path.splitext(name) - if ext.lower() not in _IMG_EXTS: - continue - candidate_names.append(name) - if len(candidate_names) >= header_limit: - break - if len(candidate_names) >= header_limit: - break - - if not candidate_names: - continue - - rng.shuffle(candidate_names) - lookup_limit = per_subset_target * 10 - candidate_set = set(candidate_names[:lookup_limit]) - - # 2) Scan JSONL to map image filename -> messages. - jsonl_path = hf_hub_download( - repo_id=self.repo_id, filename=f"{subset}/{subset}.jsonl", repo_type="dataset" - ) - meta_by_image: dict[str, dict[str, Any]] = {} - with open(jsonl_path, encoding="utf-8") as f: - for line in f: - try: - obj = json.loads(line) - except Exception: - continue - msgs = obj.get("messages") - img_name = ( - _extract_first_image_from_messages(msgs) if msgs is not None else None - ) - if isinstance(img_name, str) and img_name in candidate_set: - meta_by_image[img_name] = {"id": obj.get("id"), "messages": msgs} - if len(meta_by_image) >= per_subset_target: - break - - if not meta_by_image: - continue - - # 3) Extract matched images and yield samples. - needed = set(meta_by_image.keys()) - for shard in shard_list: - if yielded_total >= self.num_samples or not needed: - break - local_tar = hf_hub_download( - repo_id=self.repo_id, filename=shard, repo_type="dataset" - ) - with tarfile.open(local_tar, "r:*") as tf: - for member in tf: - if yielded_total >= self.num_samples or not needed: - break - if not member.isfile(): - continue - name = member.name - if name not in needed: - continue - f = tf.extractfile(member) - if f is None: - continue - try: - raw = f.read() - # Some tarfile stubs type `read()` as returning `str`; normalize to bytes for mypy. - if isinstance(raw, str): - raw = raw.encode() - # Help mypy: BytesIO expects a bytes-like buffer. - raw_bytes: bytes = raw - img = Image.open(BytesIO(raw_bytes)).convert("RGB") - except Exception: - continue - meta = meta_by_image.get(name) - if not meta: - continue - yield { - "id": meta.get("id", name), - "messages": meta.get("messages"), - "image": img, - } - needed.discard(name) - yielded_total += 1 - - def _extract_text_from_messages(messages: Any) -> str | None: """Best-effort extraction of a user text prompt from a chat-style `messages` field.""" if not isinstance(messages, list): @@ -409,7 +246,7 @@ def _get_vlm_dataset( # Prefer in-repo media tar shards when present. HF `datasets` streaming alone does not join media. if use_media_shards: - all_files = _list_repo_files(repo_id, repo_type="dataset") + all_files = list_repo_files_cached(repo_id, repo_type="dataset") shard_paths: list[str] = [] for subset in subsets: prefix = f"{subset}/media/" @@ -423,7 +260,7 @@ def _get_vlm_dataset( shard_paths = sorted(set(shard_paths)) if shard_paths: - return _NemotronTarPlusJsonlIterable( + return NemotronTarPlusJsonlIterable( repo_id=repo_id, subsets=subsets, shard_paths=shard_paths, From 5c774f9c84abdff79faaace09aaaa3166951cabd Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 15:03:56 -0800 Subject: [PATCH 11/15] refactor to make hf_ptq cleaner, create a separate vlm dataset utils for Nemotron-VLM-Dataset-v2 Signed-off-by: Zhiyu Cheng --- .../torch/utils/nemotron_vlm_dataset_utils.py | 32 ++++++++++++++++--- modelopt/torch/utils/vlm_dataset_utils.py | 3 +- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/utils/nemotron_vlm_dataset_utils.py b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py index 6a2dcdb44..36690d3b5 100644 --- a/modelopt/torch/utils/nemotron_vlm_dataset_utils.py +++ b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py @@ -40,6 +40,12 @@ @functools.lru_cache(maxsize=8) def list_repo_files_cached(repo_id: str, repo_type: str = "dataset") -> list[str]: + """List files in a HuggingFace repo (cached). + + Args: + repo_id: HF repo id (e.g., a dataset repo). + repo_type: HF repo type, usually "dataset" here. + """ from huggingface_hub import list_repo_files return list_repo_files(repo_id=repo_id, repo_type=repo_type) @@ -80,6 +86,17 @@ def __init__( shuffle_buffer_size: int, max_shards: int | None, ): + """Create an iterable dataset for Nemotron-VLM-Dataset-v2. + + Args: + repo_id: Dataset repo id, e.g. "nvidia/Nemotron-VLM-Dataset-v2". + subsets: Subset names to draw from (e.g., "sparsetables"). + shard_paths: Tar shard paths under `/media/`. + num_samples: Total number of samples to yield. + seed: RNG seed for sampling. + shuffle_buffer_size: Unused for now (kept for API compatibility). + max_shards: Max number of shards to use per subset (limits downloads). + """ super().__init__() self.repo_id = repo_id self.subsets = subsets @@ -125,7 +142,9 @@ def __iter__(self): candidate_names: list[str] = [] header_limit = per_subset_target * 50 for shard in shard_list: - local_tar = hf_hub_download(repo_id=self.repo_id, filename=shard, repo_type="dataset") + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) with tarfile.open(local_tar, "r:*") as tf: for member in tf: if not member.isfile(): @@ -173,7 +192,9 @@ def __iter__(self): for shard in shard_list: if yielded_total >= self.num_samples or not needed: break - local_tar = hf_hub_download(repo_id=self.repo_id, filename=shard, repo_type="dataset") + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) with tarfile.open(local_tar, "r:*") as tf: for member in tf: if yielded_total >= self.num_samples or not needed: @@ -197,7 +218,10 @@ def __iter__(self): meta = meta_by_image.get(name) if not meta: continue - yield {"id": meta.get("id", name), "messages": meta.get("messages"), "image": img} + yield { + "id": meta.get("id", name), + "messages": meta.get("messages"), + "image": img, + } needed.discard(name) yielded_total += 1 - diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index b964c5046..120a700a7 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -31,6 +31,7 @@ from torch.utils.data import DataLoader from .image_processor import MllamaImageProcessor +from .nemotron_vlm_dataset_utils import NemotronTarPlusJsonlIterable, list_repo_files_cached # Use dict to store the config for each dataset. # If we want to export more options to user like target languages, we need more standardized approach like dataclass. @@ -47,8 +48,6 @@ __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] -from .nemotron_vlm_dataset_utils import NemotronTarPlusJsonlIterable, list_repo_files_cached - class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): """Wrap a HF streaming IterableDataset to be compatible with torch DataLoader.""" From f2774fc640e2de870fff203af12c2c68e32e53b1 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 15:09:16 -0800 Subject: [PATCH 12/15] update readme Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 271b684f4..25850fdf1 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -328,7 +328,7 @@ python hf_ptq.py --pyt_ckpt_path --qformat fp8 --export #### VLM calibration with image-text pairs (e.g., Nemotron VL) -For vision-language models, calibration quality can improve by using image-text pairs instead of text-only data: +For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: ```bash python hf_ptq.py \ @@ -339,10 +339,11 @@ python hf_ptq.py \ --calib_with_images \ --vlm_dataset nemotron_vlm_dataset_v2 \ --vlm_subsets sparsetables,plotqa_cot,wiki_en \ - --calib_size 256 + --calib_size 512 ``` > Note: when `--calib_with_images` is set, `--calib_size` must be a single value. +This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. ### Hugging Face framework [Script](./scripts/huggingface_example.sh) From 59d97a60cbd65ad69e964983c03e865511747d3f Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 15:15:50 -0800 Subject: [PATCH 13/15] update readme Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/README.md | 39 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 25850fdf1..8018c5414 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -161,6 +161,26 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_ [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. +#### VLM calibration with image-text pairs (e.g., Nemotron VL) + +For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: + +```bash +python hf_ptq.py \ + --pyt_ckpt_path \ + --qformat nvfp4 \ + --export_path \ + --trust_remote_code \ + --calib_with_images \ + --vlm_dataset nemotron_vlm_dataset_v2 \ + --vlm_subsets sparsetables,plotqa_cot,wiki_en \ + --calib_size 512 +``` + +> Note: when `--calib_with_images` is set, `--calib_size` must be a single value. +This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. + + ### NeMo Example Script NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details. @@ -326,25 +346,6 @@ with torch.inference_mode(): python hf_ptq.py --pyt_ckpt_path --qformat fp8 --export_path --trust_remote_code ``` -#### VLM calibration with image-text pairs (e.g., Nemotron VL) - -For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: - -```bash -python hf_ptq.py \ - --pyt_ckpt_path \ - --qformat nvfp4 \ - --export_path \ - --trust_remote_code \ - --calib_with_images \ - --vlm_dataset nemotron_vlm_dataset_v2 \ - --vlm_subsets sparsetables,plotqa_cot,wiki_en \ - --calib_size 512 -``` - -> Note: when `--calib_with_images` is set, `--calib_size` must be a single value. -This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. - ### Hugging Face framework [Script](./scripts/huggingface_example.sh) Alternatively, the framework script `huggingface_example.sh` also supports quantize and export: From e2e59f65881ea05b4b68242461c2d6815d35f937 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 15:16:35 -0800 Subject: [PATCH 14/15] update readme Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 8018c5414..2cf8c0afb 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -180,7 +180,6 @@ python hf_ptq.py \ > Note: when `--calib_with_images` is set, `--calib_size` must be a single value. This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. - ### NeMo Example Script NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details. From 2a3868aa5f40ff62bdd26b8707c60ae385efd0a5 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Tue, 13 Jan 2026 15:48:59 -0800 Subject: [PATCH 15/15] minor refactor Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/nemotron_vl_calib.py | 11 ++++++++- modelopt/torch/utils/vlm_dataset_utils.py | 27 +---------------------- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/examples/llm_ptq/nemotron_vl_calib.py b/examples/llm_ptq/nemotron_vl_calib.py index 1efa3db5a..1b6187008 100644 --- a/examples/llm_ptq/nemotron_vl_calib.py +++ b/examples/llm_ptq/nemotron_vl_calib.py @@ -41,7 +41,16 @@ def safe_nemotron_vl_forward(full_model: torch.nn.Module, batch: dict[str, Any]) position_ids = batch.get("position_ids") image_flags = batch.get("image_flags") - if pixel_values is None or input_ids is None or image_flags is None: + if pixel_values is None or input_ids is None: + return + + # Nemotron Nano VL v2 expects `image_flags` in forward(), but the processor doesn't always emit it. + # `pixel_values` is flattened across batch*images, so `image_flags` should align with pixel_values.shape[0]. + if image_flags is None and torch.is_tensor(pixel_values): + image_flags = torch.ones( + (pixel_values.shape[0], 1), device=pixel_values.device, dtype=torch.long + ) + if image_flags is None: return # Match the model's preferred vision dtype (usually bf16). diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 120a700a7..841b82f65 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -104,22 +104,6 @@ def _messages_up_to_last_user(messages: Any) -> list[dict[str, Any]] | None: return [m for m in trimmed if isinstance(m, dict)] -def _messages_has_image(messages: Any) -> bool: - """Return True if `messages` contains an image content item.""" - if not isinstance(messages, list): - return False - for msg in messages: - if not isinstance(msg, dict): - continue - content = msg.get("content") - if not isinstance(content, list): - continue - for part in content: - if isinstance(part, dict) and part.get("type") == "image": - return True - return False - - def _extract_first_image_from_messages(messages: Any) -> Any: """Best-effort extraction of an image object from a chat-style `messages` field.""" if not isinstance(messages, list): @@ -307,7 +291,7 @@ def _get_vlm_dataset( ds = ds.filter( lambda ex: ex.get("image", None) is not None or ex.get("images", None) is not None - or _messages_has_image(ex.get("messages")) + or _extract_first_image_from_messages(ex.get("messages")) is not None ) # Select the first `num_samples` entries (or fewer if dataset is smaller). @@ -477,15 +461,6 @@ def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dic enc = enc.data out: dict[str, Any] = dict(enc) - # Nemotron Nano VL v2 expects `image_flags` in forward(), but the processor does not emit it. - # `pixel_values` is flattened across batch*images, so `image_flags` should align with pixel_values.shape[0]. - if out.get("pixel_values") is not None and out.get("image_flags") is None: - pv = out["pixel_values"] - if torch.is_tensor(pv): - out["image_flags"] = torch.ones( - (pv.shape[0], 1), device=pv.device, dtype=torch.long - ) - # Move tensors to device if requested. if device is not None: for k, v in list(out.items()):