From 19ad7289249ba50f3f2bb8d3c6b2086fdfd64097 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Oct 2025 15:46:51 +0200 Subject: [PATCH 01/78] fix: remove wrong trl imports --- optimum/neuron/utils/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 56de4eb08..04e367429 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -85,7 +85,6 @@ "patch_within_function", "replace_class_in_inheritance_hierarchy", ], - "trl_utils": ["NeuronSFTConfig", "NeuronORPOConfig"], } if TYPE_CHECKING: @@ -155,7 +154,6 @@ patch_within_function, replace_class_in_inheritance_hierarchy, ) - from .trl_utils import NeuronORPOConfig, NeuronSFTConfig else: import sys From 34f46985c92dbf12c9743f5aa2601be26868bad6 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Oct 2025 18:58:16 +0200 Subject: [PATCH 02/78] feat: align to latest trl release --- examples/training/llama/finetune_llama.py | 4 +- examples/training/qwen3/finetune_qwen3.py | 4 +- examples/training/qwen3/finetune_qwen3.sh | 3 +- optimum/neuron/trainers/sft_config.py | 17 +- optimum/neuron/trainers/sft_trainer.py | 231 +++++++++++----------- optimum/neuron/trainers/trl_utils.py | 2 +- tests/training/test_neuron_sft_trainer.py | 8 +- 7 files changed, 146 insertions(+), 123 deletions(-) diff --git a/examples/training/llama/finetune_llama.py b/examples/training/llama/finetune_llama.py index 8daaa5a59..fe4095d0e 100755 --- a/examples/training/llama/finetune_llama.py +++ b/examples/training/llama/finetune_llama.py @@ -80,7 +80,7 @@ def train(model_id, tokenizer, dataset, training_args): args = training_args.to_dict() sft_config = NeuronSFTConfig( - max_seq_length=2048, + max_length=2048, packing=True, **args, ) @@ -91,7 +91,7 @@ def train(model_id, tokenizer, dataset, training_args): args=sft_config, model=model, peft_config=lora_config, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dataset, formatting_func=lambda example: format_dolly(example, tokenizer), ) diff --git a/examples/training/qwen3/finetune_qwen3.py b/examples/training/qwen3/finetune_qwen3.py index f7a27bbb6..8e3b25712 100644 --- a/examples/training/qwen3/finetune_qwen3.py +++ b/examples/training/qwen3/finetune_qwen3.py @@ -84,7 +84,7 @@ def train(model_id, tokenizer, dataset, training_args): args = training_args.to_dict() sft_config = NeuronSFTConfig( - max_seq_length=4096, + max_length=4096, packing=True, **args, ) @@ -98,7 +98,7 @@ def formatting_function(examples): args=sft_config, model=model, peft_config=lora_config, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dataset, formatting_func=formatting_function, ) diff --git a/examples/training/qwen3/finetune_qwen3.sh b/examples/training/qwen3/finetune_qwen3.sh index d64a6572d..b2d7568e3 100755 --- a/examples/training/qwen3/finetune_qwen3.sh +++ b/examples/training/qwen3/finetune_qwen3.sh @@ -13,7 +13,8 @@ TP_DEGREE=8 BS=1 GRADIENT_ACCUMULATION_STEPS=8 LOGGING_STEPS=2 -MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name +# MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name +MODEL_NAME="Qwen/Qwen3-0.6B" # Change this to the desired model name OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-finetuned" DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE" SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) diff --git a/optimum/neuron/trainers/sft_config.py b/optimum/neuron/trainers/sft_config.py index 8c50d033f..89f54c7f6 100644 --- a/optimum/neuron/trainers/sft_config.py +++ b/optimum/neuron/trainers/sft_config.py @@ -10,10 +10,11 @@ # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# Seg the License for the specific language governing permissions and +# See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any from ..utils.import_utils import is_trl_available from .training_args import NeuronTrainingArguments @@ -32,4 +33,14 @@ def __init__(self, *args, **kwargs): @dataclass class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): - pass + + def __post_init__(self): + # Handle max_seq_length -> max_length migration for backward compatibility + if hasattr(self, "max_seq_length") and self.max_seq_length is not None: + if self.max_length == 1024: # 1024 is the default + self.max_length = self.max_seq_length + + # Force padding_free to False for Neuron - critical for avoiding recompilation + self.padding_free = False + + super().__post_init__() diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index b69e909d2..439b063b9 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -91,33 +91,38 @@ def __init__( data_collator: DataCollator | None = None, # type: ignore train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None, eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None, - processsing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + compute_loss_func: Callable | None = None, + compute_metrics: Callable | None = None, callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, - tokenizer: PreTrainedTokenizerBase | None = None, # deprecated + preprocess_logits_for_metrics: Callable | None = None, peft_config: PeftConfig | None = None, formatting_func: Callable | None = None, + # Deprecated parameters for backward compatibility + tokenizer: PreTrainedTokenizerBase | None = None, # Use processing_class instead ): if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronSFTTrainer requires trl=={TRL_VERSION}.") from trl.extras.dataset_formatting import get_formatting_func_from_dataset - # This will be changed to : from trl.trainer.callbacks import RichProgressCallback - from trl.trainer.utils import ( - DataCollatorForCompletionOnlyLM, - peft_module_casting_to_bf16, - ) + from trl.trainer.utils import peft_module_casting_to_bf16 if is_peft_available(): from peft import PeftConfig + # Handle backward compatibility for tokenizer parameter + if tokenizer is not None and processing_class is None: + processing_class = tokenizer + args_is_none = args is None if args is None: - output_dir = "tmp_trainer" - args = NeuronSFTConfig(output_dir=output_dir) + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = NeuronSFTConfig(f"{model_name}-SFT") elif args is not None and args.__class__.__name__ == "NeuronTrainingArguments": args_as_dict = args.to_dict() # Manually copy token values as TrainingArguments.to_dict() redacts them @@ -132,34 +137,30 @@ def __init__( if args_is_none: logging.warning(f"No `SFTConfig` passed, using `output_dir={args.output_dir}`.") - if getattr(args, "model_init_kwargs", None) is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." - ) - model_init_kwargs["torch_dtype"] = torch_dtype - + # Model handling - use model_init_kwargs from args + model_init_kwargs = args.model_init_kwargs or {} if isinstance(model, str): - logging.warning( - "You passed a model_id to the SFTTrainer. This will automatically create an " - "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." - ) - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + model_id = model + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype + else: + raise ValueError( + "Invalid `dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) - if args.packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): - raise ValueError( - "You passed a `DataCollatorForCompletionOnlyLM` to the NeuronSFTTrainer. This is not compatible with the `packing` argument." - ) if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): @@ -188,23 +189,26 @@ def make_inputs_require_grad(module, input, output): if args is not None and args.bf16: peft_module_casting_to_bf16(model) - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) - if getattr(tokenizer, "pad_token", None) is None: - tokenizer.pad_token = tokenizer.eos_token + # Processing class (tokenizer) handling + if processing_class is None: + from transformers import AutoProcessor + processing_class = AutoProcessor.from_pretrained(model_id) + + # Ensure we have a pad token + if hasattr(processing_class, 'pad_token') and getattr(processing_class, "pad_token", None) is None: + processing_class.pad_token = processing_class.eos_token - if args.max_seq_length is None: + # Handle max_length (renamed from max_seq_length) + if args.max_length is None: # to overcome some issues with broken tokenizers - args.max_seq_length = min(tokenizer.model_max_length, 1024) + args.max_length = min(processing_class.model_max_length, 1024) logger.warning( - f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {args.max_seq_length}" + f"You didn't pass a `max_length` argument to the SFTTrainer, this will default to {args.max_length}" ) self.dataset_num_proc = args.dataset_num_proc - self.dataset_batch_size = args.dataset_batch_size - self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") if args.dataset_kwargs is None: @@ -231,49 +235,39 @@ def make_inputs_require_grad(module, input, output): ) if data_collator is None: - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) # Pre-process the datasets only once per node. The remaining processes will use the cache. with NeuronPartialState().local_main_process_first(): if train_dataset is not None: train_dataset = self._prepare_dataset( train_dataset, - tokenizer, + processing_class, + args, args.packing, - args.dataset_text_field, - args.max_seq_length, formatting_func, - args.num_of_sequences, - args.chars_per_token, - remove_unused_columns=args.remove_unused_columns if args is not None else True, - **args.dataset_kwargs, + "train" ) if eval_dataset is not None: _multiple = isinstance(eval_dataset, dict) _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} - eval_packing = args.packing if args.eval_packing is None else args.eval_packing - for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): _eval_datasets[_eval_dataset_name] = self._prepare_dataset( _eval_dataset, - tokenizer, - eval_packing, - args.dataset_text_field, - args.max_seq_length, + processing_class, + args, + args.eval_packing if args.eval_packing is not None else args.packing, formatting_func, - args.num_of_sequences, - args.chars_per_token, - remove_unused_columns=args.remove_unused_columns if args is not None else True, - **args.dataset_kwargs, + _eval_dataset_name ) if not _multiple: eval_dataset = _eval_datasets["singleton"] - if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + if hasattr(processing_class, "padding_side") and processing_class.padding_side is not None and processing_class.padding_side != "right": logger.warning( - "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " - "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + "You passed a processing_class with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = \"right\"` to your code." ) NeuronTrainer.__init__( @@ -283,7 +277,7 @@ def make_inputs_require_grad(module, input, output): data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, - processing_class=tokenizer, + processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, @@ -313,62 +307,79 @@ def train( ): return NeuronTrainer.train(self, resume_from_checkpoint=resume_from_checkpoint) - def _prepare_non_packed_dataloader( + def _prepare_dataset( self, - tokenizer, dataset, - dataset_text_field, - max_seq_length, + processing_class, + args, + packing, formatting_func=None, - add_special_tokens=True, - remove_unused_columns=True, + dataset_name="train", ): - use_formatting_func = formatting_func is not None and dataset_text_field is None - self._dataset_sanity_checked = False - - # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt - def tokenize(element): - outputs = tokenizer( - element[dataset_text_field] if not use_formatting_func else formatting_func(element), - add_special_tokens=add_special_tokens, + """ + Prepare dataset for Neuron training with proper padding. + + This method overrides the base TRL implementation to ensure consistent padding + for Neuron devices, which require fixed input shapes to avoid recompilation. + """ + # For packing, delegate to parent implementation but ensure no padding_free + if packing: + # Temporarily disable padding_free for packing as well + original_padding_free = getattr(args, "padding_free", False) + args.padding_free = False + try: + result = super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) + finally: + args.padding_free = original_padding_free + return result + + # For non-packed datasets, use our custom implementation with forced padding + from datasets import Dataset, IterableDataset + + # Apply formatting function if provided + if formatting_func is not None: + if isinstance(dataset, Dataset): + dataset = dataset.map( + lambda example: {"text": formatting_func(example)}, + num_proc=args.dataset_num_proc, + desc=f"Applying formatting function to {dataset_name} dataset" + ) + else: # IterableDataset + dataset = dataset.map(lambda example: {"text": formatting_func(example)}) + + # Tokenization function with forced padding for Neuron + def tokenize(examples): + # Handle both single examples and batches + if isinstance(examples[args.dataset_text_field], list): + texts = examples[args.dataset_text_field] + else: + texts = [examples[args.dataset_text_field]] + + outputs = processing_class( + texts, + add_special_tokens=True, truncation=True, - # For Neuron we need to pad because otherwise it will trigger compilation for each new sequence length. + # Critical for Neuron: always pad to max_length to avoid recompilation padding="max_length", - max_length=max_seq_length, + max_length=args.max_length, return_overflowing_tokens=False, return_length=False, ) - if use_formatting_func and not self._dataset_sanity_checked: - if not isinstance(formatting_func(element), list): - raise ValueError( - "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." - ) - else: - self._dataset_sanity_checked = True - - return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} - - signature_columns = ["input_ids", "labels", "attention_mask"] - - if dataset.column_names is not None: # None for IterableDataset - extra_columns = list(set(dataset.column_names) - set(signature_columns)) - else: - extra_columns = [] - - if not remove_unused_columns and len(extra_columns) > 0: - logger.warning( - "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to " - f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." - ) + return { + "input_ids": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + "labels": outputs["input_ids"].copy() # For language modeling + } + # Build map kwargs map_kwargs = { "batched": True, - "remove_columns": dataset.column_names if remove_unused_columns else None, - "batch_size": self.dataset_batch_size, + "remove_columns": dataset.column_names if hasattr(dataset, "column_names") and dataset.column_names else None, } - if isinstance(dataset, datasets.Dataset): - map_kwargs["num_proc"] = self.dataset_num_proc # this arg is not available for IterableDataset - tokenized_dataset = dataset.map(tokenize, **map_kwargs) + if isinstance(dataset, Dataset): + map_kwargs["num_proc"] = args.dataset_num_proc + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + tokenized_dataset = dataset.map(tokenize, **map_kwargs) return tokenized_dataset diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index dac42fb54..2a5d2af28 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -TRL_VERSION = "0.11.4" +TRL_VERSION = "0.23.1" diff --git a/tests/training/test_neuron_sft_trainer.py b/tests/training/test_neuron_sft_trainer.py index dc4f9d15e..4be2a1a80 100644 --- a/tests/training/test_neuron_sft_trainer.py +++ b/tests/training/test_neuron_sft_trainer.py @@ -77,7 +77,7 @@ def format_dolly(sample): args = args.to_dict() sft_config = NeuronSFTConfig( # Using a small sequence-length since we are not validating the outputs. - max_seq_length=128, + max_length=128, packing=packing, dataset_num_proc=1, **args, @@ -86,7 +86,7 @@ def format_dolly(sample): # Create Trainer instance trainer = NeuronSFTTrainer( model=model, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dataset, formatting_func=format_dolly, args=sft_config, @@ -172,7 +172,7 @@ def format_dolly(sample): args = args.to_dict() sft_config = NeuronSFTConfig( - max_seq_length=128, + max_length=128, packing=False, # No packing for PEFT test simplicity dataset_num_proc=1, **args, @@ -181,7 +181,7 @@ def format_dolly(sample): # Create SFT Trainer instance with PEFT model trainer = NeuronSFTTrainer( model=base_model, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dataset, formatting_func=format_dolly, args=sft_config, From 07437bcc0afd243cbeff546c87bd4147112c2b9e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Oct 2025 18:58:59 +0200 Subject: [PATCH 03/78] chore: update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0db101f36..09e66972d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ quality = [ "isort", ] training = [ - "trl == 0.11.4", + "trl == 0.23.1", "peft == 0.17.0", "evaluate == 0.4.3", ] From e5256bf18b08ae49c5cdd0d5696d1f9a4bab3482 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 16 Oct 2025 00:58:39 +0200 Subject: [PATCH 04/78] style --- optimum/neuron/trainers/sft_config.py | 4 +-- optimum/neuron/trainers/sft_trainer.py | 40 ++++++++++++++------------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/optimum/neuron/trainers/sft_config.py b/optimum/neuron/trainers/sft_config.py index 89f54c7f6..b33763e72 100644 --- a/optimum/neuron/trainers/sft_config.py +++ b/optimum/neuron/trainers/sft_config.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass from ..utils.import_utils import is_trl_available from .training_args import NeuronTrainingArguments @@ -33,7 +32,6 @@ def __init__(self, *args, **kwargs): @dataclass class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): - def __post_init__(self): # Handle max_seq_length -> max_length migration for backward compatibility if hasattr(self, "max_seq_length") and self.max_seq_length is not None: diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index 439b063b9..ba40b6768 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -21,7 +21,6 @@ from torch.utils.data import Dataset, IterableDataset from transformers import ( AutoModelForCausalLM, - AutoTokenizer, DataCollator, DataCollatorForLanguageModeling, PreTrainedModel, @@ -107,7 +106,6 @@ def __init__( raise RuntimeError(f"Using NeuronSFTTrainer requires trl=={TRL_VERSION}.") from trl.extras.dataset_formatting import get_formatting_func_from_dataset - from trl.trainer.callbacks import RichProgressCallback from trl.trainer.utils import peft_module_casting_to_bf16 @@ -161,7 +159,6 @@ def __init__( "The `model_init_kwargs` will be ignored." ) - if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -192,10 +189,11 @@ def make_inputs_require_grad(module, input, output): # Processing class (tokenizer) handling if processing_class is None: from transformers import AutoProcessor + processing_class = AutoProcessor.from_pretrained(model_id) # Ensure we have a pad token - if hasattr(processing_class, 'pad_token') and getattr(processing_class, "pad_token", None) is None: + if hasattr(processing_class, "pad_token") and getattr(processing_class, "pad_token", None) is None: processing_class.pad_token = processing_class.eos_token # Handle max_length (renamed from max_seq_length) @@ -211,6 +209,9 @@ def make_inputs_require_grad(module, input, output): self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + # Vision Language Model (VLM) support - not yet supported in Neuron + self._is_vlm = False + if args.dataset_kwargs is None: args.dataset_kwargs = {} @@ -241,12 +242,7 @@ def make_inputs_require_grad(module, input, output): with NeuronPartialState().local_main_process_first(): if train_dataset is not None: train_dataset = self._prepare_dataset( - train_dataset, - processing_class, - args, - args.packing, - formatting_func, - "train" + train_dataset, processing_class, args, args.packing, formatting_func, "train" ) if eval_dataset is not None: _multiple = isinstance(eval_dataset, dict) @@ -259,15 +255,19 @@ def make_inputs_require_grad(module, input, output): args, args.eval_packing if args.eval_packing is not None else args.packing, formatting_func, - _eval_dataset_name + _eval_dataset_name, ) if not _multiple: eval_dataset = _eval_datasets["singleton"] - if hasattr(processing_class, "padding_side") and processing_class.padding_side is not None and processing_class.padding_side != "right": + if ( + hasattr(processing_class, "padding_side") + and processing_class.padding_side is not None + and processing_class.padding_side != "right" + ): logger.warning( "You passed a processing_class with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " - "overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = \"right\"` to your code." + 'overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = "right"` to your code.' ) NeuronTrainer.__init__( @@ -328,13 +328,15 @@ def _prepare_dataset( original_padding_free = getattr(args, "padding_free", False) args.padding_free = False try: - result = super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) + result = super()._prepare_dataset( + dataset, processing_class, args, packing, formatting_func, dataset_name + ) finally: args.padding_free = original_padding_free return result # For non-packed datasets, use our custom implementation with forced padding - from datasets import Dataset, IterableDataset + from datasets import Dataset # Apply formatting function if provided if formatting_func is not None: @@ -342,7 +344,7 @@ def _prepare_dataset( dataset = dataset.map( lambda example: {"text": formatting_func(example)}, num_proc=args.dataset_num_proc, - desc=f"Applying formatting function to {dataset_name} dataset" + desc=f"Applying formatting function to {dataset_name} dataset", ) else: # IterableDataset dataset = dataset.map(lambda example: {"text": formatting_func(example)}) @@ -369,13 +371,15 @@ def tokenize(examples): return { "input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"], - "labels": outputs["input_ids"].copy() # For language modeling + "labels": outputs["input_ids"].copy(), # For language modeling } # Build map kwargs map_kwargs = { "batched": True, - "remove_columns": dataset.column_names if hasattr(dataset, "column_names") and dataset.column_names else None, + "remove_columns": dataset.column_names + if hasattr(dataset, "column_names") and dataset.column_names + else None, } if isinstance(dataset, Dataset): map_kwargs["num_proc"] = args.dataset_num_proc From 954cfdfd316ec6c96dbcf060451c5f56227f1bfc Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 16 Oct 2025 01:34:04 +0200 Subject: [PATCH 05/78] feat: sync with SFTTrainer --- optimum/neuron/trainers/sft_trainer.py | 19 +++++++++++++ optimum/neuron/trainers/transformers.py | 36 ++++++++++++++++--------- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index ba40b6768..f31e59021 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -307,6 +307,25 @@ def train( ): return NeuronTrainer.train(self, resume_from_checkpoint=resume_from_checkpoint) + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss. + + This method overrides the TRL SFTTrainer's compute_loss to disable unsupported + metrics computation (entropy, token accuracy) for Neuron compatibility. + """ + # Set use_cache to False to avoid warnings with gradient checkpointing + inputs["use_cache"] = False + + # Call the parent NeuronTrainer's compute_loss method (not TRL's) + return NeuronTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch) + + def training_step( + self, model: torch.nn.Module, inputs: dict[str, Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + # We do not use the SFTTrainer.training_step because it checks for an attribute the NeuronTrainer doesn't have. + return NeuronTrainer.training_step(self, model, inputs, num_items_in_batch=num_items_in_batch) + def _prepare_dataset( self, dataset, diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index 8e9ddcbf8..cfd009887 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -936,26 +936,29 @@ def get_batch_samples( return batch_samples, num_items_in_batch - def train_step( - self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch: int | torch.Tensor | None = None - ) -> torch.Tensor: - manager = self.autocast_smart_context_manager() - + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs: bool = False, + num_items_in_batch: torch.Tensor | None = None, + ): if isinstance(model, NxDPPModel): - with manager: - loss = model.run_train(**inputs) + loss = model.run_train(**inputs) # When using pipeline parallelism, the loss is only computed on the last stage. # So we set the loss to zero on other stages. if self.pp_rank != self.pp_size - 1: dtype = torch.bfloat16 if self.args.bf16 else torch.float32 loss = torch.tensor(0, dtype=dtype).to(xm.xla_device()) + + # PP does not return any outputs except the loss + outputs = {"loss": loss} else: if num_items_in_batch is not None: inputs = dict(**inputs, reduction="sum") - with manager: - outputs = model(**inputs) + outputs = model(**inputs) if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( @@ -970,8 +973,17 @@ def train_step( else: loss = loss / self.args.gradient_accumulation_steps - # Backward pass - self.accelerator.backward(loss) + return (loss, outputs) if return_outputs else loss + + def training_step( + self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch: int | torch.Tensor | None = None + ) -> torch.Tensor: + manager = self.autocast_smart_context_manager() + with manager: + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + # Backward pass + self.accelerator.backward(loss) return loss @@ -1102,7 +1114,7 @@ def train( if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - loss_step = self.train_step(self.model, inputs, num_items_in_batch=num_items_in_batch) + loss_step = self.training_step(self.model, inputs, num_items_in_batch=num_items_in_batch) self.running_loss += loss_step.detach() if do_sync_step: From 3c72216613dd55ff5b7022de015067ec42888ee6 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 31 Oct 2025 14:44:43 +0100 Subject: [PATCH 06/78] fix: minor issues --- optimum/neuron/trainers/sft_config.py | 3 +-- optimum/neuron/trainers/sft_trainer.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/trainers/sft_config.py b/optimum/neuron/trainers/sft_config.py index b33763e72..ee31c067d 100644 --- a/optimum/neuron/trainers/sft_config.py +++ b/optimum/neuron/trainers/sft_config.py @@ -35,8 +35,7 @@ class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): def __post_init__(self): # Handle max_seq_length -> max_length migration for backward compatibility if hasattr(self, "max_seq_length") and self.max_seq_length is not None: - if self.max_length == 1024: # 1024 is the default - self.max_length = self.max_seq_length + self.max_length = self.max_seq_length # Force padding_free to False for Neuron - critical for avoiding recompilation self.padding_free = False diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index aa2ea18d5..cb2929f3f 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -212,9 +212,8 @@ def make_inputs_require_grad(module, input, output): if hasattr(processing_class, "pad_token") and getattr(processing_class, "pad_token", None) is None: processing_class.pad_token = processing_class.eos_token - # Handle max_length (renamed from max_seq_length) if args.max_length is None: - # to overcome some issues with broken tokenizers + # To overcome some issues with broken tokenizers args.max_length = min(processing_class.model_max_length, 1024) logger.warning( @@ -223,7 +222,8 @@ def make_inputs_require_grad(module, input, output): self.dataset_num_proc = args.dataset_num_proc - self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + # We do not support NeFTune with NeuronSFTTrainer for now. + self._trainer_supports_neftune = False # Vision Language Model (VLM) support - not yet supported in Neuron self._is_vlm = False From d286f5011219ac5da9c5973851e8683a2eca4db8 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 31 Oct 2025 15:07:34 +0100 Subject: [PATCH 07/78] chore: sync with trl==0.24.0 --- optimum/neuron/trainers/sft_config.py | 12 ++ optimum/neuron/trainers/sft_trainer.py | 167 +++++++++++++------------ optimum/neuron/trainers/trl_utils.py | 2 +- 3 files changed, 102 insertions(+), 79 deletions(-) diff --git a/optimum/neuron/trainers/sft_config.py b/optimum/neuron/trainers/sft_config.py index ee31c067d..45cb07b3e 100644 --- a/optimum/neuron/trainers/sft_config.py +++ b/optimum/neuron/trainers/sft_config.py @@ -32,12 +32,24 @@ def __init__(self, *args, **kwargs): @dataclass class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): + """ + Configuration class for Neuron-optimized SFT training. + + Inherits from both NeuronTrainingArguments (for Trainium-specific settings) and + trl's SFTConfig (for SFT-specific settings). + + Key Neuron-specific behavior: + - padding_free is always set to False to avoid recompilation on Trainium devices + - All other SFT parameters from trl 0.24.0+ are supported + """ + def __post_init__(self): # Handle max_seq_length -> max_length migration for backward compatibility if hasattr(self, "max_seq_length") and self.max_seq_length is not None: self.max_length = self.max_seq_length # Force padding_free to False for Neuron - critical for avoiding recompilation + # Neuron devices require fixed input shapes; padding_free flattening breaks this requirement self.padding_free = False super().__post_init__() diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index cb2929f3f..29b9b6b2a 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Callable import datasets @@ -74,13 +75,18 @@ class PeftConfig: class NeuronSFTTrainer(_SFTTrainer): """ - `SFTTrainer` adapted for Neuron. - - It differs from the original `SFTTrainer` by: - - Using `_TrainerForNeuron.__init__()` instead of `Trainer.__init__()` - - Using the `_TrainerForNeuron.train()` instead of `Trainer.train()` - - Adapts the `_prepare_non_packed_dataloader` to pad to max length. In the original `SFTTrainer` examples are - not padded, which is an issue here because it triggers compilation every time. + `SFTTrainer` adapted for Neuron (Trainium) devices. + + Overrides key methods for Neuron compatibility: + - Uses NeuronTrainer.__init__() instead of transformers.Trainer.__init__() + - Uses NeuronTrainer.train() for Neuron-optimized training + - Enforces padding_free=False for fixed input shapes (required for Trainium) + - Simplifies _prepare_dataset to delegate to parent with Neuron constraints + + Neuron-specific constraints: + - padding_free is always False to avoid recompilation + - VLM training is not yet supported + - NeFTune training is not supported """ def __init__( @@ -175,6 +181,31 @@ def __init__( "The `model_init_kwargs` will be ignored." ) + # Chat template handling (trl 0.24.0+) + # This allows users to provide a custom chat template via path or directory + if hasattr(args, 'chat_template_path') and args.chat_template_path is not None: + from trl.models import clone_chat_template + + if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): + # Load Jinja template directly + with open(args.chat_template_path, encoding="utf-8") as chat_template_file: + processing_class.chat_template = chat_template_file.read() + added_tokens = [] + else: + # Clone template from another model + try: + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) + except Exception as e: + logger.warning( + f"Failed to clone chat template from {args.chat_template_path}: {e}. " + "Continuing without custom chat template." + ) + added_tokens = [] + else: + added_tokens = [] + if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -251,8 +282,30 @@ def make_inputs_require_grad(module, input, output): "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." ) + # Data collator creation with Neuron-specific constraints if data_collator is None: - data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) + # Determine if this is a VLM (vision language model) + is_vlm = isinstance(processing_class, ProcessorMixin) and hasattr(processing_class, 'image_processor') + + if is_vlm: + # VLM support is not yet implemented in Neuron + logger.warning( + "Vision Language Model (VLM) detected. VLM training is not yet fully supported in Neuron. " + "Attempting to use standard language modeling collator." + ) + # For now, use standard collator - user can override if needed + data_collator = DataCollatorForLanguageModeling( + tokenizer=processing_class.tokenizer if hasattr(processing_class, 'tokenizer') else processing_class, + mlm=False, + ) + else: + # Standard language modeling collator + data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) + + # Ensure padding_free is False - critical Neuron requirement + # (this is already done in NeuronSFTConfig.__post_init__, but double-check) + if hasattr(data_collator, 'padding_free'): + data_collator.padding_free = False # Pre-process the datasets only once per node. The remaining processes will use the cache. with NeuronPartialState().local_main_process_first(): @@ -325,10 +378,10 @@ def train( def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ - Compute training loss. + Compute training loss for Neuron-optimized training. - This method overrides the TRL SFTTrainer's compute_loss to disable unsupported - metrics computation (entropy, token accuracy) for Neuron compatibility. + Overrides TRL SFTTrainer's compute_loss to set use_cache=False for gradient + checkpointing compatibility and delegate to NeuronTrainer's compute_loss. """ # Set use_cache to False to avoid warnings with gradient checkpointing inputs["use_cache"] = False @@ -339,7 +392,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N def training_step( self, model: torch.nn.Module, inputs: dict[str, Any], num_items_in_batch: int | None = None ) -> torch.Tensor: - # We do not use the SFTTrainer.training_step because it checks for an attribute the NeuronTrainer doesn't have. + """ + Perform a training step for Neuron-optimized training. + + Overrides SFTTrainer.training_step to delegate to NeuronTrainer's implementation, + which is compatible with Neuron's distributed training setup. + """ return NeuronTrainer.training_step(self, model, inputs, num_items_in_batch=num_items_in_batch) def _prepare_dataset( @@ -352,73 +410,26 @@ def _prepare_dataset( dataset_name="train", ): """ - Prepare dataset for Neuron training with proper padding. + Prepare dataset for Neuron training. + + Delegates to parent SFTTrainer._prepare_dataset, which handles: + - Dataset type detection (language modeling, prompt-completion, conversational) + - Chat template application + - Tokenization + - Packing (if enabled) - This method overrides the base TRL implementation to ensure consistent padding - for Neuron devices, which require fixed input shapes to avoid recompilation. + Neuron-specific behavior: + - Ensures padding_free=False to avoid recompilation + - Enforces padding to max_length for fixed input shapes """ - # For packing, delegate to parent implementation but ensure no padding_free - if packing: - # Temporarily disable padding_free for packing as well - original_padding_free = getattr(args, "padding_free", False) - args.padding_free = False - try: - result = super()._prepare_dataset( - dataset, processing_class, args, packing, formatting_func, dataset_name - ) - finally: - args.padding_free = original_padding_free - return result - - # For non-packed datasets, use our custom implementation with forced padding - from datasets import Dataset - - # Apply formatting function if provided - if formatting_func is not None: - if isinstance(dataset, Dataset): - dataset = dataset.map( - lambda example: {"text": formatting_func(example)}, - num_proc=args.dataset_num_proc, - desc=f"Applying formatting function to {dataset_name} dataset", - ) - else: # IterableDataset - dataset = dataset.map(lambda example: {"text": formatting_func(example)}) - - # Tokenization function with forced padding for Neuron - def tokenize(examples): - # Handle both single examples and batches - if isinstance(examples[args.dataset_text_field], list): - texts = examples[args.dataset_text_field] - else: - texts = [examples[args.dataset_text_field]] - - outputs = processing_class( - texts, - add_special_tokens=True, - truncation=True, - # Critical for Neuron: always pad to max_length to avoid recompilation - padding="max_length", - max_length=args.max_length, - return_overflowing_tokens=False, - return_length=False, + # Ensure padding_free is disabled for Neuron - this is critical for Trainium devices + if args.padding_free: + raise ValueError( + "padding_free must be False for Neuron training. " + "Neuron devices require fixed input shapes to avoid recompilation." ) - return { - "input_ids": outputs["input_ids"], - "attention_mask": outputs["attention_mask"], - "labels": outputs["input_ids"].copy(), # For language modeling - } - - # Build map kwargs - map_kwargs = { - "batched": True, - "remove_columns": dataset.column_names - if hasattr(dataset, "column_names") and dataset.column_names - else None, - } - if isinstance(dataset, Dataset): - map_kwargs["num_proc"] = args.dataset_num_proc - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - tokenized_dataset = dataset.map(tokenize, **map_kwargs) - return tokenized_dataset + # Call parent implementation from SFTTrainer + return super()._prepare_dataset( + dataset, processing_class, args, packing, formatting_func, dataset_name + ) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 2a5d2af28..4046dd93b 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -TRL_VERSION = "0.23.1" +TRL_VERSION = "0.24.0" From 09927381cbd5bf79585f3e7798520c728864bbcb Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 31 Oct 2025 15:59:14 +0100 Subject: [PATCH 08/78] chore: sync sft_trainer --- optimum/neuron/trainers/sft_trainer.py | 101 ++++++++++++++++++------- 1 file changed, 73 insertions(+), 28 deletions(-) diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index 29b9b6b2a..6edf01990 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -23,7 +23,6 @@ from transformers import ( AutoModelForCausalLM, DataCollator, - DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, @@ -43,6 +42,7 @@ if is_trl_available(): from trl import SFTConfig, SFTTrainer + from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling else: class SFTTrainer: @@ -51,6 +51,12 @@ class SFTTrainer: class SFTConfig: pass + class DataCollatorForLanguageModeling: + pass + + class DataCollatorForVisionLanguageModeling: + pass + if is_peft_available(): from peft import PeftConfig @@ -183,7 +189,7 @@ def __init__( # Chat template handling (trl 0.24.0+) # This allows users to provide a custom chat template via path or directory - if hasattr(args, 'chat_template_path') and args.chat_template_path is not None: + if hasattr(args, "chat_template_path") and args.chat_template_path is not None: from trl.models import clone_chat_template if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): @@ -213,6 +219,30 @@ def __init__( f" and you passed a {type(peft_config)}." ) + # Handle added tokens from chat template + if added_tokens: + # Ensure that the added tokens are trainable + if peft_config.trainable_token_indices is None: + peft_config.trainable_token_indices = {"embed_tokens": added_tokens} + elif "embed_tokens" not in peft_config.trainable_token_indices: + peft_config.trainable_token_indices["embed_tokens"] = added_tokens + else: + peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + + # Ensure that the lm_head is trainable + if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: + logger.warning( + "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " + "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " + "tokens, leading to degraded generation quality. To fix this, add " + "`modules_to_save=['lm_head']` to your PEFT configuration." + ) + + if peft_config.modules_to_save is None: + peft_config.modules_to_save = ["lm_head"] + else: + peft_config.modules_to_save.append("lm_head") + if not isinstance(model, NeuronPeftModel): gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} if getattr(args, "gradient_checkpointing", False) and ( @@ -256,8 +286,23 @@ def make_inputs_require_grad(module, input, output): # We do not support NeFTune with NeuronSFTTrainer for now. self._trainer_supports_neftune = False - # Vision Language Model (VLM) support - not yet supported in Neuron - self._is_vlm = False + # Determine VLM type based on processing_class + # This must be done before data collator creation + if processing_class is None: + from transformers import AutoProcessor + + processing_class = AutoProcessor.from_pretrained(model_id) + + if isinstance(processing_class, ProcessorMixin): + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + # Initialize _is_vision_dataset - will be set to True if dataset contains 'image' or 'images' keys + # This is needed for trl 0.24.0's _set_signature_columns_if_needed method + self._is_vision_dataset = False if args.dataset_kwargs is None: args.dataset_kwargs = {} @@ -282,29 +327,22 @@ def make_inputs_require_grad(module, input, output): "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." ) - # Data collator creation with Neuron-specific constraints - if data_collator is None: - # Determine if this is a VLM (vision language model) - is_vlm = isinstance(processing_class, ProcessorMixin) and hasattr(processing_class, 'image_processor') - - if is_vlm: - # VLM support is not yet implemented in Neuron - logger.warning( - "Vision Language Model (VLM) detected. VLM training is not yet fully supported in Neuron. " - "Attempting to use standard language modeling collator." - ) - # For now, use standard collator - user can override if needed - data_collator = DataCollatorForLanguageModeling( - tokenizer=processing_class.tokenizer if hasattr(processing_class, 'tokenizer') else processing_class, - mlm=False, - ) + # Inspect dataset to determine dataset type and completion_only_loss + if train_dataset is not None: + dataset_sample = next(iter(train_dataset)) + if args.completion_only_loss is None: + self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample else: - # Standard language modeling collator - data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False) + self.completion_only_loss = args.completion_only_loss + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + else: + self.completion_only_loss = False + self._is_vision_dataset = False - # Ensure padding_free is False - critical Neuron requirement - # (this is already done in NeuronSFTConfig.__post_init__, but double-check) - if hasattr(data_collator, 'padding_free'): + # Data collator creation with Neuron-specific constraints + # We delegate to parent SFTTrainer to create the proper data collator + # If user provides data_collator, ensure padding_free is False for Neuron + if data_collator is not None and hasattr(data_collator, "padding_free"): data_collator.padding_free = False # Pre-process the datasets only once per node. The remaining processes will use the cache. @@ -339,6 +377,15 @@ def make_inputs_require_grad(module, input, output): 'overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = "right"` to your code.' ) + # Detect if this is a vision dataset + if train_dataset is not None: + try: + dataset_sample = next(iter(train_dataset)) + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + except (StopIteration, KeyError): + # Empty dataset or no vision keys + self._is_vision_dataset = False + NeuronTrainer.__init__( self, model, @@ -430,6 +477,4 @@ def _prepare_dataset( ) # Call parent implementation from SFTTrainer - return super()._prepare_dataset( - dataset, processing_class, args, packing, formatting_func, dataset_name - ) + return super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) From 5a847ec1028ac51ac86c05ca66d9f5bf79b12796 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 31 Oct 2025 16:39:50 +0100 Subject: [PATCH 09/78] chore: sync sft_trainer --- optimum/neuron/trainers/sft_trainer.py | 144 +++++++++++------------- optimum/neuron/trainers/transformers.py | 1 + 2 files changed, 64 insertions(+), 81 deletions(-) diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index 6edf01990..8de5c86d5 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -99,7 +99,7 @@ def __init__( self, model: PreTrainedModel | torch.nn.Module | str, args: SFTConfig | None = None, - data_collator: DataCollator | None = None, # type: ignore + data_collator: DataCollator | None = None, train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None, eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, @@ -111,8 +111,7 @@ def __init__( preprocess_logits_for_metrics: Callable | None = None, peft_config: PeftConfig | None = None, formatting_func: Callable | None = None, - # Deprecated parameters for backward compatibility - tokenizer: PreTrainedTokenizerBase | None = None, # Use processing_class instead + tokenizer: PreTrainedTokenizerBase | None = None, ): if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronSFTTrainer requires trl=={TRL_VERSION}.") @@ -124,7 +123,6 @@ def __init__( if is_peft_available(): from peft import PeftConfig - # Handle backward compatibility for tokenizer parameter if tokenizer is not None and processing_class is None: processing_class = tokenizer @@ -135,19 +133,15 @@ def __init__( args = NeuronSFTConfig(f"{model_name}-SFT") elif args is not None and args.__class__.__name__ == "NeuronTrainingArguments": args_as_dict = args.to_dict() - # Manually copy token values as TrainingArguments.to_dict() redacts them args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) args = NeuronSFTConfig(**args_as_dict) - # Set the correct log level depending on the node log_level = args.get_process_log_level() logging.set_verbosity(log_level) - # We wait for the verbosity of the logger to be set before logging the warning below. if args_is_none: logging.warning(f"No `SFTConfig` passed, using `output_dir={args.output_dir}`.") - # Model handling - use model_init_kwargs from args if args.model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): @@ -156,7 +150,6 @@ def __init__( model_init_kwargs = args.model_init_kwargs torch_dtype = model_init_kwargs.get("dtype") if torch_dtype is not None: - # Convert to `torch.dtype` if an str is passed if isinstance(torch_dtype, str) and torch_dtype != "auto": torch_dtype = getattr(torch, torch_dtype) if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): @@ -169,7 +162,7 @@ def __init__( model_id = model dtype = model_init_kwargs.get("dtype") if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None + pass elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: dtype = getattr(torch, dtype) model_init_kwargs["dtype"] = dtype @@ -187,18 +180,38 @@ def __init__( "The `model_init_kwargs` will be ignored." ) - # Chat template handling (trl 0.24.0+) - # This allows users to provide a custom chat template via path or directory + if processing_class is None: + from transformers import AutoProcessor + processing_class = AutoProcessor.from_pretrained(model_id) + + if isinstance(processing_class, ProcessorMixin): + self._is_vlm = True + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + self._is_vlm = False + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if hasattr(args, "eos_token") and args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) + if eos_token_id is None: + raise ValueError( + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." + ) + tokenizer.eos_token_id = eos_token_id + if hasattr(args, "chat_template_path") and args.chat_template_path is not None: from trl.models import clone_chat_template if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): - # Load Jinja template directly with open(args.chat_template_path, encoding="utf-8") as chat_template_file: processing_class.chat_template = chat_template_file.read() added_tokens = [] else: - # Clone template from another model try: model, processing_class, added_tokens = clone_chat_template( model, processing_class, args.chat_template_path @@ -212,6 +225,16 @@ def __init__( else: added_tokens = [] + if self._is_vlm and args.packing: + raise ValueError( + "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." + ) + if self._is_vlm and args.padding_free: + raise ValueError( + "Padding-free training is not yet supported for vision-language models. Please set " + "`padding_free=False` in the `SFTConfig`." + ) + if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -219,9 +242,7 @@ def __init__( f" and you passed a {type(peft_config)}." ) - # Handle added tokens from chat template if added_tokens: - # Ensure that the added tokens are trainable if peft_config.trainable_token_indices is None: peft_config.trainable_token_indices = {"embed_tokens": added_tokens} elif "embed_tokens" not in peft_config.trainable_token_indices: @@ -229,7 +250,6 @@ def __init__( else: peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) - # Ensure that the lm_head is trainable if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: logger.warning( "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " @@ -237,7 +257,6 @@ def __init__( "tokens, leading to degraded generation quality. To fix this, add " "`modules_to_save=['lm_head']` to your PEFT configuration." ) - if peft_config.modules_to_save is None: peft_config.modules_to_save = ["lm_head"] else: @@ -249,75 +268,66 @@ def __init__( "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] ): - # For backward compatibility with older versions of transformers if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: - def make_inputs_require_grad(module, input, output): output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config) if args is not None and args.bf16: peft_module_casting_to_bf16(model) - # Processing class (tokenizer) handling - if processing_class is None: - from transformers import AutoProcessor - - processing_class = AutoProcessor.from_pretrained(model_id) - - # Ensure we have a pad token if hasattr(processing_class, "pad_token") and getattr(processing_class, "pad_token", None) is None: processing_class.pad_token = processing_class.eos_token if args.max_length is None: - # To overcome some issues with broken tokenizers args.max_length = min(processing_class.model_max_length, 1024) - logger.warning( f"You didn't pass a `max_length` argument to the SFTTrainer, this will default to {args.max_length}" ) self.dataset_num_proc = args.dataset_num_proc - - # We do not support NeFTune with NeuronSFTTrainer for now. self._trainer_supports_neftune = False - # Determine VLM type based on processing_class - # This must be done before data collator creation - if processing_class is None: - from transformers import AutoProcessor + self.padding_free = False + if args.padding_free: + logger.warning( + "padding_free=True is not supported for Neuron training. Neuron devices require fixed input shapes " + "to avoid recompilation. Setting padding_free=False." + ) + args.padding_free = False - processing_class = AutoProcessor.from_pretrained(model_id) + if args.dataset_kwargs is None: + args.dataset_kwargs = {} - if isinstance(processing_class, ProcessorMixin): - self._is_vlm = True - elif isinstance(processing_class, PreTrainedTokenizerBase): - self._is_vlm = False - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + if train_dataset is not None: + dataset_sample = next(iter(train_dataset)) + if args.completion_only_loss is None: + self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample + else: + self.completion_only_loss = args.completion_only_loss + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample - # Initialize _is_vision_dataset - will be set to True if dataset contains 'image' or 'images' keys - # This is needed for trl 0.24.0's _set_signature_columns_if_needed method - self._is_vision_dataset = False + if self._is_vision_dataset and not self._is_vlm: + raise ValueError( + "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " + "model does not seem to be a vision-language model. Please check your model and dataset." + ) + else: + self.completion_only_loss = False + self._is_vision_dataset = False - if args.dataset_kwargs is None: - args.dataset_kwargs = {} + if data_collator is not None and hasattr(data_collator, "padding_free"): + data_collator.padding_free = False if formatting_func is None and args.dataset_text_field is None: - # check if dataset has ChatML format or instruction format and is supported - # if not stays #None formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) - # if a template is detected, we don't need to add special tokens again if formatting_func is not None: args.dataset_kwargs["add_special_tokens"] = False if not args.packing: - # If we aren't skipping data preparation, then a dataset_text_field - # or formatting_func must be provided. if ( args.dataset_text_field is None and formatting_func is None @@ -327,25 +337,6 @@ def make_inputs_require_grad(module, input, output): "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." ) - # Inspect dataset to determine dataset type and completion_only_loss - if train_dataset is not None: - dataset_sample = next(iter(train_dataset)) - if args.completion_only_loss is None: - self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample - else: - self.completion_only_loss = args.completion_only_loss - self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample - else: - self.completion_only_loss = False - self._is_vision_dataset = False - - # Data collator creation with Neuron-specific constraints - # We delegate to parent SFTTrainer to create the proper data collator - # If user provides data_collator, ensure padding_free is False for Neuron - if data_collator is not None and hasattr(data_collator, "padding_free"): - data_collator.padding_free = False - - # Pre-process the datasets only once per node. The remaining processes will use the cache. with NeuronPartialState().local_main_process_first(): if train_dataset is not None: train_dataset = self._prepare_dataset( @@ -377,15 +368,6 @@ def make_inputs_require_grad(module, input, output): 'overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = "right"` to your code.' ) - # Detect if this is a vision dataset - if train_dataset is not None: - try: - dataset_sample = next(iter(train_dataset)) - self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample - except (StopIteration, KeyError): - # Empty dataset or no vision keys - self._is_vision_dataset = False - NeuronTrainer.__init__( self, model, diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index cfd009887..80ebb0094 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -958,6 +958,7 @@ def compute_loss( if num_items_in_batch is not None: inputs = dict(**inputs, reduction="sum") + print("zaza", inputs) outputs = model(**inputs) if isinstance(outputs, dict) and "loss" not in outputs: From cddbf5f38764c62a6006b183cc41fecfd6b24e99 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 3 Nov 2025 18:52:25 +0100 Subject: [PATCH 10/78] chore: sync sft_trainer --- examples/training/qwen3/finetune_qwen3.sh | 3 + optimum/neuron/trainers/sft_trainer.py | 268 +++++++++------------- optimum/neuron/trainers/transformers.py | 1 - 3 files changed, 106 insertions(+), 166 deletions(-) diff --git a/examples/training/qwen3/finetune_qwen3.sh b/examples/training/qwen3/finetune_qwen3.sh index b2d7568e3..a97e37ad7 100755 --- a/examples/training/qwen3/finetune_qwen3.sh +++ b/examples/training/qwen3/finetune_qwen3.sh @@ -6,6 +6,9 @@ export NEURON_FUSE_SOFTMAX=1 export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime export MALLOC_ARENA_MAX=64 # Host OOM mitigation +# Enable Neuron logging +export NEURON_RT_LOG_LEVEL=INFO + # Variables for training PROCESSES_PER_NODE=32 NUM_EPOCHS=3 diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index 8de5c86d5..a85c586cf 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -29,13 +29,13 @@ ) from transformers.trainer_callback import TrainerCallback -from ..accelerate import NeuronPartialState from ..peft import NeuronPeftModel, get_peft_model from ..utils import ( is_trl_available, ) from ..utils.import_utils import is_peft_available from .sft_config import NeuronSFTConfig +from .training_args import NeuronTrainingArguments from .transformers import NeuronTrainer from .trl_utils import TRL_VERSION @@ -111,89 +111,55 @@ def __init__( preprocess_logits_for_metrics: Callable | None = None, peft_config: PeftConfig | None = None, formatting_func: Callable | None = None, - tokenizer: PreTrainedTokenizerBase | None = None, ): if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronSFTTrainer requires trl=={TRL_VERSION}.") - from trl.extras.dataset_formatting import get_formatting_func_from_dataset - from trl.trainer.callbacks import RichProgressCallback - from trl.trainer.utils import peft_module_casting_to_bf16 + from trl.models import clone_chat_template + from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling if is_peft_available(): - from peft import PeftConfig + pass - if tokenizer is not None and processing_class is None: - processing_class = tokenizer - - args_is_none = args is None + # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path model_name = model_name.split("/")[-1] args = NeuronSFTConfig(f"{model_name}-SFT") - elif args is not None and args.__class__.__name__ == "NeuronTrainingArguments": - args_as_dict = args.to_dict() - args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) - args = NeuronSFTConfig(**args_as_dict) - - log_level = args.get_process_log_level() - logging.set_verbosity(log_level) - - if args_is_none: - logging.warning(f"No `SFTConfig` passed, using `output_dir={args.output_dir}`.") - - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("dtype") - if torch_dtype is not None: - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." - ) - model_init_kwargs["dtype"] = torch_dtype + elif isinstance(args, NeuronTrainingArguments) and not isinstance(args, NeuronSFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token + dict_args.pop("push_to_hub_token", None) + args = NeuronSFTConfig(**dict_args) + # Model if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass - elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " - f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + model = AutoModelForCausalLM.from_pretrained(model, **args.model_init_kwargs or {}) else: - model_id = model.config._name_or_path if args.model_init_kwargs is not None: logger.warning( "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " "The `model_init_kwargs` will be ignored." ) + model_id = model.config._name_or_path + # Processing class if processing_class is None: from transformers import AutoProcessor + processing_class = AutoProcessor.from_pretrained(model_id) + # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - self._is_vlm = True tokenizer = processing_class.tokenizer + self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - self._is_vlm = False tokenizer = processing_class + self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if hasattr(args, "eos_token") and args.eos_token is not None: + if args.eos_token is not None: eos_token = args.eos_token eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) if eos_token_id is None: @@ -204,45 +170,33 @@ def __init__( ) tokenizer.eos_token_id = eos_token_id - if hasattr(args, "chat_template_path") and args.chat_template_path is not None: - from trl.models import clone_chat_template - + if args.chat_template_path is not None: if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): with open(args.chat_template_path, encoding="utf-8") as chat_template_file: processing_class.chat_template = chat_template_file.read() added_tokens = [] else: - try: - model, processing_class, added_tokens = clone_chat_template( - model, processing_class, args.chat_template_path - ) - except Exception as e: - logger.warning( - f"Failed to clone chat template from {args.chat_template_path}: {e}. " - "Continuing without custom chat template." - ) - added_tokens = [] + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) else: added_tokens = [] + # Catch some wrong configurations related to VLMs if self._is_vlm and args.packing: raise ValueError( "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." ) if self._is_vlm and args.padding_free: raise ValueError( - "Padding-free training is not yet supported for vision-language models. Please set " + "Padding-free training is yet not supported for vision-language models. Please set " "`padding_free=False` in the `SFTConfig`." ) - if is_peft_available() and peft_config is not None: - if not isinstance(peft_config, PeftConfig): - raise ValueError( - "If you want to use the NeuronPeftModel, you need to pass a PeftConfig object to the NeuronSFTTrainer." - f" and you passed a {type(peft_config)}." - ) - + # PEFT configuration and model wrapping + if peft_config is not None: if added_tokens: + # Ensure that the added tokens are trainable if peft_config.trainable_token_indices is None: peft_config.trainable_token_indices = {"embed_tokens": added_tokens} elif "embed_tokens" not in peft_config.trainable_token_indices: @@ -250,6 +204,7 @@ def __init__( else: peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + # Ensure that the lm_head is trainable if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: logger.warning( "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " @@ -257,40 +212,33 @@ def __init__( "tokens, leading to degraded generation quality. To fix this, add " "`modules_to_save=['lm_head']` to your PEFT configuration." ) + if peft_config.modules_to_save is None: peft_config.modules_to_save = ["lm_head"] else: peft_config.modules_to_save.append("lm_head") if not isinstance(model, NeuronPeftModel): + # Enable gradient checkpointing if needed gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} - if getattr(args, "gradient_checkpointing", False) and ( + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + if args.gradient_checkpointing and ( "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] ): if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: + def make_inputs_require_grad(module, input, output): output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config) - if args is not None and args.bf16: - peft_module_casting_to_bf16(model) - - if hasattr(processing_class, "pad_token") and getattr(processing_class, "pad_token", None) is None: - processing_class.pad_token = processing_class.eos_token - - if args.max_length is None: - args.max_length = min(processing_class.model_max_length, 1024) - logger.warning( - f"You didn't pass a `max_length` argument to the SFTTrainer, this will default to {args.max_length}" - ) - - self.dataset_num_proc = args.dataset_num_proc - self._trainer_supports_neftune = False + # Data collator + # Neuron-specific: padding_free must always be False for Neuron devices self.padding_free = False if args.padding_free: logger.warning( @@ -299,79 +247,83 @@ def make_inputs_require_grad(module, input, output): ) args.padding_free = False - if args.dataset_kwargs is None: - args.dataset_kwargs = {} + # Decide whether to use completion-only loss + dataset_sample = next(iter(train_dataset)) + if args.completion_only_loss is None: + self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample + else: + self.completion_only_loss = args.completion_only_loss - if train_dataset is not None: - dataset_sample = next(iter(train_dataset)) - if args.completion_only_loss is None: - self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample - else: - self.completion_only_loss = args.completion_only_loss - self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + if self._is_vision_dataset and not self._is_vlm: + raise ValueError( + "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " + "model does not seem to be a vision-language model. Please check your model and dataset." + ) - if self._is_vision_dataset and not self._is_vlm: - raise ValueError( - "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " - "model does not seem to be a vision-language model. Please check your model and dataset." - ) - else: - self.completion_only_loss = False - self._is_vision_dataset = False - - if data_collator is not None and hasattr(data_collator, "padding_free"): - data_collator.padding_free = False - - if formatting_func is None and args.dataset_text_field is None: - formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) - if formatting_func is not None: - args.dataset_kwargs["add_special_tokens"] = False - - if not args.packing: - if ( - args.dataset_text_field is None - and formatting_func is None - and not args.dataset_kwargs.get("skip_prepare_dataset", False) - ): + if data_collator is None and not self._is_vision_dataset: + # Get the pad token + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if pad_token_id is None: raise ValueError( - "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." ) + # Neuron-specific: always pad to max_length for fixed input shapes + data_collator = DataCollatorForLanguageModeling( + pad_token_id=pad_token_id, + completion_only_loss=self.completion_only_loss, + padding_free=self.padding_free, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + elif data_collator is None and self._is_vision_dataset: + data_collator = DataCollatorForVisionLanguageModeling( + processor=processing_class, + max_length=args.max_length, + completion_only_loss=self.completion_only_loss, + pad_to_multiple_of=args.pad_to_multiple_of, + dataset_text_field=args.dataset_text_field, + ) - with NeuronPartialState().local_main_process_first(): - if train_dataset is not None: - train_dataset = self._prepare_dataset( - train_dataset, processing_class, args, args.packing, formatting_func, "train" + # Dataset + skip_prepare_dataset = ( + args.dataset_kwargs is not None + and args.dataset_kwargs.get("skip_prepare_dataset", False) + or self._is_vision_dataset + ) + if not skip_prepare_dataset: + if self.completion_only_loss and formatting_func: + raise ValueError( + "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " + "Using a formatter converts the dataset to a language modeling type, conflicting with " + "completion-only loss. To resolve this, apply your formatting function before passing the " + "dataset, or disable `completion_only_loss` in `SFTConfig`." ) + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) if eval_dataset is not None: - _multiple = isinstance(eval_dataset, dict) - _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} - - for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): - _eval_datasets[_eval_dataset_name] = self._prepare_dataset( - _eval_dataset, - processing_class, - args, - args.eval_packing if args.eval_packing is not None else args.packing, - formatting_func, - _eval_dataset_name, + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" ) - if not _multiple: - eval_dataset = _eval_datasets["singleton"] - - if ( - hasattr(processing_class, "padding_side") - and processing_class.padding_side is not None - and processing_class.padding_side != "right" - ): - logger.warning( - "You passed a processing_class with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " - 'overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = "right"` to your code.' - ) + # Neuron-specific: we don't support NeFTune + self._trainer_supports_neftune = False + + # Initialize NeuronTrainer NeuronTrainer.__init__( self, - model, - args, + model=model, + args=args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, @@ -385,20 +337,6 @@ def make_inputs_require_grad(module, input, output): if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - if self.args.max_steps > 0 and args.packing: - logger.warning( - "You passed `packing=True` to the NeuronSFTTrainer/SFTConfig, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." - ) - self.train_dataset.infinite = True - elif self.args.max_steps == -1 and args.packing: - self.train_dataset.infinite = False - - if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks): - for callback in self.callback_handler.callbacks: - # Remove the PrinterCallback to avoid duplicated prints in case we passed a `RichProgressCallback` - if callback.__class__.__name__ == "PrinterCallback": - self.callback_handler.pop_callback(callback) - def train( self, resume_from_checkpoint: str | bool | None = None, diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index 80ebb0094..cfd009887 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -958,7 +958,6 @@ def compute_loss( if num_items_in_batch is not None: inputs = dict(**inputs, reduction="sum") - print("zaza", inputs) outputs = model(**inputs) if isinstance(outputs, dict) and "loss" not in outputs: From 0200820e80f584225571009c9359c42ef72fd16b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 10:28:58 +0100 Subject: [PATCH 11/78] fix: sft trainer --- examples/training/qwen3/finetune_qwen3.sh | 3 - optimum/neuron/trainers/sft_trainer.py | 148 ++++++++++++++-------- 2 files changed, 92 insertions(+), 59 deletions(-) diff --git a/examples/training/qwen3/finetune_qwen3.sh b/examples/training/qwen3/finetune_qwen3.sh index a97e37ad7..b2d7568e3 100755 --- a/examples/training/qwen3/finetune_qwen3.sh +++ b/examples/training/qwen3/finetune_qwen3.sh @@ -6,9 +6,6 @@ export NEURON_FUSE_SOFTMAX=1 export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime export MALLOC_ARENA_MAX=64 # Host OOM mitigation -# Enable Neuron logging -export NEURON_RT_LOG_LEVEL=INFO - # Variables for training PROCESSES_PER_NODE=32 NUM_EPOCHS=3 diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index a85c586cf..b0a9ef149 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -21,7 +21,6 @@ from optimum.utils import logging from torch.utils.data import Dataset, IterableDataset from transformers import ( - AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, @@ -29,6 +28,7 @@ ) from transformers.trainer_callback import TrainerCallback +from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model from ..utils import ( is_trl_available, @@ -79,6 +79,52 @@ class PeftConfig: logger = logging.get_logger() +class NeuronDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): + """ + Data collator for Neuron that ensures all sequences are padded to exactly max_length. + + This is required for Neuron devices to maintain fixed input shapes and avoid recompilation. + Inherits from trl's DataCollatorForLanguageModeling but adds max_length enforcement. + """ + + def __init__(self, max_length: int, **kwargs): + super().__init__(**kwargs) + self.max_length = max_length + + def __call__(self, examples): + # Pad/truncate all sequences to max_length before calling parent + for example in examples: + if "input_ids" in example: + input_ids = example["input_ids"] + current_length = len(input_ids) + + if current_length > self.max_length: + # Truncate to max_length + example["input_ids"] = input_ids[: self.max_length] + elif current_length < self.max_length: + # Pad to max_length + example["input_ids"] = input_ids + [self.pad_token_id] * (self.max_length - current_length) + + # Handle other fields if present + for key in ["labels", "attention_mask", "completion_mask"]: + if key in example: + field = example[key] + field_length = len(field) + if field_length > self.max_length: + example[key] = field[: self.max_length] + elif field_length < self.max_length: + # Pad with appropriate value + if key == "labels": + pad_value = -100 + elif key == "attention_mask": + pad_value = 0 + elif key == "completion_mask": + pad_value = 0 + example[key] = field + [pad_value] * (self.max_length - field_length) + + return super().__call__(examples) + + class NeuronSFTTrainer(_SFTTrainer): """ `SFTTrainer` adapted for Neuron (Trainium) devices. @@ -87,7 +133,6 @@ class NeuronSFTTrainer(_SFTTrainer): - Uses NeuronTrainer.__init__() instead of transformers.Trainer.__init__() - Uses NeuronTrainer.train() for Neuron-optimized training - Enforces padding_free=False for fixed input shapes (required for Trainium) - - Simplifies _prepare_dataset to delegate to parent with Neuron constraints Neuron-specific constraints: - padding_free is always False to avoid recompilation @@ -116,10 +161,7 @@ def __init__( raise RuntimeError(f"Using NeuronSFTTrainer requires trl=={TRL_VERSION}.") from trl.models import clone_chat_template - from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling - - if is_peft_available(): - pass + from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling # Args if args is None: @@ -134,7 +176,7 @@ def __init__( # Model if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **args.model_init_kwargs or {}) + model = NeuronModelForCausalLM.from_pretrained(model, **args.model_init_kwargs or {}) else: if args.model_init_kwargs is not None: logger.warning( @@ -218,24 +260,31 @@ def __init__( else: peft_config.modules_to_save.append("lm_head") - if not isinstance(model, NeuronPeftModel): - # Enable gradient checkpointing if needed - gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} - gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs - if args.gradient_checkpointing and ( - "use_reentrant" not in gradient_checkpointing_kwargs - or gradient_checkpointing_kwargs["use_reentrant"] - ): - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 + + if peft_config is not None and not isinstance(model, NeuronPeftModel): + # Enable gradient checkpointing if needed + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + if args.gradient_checkpointing and ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + model = get_peft_model(model, peft_config) - model = get_peft_model(model, peft_config) + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) # Data collator # Neuron-specific: padding_free must always be False for Neuron devices @@ -271,8 +320,9 @@ def make_inputs_require_grad(module, input, output): f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " "in the vocabulary before using it as a padding token." ) - # Neuron-specific: always pad to max_length for fixed input shapes - data_collator = DataCollatorForLanguageModeling( + # Neuron-specific: use NeuronDataCollatorForLanguageModeling to ensure fixed max_length padding + data_collator = NeuronDataCollatorForLanguageModeling( + max_length=args.max_length, pad_token_id=pad_token_id, completion_only_loss=self.completion_only_loss, padding_free=self.padding_free, @@ -343,6 +393,24 @@ def train( ): return NeuronTrainer.train(self, resume_from_checkpoint=resume_from_checkpoint) + def log(self, logs: dict[str, float]) -> None: + """ + Override SFTTrainer's log method to use NeuronTrainer's implementation. + + SFTTrainer has custom metrics tracking that we don't use for Neuron training. + """ + return NeuronTrainer.log(self, logs) + + def _save_checkpoint(self, model=None, trial=None, metrics=None): + """ + Override SFTTrainer's _save_checkpoint to use NeuronTrainer's implementation. + + SFTTrainer has a custom checkpoint saving method, but we use NeuronTrainer's + which is compatible with Neuron's distributed training and async saving. + NeuronTrainer._save_checkpoint only takes self, so we ignore the extra arguments. + """ + return NeuronTrainer._save_checkpoint(self) + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ Compute training loss for Neuron-optimized training. @@ -366,35 +434,3 @@ def training_step( which is compatible with Neuron's distributed training setup. """ return NeuronTrainer.training_step(self, model, inputs, num_items_in_batch=num_items_in_batch) - - def _prepare_dataset( - self, - dataset, - processing_class, - args, - packing, - formatting_func=None, - dataset_name="train", - ): - """ - Prepare dataset for Neuron training. - - Delegates to parent SFTTrainer._prepare_dataset, which handles: - - Dataset type detection (language modeling, prompt-completion, conversational) - - Chat template application - - Tokenization - - Packing (if enabled) - - Neuron-specific behavior: - - Ensures padding_free=False to avoid recompilation - - Enforces padding to max_length for fixed input shapes - """ - # Ensure padding_free is disabled for Neuron - this is critical for Trainium devices - if args.padding_free: - raise ValueError( - "padding_free must be False for Neuron training. " - "Neuron devices require fixed input shapes to avoid recompilation." - ) - - # Call parent implementation from SFTTrainer - return super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) From 2c8c1d15de5721131d681a608a6b511a12478998 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 11:13:55 +0100 Subject: [PATCH 12/78] chore: update dependency version for trl --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index facb0eb40..db63a53a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ quality = [ "isort", ] training = [ - "trl == 0.23.1", + "trl == 0.24.0", "peft == 0.17.0", "evaluate == 0.4.3", ] From 7eda163445e878de8e42e741503f7ca62cdc149d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 11:38:37 +0100 Subject: [PATCH 13/78] chore: cleanup and fix no-packing test --- optimum/neuron/trainers/sft_trainer.py | 17 ++++------------- tests/training/test_neuron_sft_trainer.py | 16 +++++++++++++--- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index b0a9ef149..3a51aeeea 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -42,6 +42,7 @@ if is_trl_available(): from trl import SFTConfig, SFTTrainer + from trl.models import clone_chat_template from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling else: @@ -57,6 +58,9 @@ class DataCollatorForLanguageModeling: class DataCollatorForVisionLanguageModeling: pass + def clone_chat_template(*args, **kwargs): + pass + if is_peft_available(): from peft import PeftConfig @@ -128,16 +132,6 @@ def __call__(self, examples): class NeuronSFTTrainer(_SFTTrainer): """ `SFTTrainer` adapted for Neuron (Trainium) devices. - - Overrides key methods for Neuron compatibility: - - Uses NeuronTrainer.__init__() instead of transformers.Trainer.__init__() - - Uses NeuronTrainer.train() for Neuron-optimized training - - Enforces padding_free=False for fixed input shapes (required for Trainium) - - Neuron-specific constraints: - - padding_free is always False to avoid recompilation - - VLM training is not yet supported - - NeFTune training is not supported """ def __init__( @@ -160,9 +154,6 @@ def __init__( if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronSFTTrainer requires trl=={TRL_VERSION}.") - from trl.models import clone_chat_template - from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling - # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path diff --git a/tests/training/test_neuron_sft_trainer.py b/tests/training/test_neuron_sft_trainer.py index 4be2a1a80..3f2140537 100644 --- a/tests/training/test_neuron_sft_trainer.py +++ b/tests/training/test_neuron_sft_trainer.py @@ -53,9 +53,7 @@ def format_dolly(sample): context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None response = f"### Answer\n{sample['response']}" prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) - if packing: - return prompt - return [prompt] + return prompt args = NeuronTrainingArguments( output_dir=str(tmpdir), @@ -98,6 +96,18 @@ def format_dolly(sample): # Verify initial state assert trainer.state.global_step == 0, f"Expected initial global_step=0, got {trainer.state.global_step}" + # Verify that all inputs are padded to max_length + sample_batch = next(iter(trainer.get_train_dataloader())) + assert sample_batch["input_ids"].shape[1] == sft_config.max_length, ( + f"Expected input_ids to have length {sft_config.max_length}, got {sample_batch['input_ids'].shape[1]}" + ) + assert sample_batch["labels"].shape[1] == sft_config.max_length, ( + f"Expected labels to have length {sft_config.max_length}, got {sample_batch['labels'].shape[1]}" + ) + assert sample_batch["attention_mask"].shape[1] == sft_config.max_length, ( + f"Expected attention_mask to have length {sft_config.max_length}, got {sample_batch['attention_mask'].shape[1]}" + ) + # Run training trainer.train() From b6ee2a39237b00500d3f0d655b5c037fb0ba004a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 11:38:56 +0100 Subject: [PATCH 14/78] chore: restore finetune_qwen3.sh --- examples/training/qwen3/finetune_qwen3.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/training/qwen3/finetune_qwen3.sh b/examples/training/qwen3/finetune_qwen3.sh index b2d7568e3..d64a6572d 100755 --- a/examples/training/qwen3/finetune_qwen3.sh +++ b/examples/training/qwen3/finetune_qwen3.sh @@ -13,8 +13,7 @@ TP_DEGREE=8 BS=1 GRADIENT_ACCUMULATION_STEPS=8 LOGGING_STEPS=2 -# MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name -MODEL_NAME="Qwen/Qwen3-0.6B" # Change this to the desired model name +MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-finetuned" DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE" SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) From 72b338a4d2987057e786a7962e98c407fa5c1b11 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 11:43:43 +0100 Subject: [PATCH 15/78] feat: add model card creation when saving a checkpoint --- optimum/neuron/trainers/sft_trainer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index 3a51aeeea..92a033190 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -14,6 +14,7 @@ # limitations under the License. import os +from pathlib import Path from typing import Any, Callable import datasets @@ -395,19 +396,17 @@ def log(self, logs: dict[str, float]) -> None: def _save_checkpoint(self, model=None, trial=None, metrics=None): """ Override SFTTrainer's _save_checkpoint to use NeuronTrainer's implementation. - - SFTTrainer has a custom checkpoint saving method, but we use NeuronTrainer's - which is compatible with Neuron's distributed training and async saving. - NeuronTrainer._save_checkpoint only takes self, so we ignore the extra arguments. """ + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) return NeuronTrainer._save_checkpoint(self) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ Compute training loss for Neuron-optimized training. - - Overrides TRL SFTTrainer's compute_loss to set use_cache=False for gradient - checkpointing compatibility and delegate to NeuronTrainer's compute_loss. """ # Set use_cache to False to avoid warnings with gradient checkpointing inputs["use_cache"] = False @@ -420,8 +419,5 @@ def training_step( ) -> torch.Tensor: """ Perform a training step for Neuron-optimized training. - - Overrides SFTTrainer.training_step to delegate to NeuronTrainer's implementation, - which is compatible with Neuron's distributed training setup. """ return NeuronTrainer.training_step(self, model, inputs, num_items_in_batch=num_items_in_batch) From 98a6210de9a90f941fde769809444eae6dc759be Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 11:49:48 +0100 Subject: [PATCH 16/78] chore: remove model card support --- optimum/neuron/trainers/sft_trainer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/optimum/neuron/trainers/sft_trainer.py b/optimum/neuron/trainers/sft_trainer.py index 92a033190..20651277b 100644 --- a/optimum/neuron/trainers/sft_trainer.py +++ b/optimum/neuron/trainers/sft_trainer.py @@ -14,7 +14,6 @@ # limitations under the License. import os -from pathlib import Path from typing import Any, Callable import datasets @@ -396,12 +395,9 @@ def log(self, logs: dict[str, float]) -> None: def _save_checkpoint(self, model=None, trial=None, metrics=None): """ Override SFTTrainer's _save_checkpoint to use NeuronTrainer's implementation. + The only difference is that this method does not create a model card after saving the checkpoint, it can be + added if needed. """ - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) return NeuronTrainer._save_checkpoint(self) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): From ee6caeb4cf3d7784da759acb6224c061e9fff22d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 11:53:39 +0100 Subject: [PATCH 17/78] doc: align with trl==0.24.0 --- docs/source/training_tutorials/finetune_llama.mdx | 4 ++-- docs/source/training_tutorials/finetune_qwen3.mdx | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/training_tutorials/finetune_llama.mdx b/docs/source/training_tutorials/finetune_llama.mdx index 4e7ae6697..f6c4f6688 100644 --- a/docs/source/training_tutorials/finetune_llama.mdx +++ b/docs/source/training_tutorials/finetune_llama.mdx @@ -156,7 +156,7 @@ lora_config = LoraConfig( args = training_args.to_dict() sft_config = NeuronSFTConfig( - max_seq_length=2048, + max_length=2048, packing=True, **args, ) @@ -186,7 +186,7 @@ trainer = NeuronSFTTrainer( args=sft_config, model=model, peft_config=lora_config, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dataset, formatting_func=lambda example: format_dolly(example, tokenizer), ) diff --git a/docs/source/training_tutorials/finetune_qwen3.mdx b/docs/source/training_tutorials/finetune_qwen3.mdx index 1cf4a2d6e..a8f89e168 100644 --- a/docs/source/training_tutorials/finetune_qwen3.mdx +++ b/docs/source/training_tutorials/finetune_qwen3.mdx @@ -164,7 +164,7 @@ lora_config = LoraConfig( args = training_args.to_dict() sft_config = NeuronSFTConfig( - max_seq_length=4096, + max_length=4096, packing=True, **args, ) @@ -181,7 +181,7 @@ dataset = preprocess_dataset_with_eos(tokenizer.eos_token) args=sft_config, model=model, peft_config=lora_config, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=dataset, formatting_func=formatting_function, ) From ac0c9f2124a480dc30de27a96e0ccafbad98d4d5 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 14:42:57 +0100 Subject: [PATCH 18/78] test: fix broken sft + peft test --- tests/training/test_neuron_sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/training/test_neuron_sft_trainer.py b/tests/training/test_neuron_sft_trainer.py index 3f2140537..dc9513530 100644 --- a/tests/training/test_neuron_sft_trainer.py +++ b/tests/training/test_neuron_sft_trainer.py @@ -151,7 +151,7 @@ def format_dolly(sample): context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None response = f"### Answer\n{sample['response']}" prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) - return [prompt] # No packing for simplicity + return prompt args = NeuronTrainingArguments( output_dir=str(tmpdir), From 8892d51e7d7f69c9304ac02f2ea61c06695e9446 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 14:50:39 +0100 Subject: [PATCH 19/78] chore: add GRPO imports in optimum.neuron --- optimum/neuron/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 260ad04b6..3cace0c0a 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -37,8 +37,10 @@ "trainers": [ "NeuronTrainer", "NeuronSFTTrainer", + "NeuronGRPOTrainer", "NeuronTrainingArguments", "NeuronSFTConfig", + "NeuronGRPOConfig", ], "modeling_traced": ["NeuronTracedModel"], "modeling": [ @@ -158,6 +160,8 @@ from .trainers import ( NeuronSFTConfig, NeuronSFTTrainer, + NeuronGRPOConfig, + NeuronGRPOTrainer, NeuronTrainer, NeuronTrainingArguments, ) From f26497d9c3ed0ac6b21e2545c6ded9e5eef24078 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 14:52:24 +0100 Subject: [PATCH 20/78] chore: add GRPO imports in optimum.neuron.trainers --- optimum/neuron/trainers/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/neuron/trainers/__init__.py b/optimum/neuron/trainers/__init__.py index df612b2dc..7081378b4 100644 --- a/optimum/neuron/trainers/__init__.py +++ b/optimum/neuron/trainers/__init__.py @@ -18,3 +18,5 @@ from .training_args import NeuronTrainingArguments from .transformers import NeuronTrainer from .trl_utils import TRL_VERSION +from .grpo_trainer import NeuronGRPOTrainer +from .grpo_config import NeuronGRPOConfig From f574d3ea7f0b0a6a76897236aec8dd072bfc2932 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 17:23:41 +0100 Subject: [PATCH 21/78] chore: add skeleton for GRPO trainer --- optimum/neuron/trainers/grpo_config.py | 35 ++ optimum/neuron/trainers/grpo_trainer.py | 576 ++++++++++++++++++++++++ 2 files changed, 611 insertions(+) create mode 100644 optimum/neuron/trainers/grpo_config.py create mode 100644 optimum/neuron/trainers/grpo_trainer.py diff --git a/optimum/neuron/trainers/grpo_config.py b/optimum/neuron/trainers/grpo_config.py new file mode 100644 index 000000000..32e6c4fbb --- /dev/null +++ b/optimum/neuron/trainers/grpo_config.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..utils.import_utils import is_trl_available +from .training_args import NeuronTrainingArguments +from .trl_utils import TRL_VERSION + + +if is_trl_available(): + from trl import GRPOConfig +else: + + @dataclass + class GRPOConfig: + def __init__(self, *args, **kwargs): + raise RuntimeError(f"You need to install the `trl=={TRL_VERSION}` library to use the `NeuronGRPOConfig`.") + + +@dataclass +class NeuronGRPOConfig(NeuronTrainingArguments, GRPOConfig): + pass diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py new file mode 100644 index 000000000..fd062661e --- /dev/null +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -0,0 +1,576 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from collections import defaultdict, deque +from typing import Any, Callable + +import datasets +import torch +from accelerate.utils import set_seed +from optimum.utils import logging +from torch.utils.data import Dataset, IterableDataset +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) + +from ..models.training import NeuronModelForCausalLM +from ..peft import NeuronPeftModel, get_peft_model +from ..utils import is_trl_available +from ..utils.import_utils import is_peft_available +from .grpo_config import NeuronGRPOConfig +from .training_args import NeuronTrainingArguments +from .transformers import NeuronTrainer +from .trl_utils import TRL_VERSION + + +if is_trl_available(): + from trl import GRPOConfig, GRPOTrainer +else: + + class GRPOTrainer: + pass + + class GRPOConfig: + pass + + +if is_peft_available(): + from peft import PeftConfig +else: + + class PeftConfig: + pass + + +# Create a new class that inherits from NeuronTrainer to use this class instead of the transformers Trainer, +# but has the same methods and attributes as GRPOTrainer. +# We can then inherit from this class to create our NeuronGRPOTrainer. +_GRPOTrainer = type( + "_GRPOTrainer", + (NeuronTrainer,), + GRPOTrainer.__dict__.copy(), +) + + +logger = logging.get_logger() + + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] + + +class NeuronGRPOTrainer(_GRPOTrainer): + """ + `GRPOTrainer` adapted for Neuron (Trainium) devices. + + This algorithm was initially proposed in the paper [DeepSeekMath: Pushing the Limits + of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). + """ + + def __init__( + self, + model: str | PreTrainedModel | torch.nn.Module, + reward_funcs: RewardFunc | list[RewardFunc], + args: GRPOConfig | None = None, + train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None, + eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, + peft_config: PeftConfig | None = None, + ): + if not is_trl_available(required_version=TRL_VERSION): + raise RuntimeError(f"Using NeuronGRPOTrainer requires trl=={TRL_VERSION}.") + + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = NeuronGRPOConfig(f"{model_name}-GRPO") + elif isinstance(args, NeuronTrainingArguments) and not isinstance(args, NeuronGRPOConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token + dict_args.pop("push_to_hub_token", None) + args = NeuronGRPOConfig(**dict_args) + + # Model + if isinstance(model, str): + model = NeuronModelForCausalLM.from_pretrained(model, **args.model_init_kwargs or {}) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + model_id = model.config._name_or_path + + # Processing class + if processing_class is None: + from transformers import AutoProcessor + + processing_class = AutoProcessor.from_pretrained(model_id, truncation_side="left") + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # PEFT configuration and model wrapping + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 + + if peft_config is not None and not isinstance(model, NeuronPeftModel): + # Enable gradient checkpointing if needed + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + if args.gradient_checkpointing and ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) + + # Reward functions - for now, only support callable reward functions + # TODO: Add support for reward models when they can be properly loaded on Neuron + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + raise NotImplementedError( + "Loading reward models from model IDs is not yet implemented for NeuronGRPOTrainer. " + "Please provide either a PreTrainedModel or a custom callable reward function." + ) + if isinstance(reward_func, PreTrainedModel): + raise NotImplementedError( + "Using PreTrainedModel reward functions is not yet fully implemented for NeuronGRPOTrainer. " + "Please use a custom callable reward function for now." + ) + # Custom callable reward function + self.reward_func_names.append(reward_func.__name__) + + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " + f"reward functions ({len(reward_funcs)})." + ) + + # Note: We skip the loop that sets up tokenizers for PreTrainedModel reward functions + # since we currently raise errors for those anyway + self.reward_processing_classes = reward_processing_classes + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size + self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap + self.use_liger_loss = args.use_liger_loss + self.loss_type = args.loss_type + self.scale_rewards = args.scale_rewards + self.importance_sampling_level = args.importance_sampling_level + self.mask_truncated_completions = args.mask_truncated_completions + self.top_entropy_quantile = args.top_entropy_quantile + + # Validate liger kernel configuration + if self.use_liger_loss: + raise RuntimeError( + "Liger Kernel loss is not supported on Neuron devices. " + "Please set use_liger_loss=False in your GRPOConfig." + ) + + # Neuron GRPO only supports vLLM generation + if not self.use_vllm: + raise NotImplementedError( + "NeuronGRPOTrainer currently only supports vLLM generation. " + "Please set use_vllm=True in your GRPOConfig." + ) + + # Only server mode is supported for now + if self.vllm_mode != "server": + raise NotImplementedError( + "NeuronGRPOTrainer currently only supports vLLM server mode. " + "Please set vllm_mode='server' in your GRPOConfig." + ) + + if self._is_vlm: + raise NotImplementedError( + "Vision-language models are not yet supported in NeuronGRPOTrainer. " + "Please use text-only models for now." + ) + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or (isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())) + ): + raise NotImplementedError( + "Iterable datasets are not yet supported in NeuronGRPOTrainer. Please use a standard dataset instead." + ) + + # Multi-step + self.num_iterations = args.num_iterations + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self._step = 0 + self._buffered_inputs = None + + # Suppress FLOP estimation warning + model.warnings_issued["estimate_tokens"] = True + + # Initialize NeuronTrainer + from trl.trainer.utils import identity + + NeuronTrainer.__init__( + self, + model=model, + args=args, + data_collator=identity, # No data collation is needed in GRPO + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + ) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + self.ref_model = None + elif isinstance(model, NeuronPeftModel): + self.ref_model = None + else: + # Create reference model using NeuronModelForCausalLM + self.ref_model = NeuronModelForCausalLM.from_pretrained(model_id, **args.model_init_kwargs or {}) + + # Disable dropout in the models + if args.disable_dropout: + from trl.trainer.utils import disable_dropout_in_model + + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + self._logs = { + "images": deque(maxlen=args.generation_batch_size), + "prompt": deque(maxlen=args.generation_batch_size), + "completion": deque(maxlen=args.generation_batch_size), + "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + "advantages": deque(maxlen=args.generation_batch_size), + } + + # Ensure each process receives a unique seed + set_seed(args.seed, device_specific=True) + + # vLLM setup - server mode only + from ..utils import is_vllm_available + + if not is_vllm_available(): + raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") + + # Setup vLLM server client (only on main process) + if self.accelerator.is_main_process: + from trl.extras.vllm_client import VLLMClient + + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client.init_communicator(device=torch.cuda.current_device()) + + # vLLM specific sampling arguments + self.guided_decoding_regex = args.vllm_guided_decoding_regex + + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # Synchronize all processes after vLLM client setup + self.accelerator.wait_for_everyone() + + # Gradient accumulation requires scaled loss + self.model_accepts_loss_kwargs = False + + # Add tags for models + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + # Prepare reference model + if self.ref_model is not None: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + # Sync reference model callback + if args.sync_ref_model: + from trl.trainer.callbacks import SyncRefModelCallback + + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + # Prepare reward functions + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + def train( + self, + resume_from_checkpoint: str | bool | None = None, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint: Path to a checkpoint to resume from, or True to resume from the latest checkpoint. + """ + return NeuronTrainer.train(self, resume_from_checkpoint=resume_from_checkpoint) + + def log(self, logs: dict[str, float]) -> None: + """ + Override GRPOTrainer's log method to use NeuronTrainer's implementation. + + GRPOTrainer has custom metrics tracking that we don't use for Neuron training. + """ + return NeuronTrainer.log(self, logs) + + def _save_checkpoint(self, model=None, trial=None, metrics=None): + """ + Override GRPOTrainer's _save_checkpoint to use NeuronTrainer's implementation. + """ + return NeuronTrainer._save_checkpoint(self) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss for Neuron-optimized training. + + TODO: Implement GRPO-specific loss computation adapted for Neuron devices. + """ + raise NotImplementedError( + "compute_loss is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing GRPO-specific loss computation for Neuron devices." + ) + + def training_step( + self, model: torch.nn.Module, inputs: dict[str, Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + """ + Perform a training step for Neuron-optimized training. + + TODO: Implement GRPO-specific training step adapted for Neuron devices. + """ + raise NotImplementedError( + "training_step is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing GRPO-specific training logic for Neuron devices." + ) + + def _prepare_inputs(self, inputs): + """ + Prepare inputs for GRPO training, including generation and reward computation. + + TODO: Implement input preparation with Neuron-compatible generation and reward scoring. + """ + raise NotImplementedError( + "_prepare_inputs is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing prompt generation and reward computation for Neuron devices." + ) + + def _generate(self, prompts: list[str], images: list | None): + """ + Generate completions for the given prompts. + + TODO: Implement Neuron-compatible text generation. + """ + raise NotImplementedError( + "_generate is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing generation logic compatible with Neuron devices." + ) + + def _generate_single_turn(self, prompts: list[str], images: list | None): + """ + Generate a single turn of completions. + + TODO: Implement single-turn generation for Neuron devices. + """ + raise NotImplementedError( + "_generate_single_turn is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing single-turn generation for Neuron devices." + ) + + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + """ + Calculate rewards for the generated completions. + + TODO: Implement reward calculation compatible with Neuron devices. + """ + raise NotImplementedError( + "_calculate_rewards is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing reward computation for Neuron devices." + ) + + def _compute_loss(self, model, inputs): + """ + Internal loss computation for GRPO. + + TODO: Implement GRPO loss computation for Neuron devices. + """ + raise NotImplementedError( + "_compute_loss is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing the core GRPO loss computation for Neuron devices." + ) + + def get_train_dataloader(self): + """ + Get the training dataloader with GRPO-specific batching strategy. + + TODO: Implement GRPO-specific dataloader with proper batching for Neuron devices. + """ + raise NotImplementedError( + "get_train_dataloader is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing GRPO's custom batching strategy for Neuron devices." + ) + + def _get_train_sampler(self, dataset: Dataset | None = None): + """ + Get the training sampler with GRPO-specific sampling strategy. + + TODO: Implement RepeatSampler strategy for GRPO on Neuron devices. + """ + raise NotImplementedError( + "_get_train_sampler is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing GRPO's RepeatSampler strategy for Neuron devices." + ) + + def _get_eval_sampler(self, eval_dataset): + """ + Get the evaluation sampler. + + TODO: Implement evaluation sampler for GRPO on Neuron devices. + """ + raise NotImplementedError( + "_get_eval_sampler is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing the evaluation sampler for Neuron devices." + ) + + def _get_per_token_logps_and_entropies(self, *args, **kwargs): + """ + Compute per-token log probabilities and entropies. + + TODO: Implement log probability and entropy computation for Neuron devices. + """ + raise NotImplementedError( + "_get_per_token_logps_and_entropies is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing log probability computation for Neuron devices." + ) + + def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Get mask for high-entropy tokens. + + TODO: Implement entropy-based masking for Neuron devices. + """ + raise NotImplementedError( + "get_high_entropy_mask is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing entropy-based masking for Neuron devices." + ) + + def _set_signature_columns_if_needed(self): + """ + Set signature columns for GRPO-specific data preprocessing. + + TODO: Implement signature column handling for Neuron devices. + """ + raise NotImplementedError( + "_set_signature_columns_if_needed is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing signature column handling for GRPO on Neuron devices." + ) + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + """ + Perform a prediction step during evaluation. + + TODO: Implement prediction step for GRPO evaluation on Neuron devices. + """ + raise NotImplementedError( + "prediction_step is not yet implemented for NeuronGRPOTrainer. " + "This requires implementing the prediction step for Neuron devices." + ) From b105f91f9545a5f455692bf67f4454530b639abe Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 4 Nov 2025 18:07:03 +0100 Subject: [PATCH 22/78] feat: add mock class for vLLM --- optimum/neuron/trainers/grpo_mocks.py | 170 ++++++++++++++++++++++++ optimum/neuron/trainers/grpo_trainer.py | 61 ++++++--- 2 files changed, 210 insertions(+), 21 deletions(-) create mode 100644 optimum/neuron/trainers/grpo_mocks.py diff --git a/optimum/neuron/trainers/grpo_mocks.py b/optimum/neuron/trainers/grpo_mocks.py new file mode 100644 index 000000000..621a360af --- /dev/null +++ b/optimum/neuron/trainers/grpo_mocks.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mock implementations for GRPO trainer testing and development. + +This module provides mock implementations of vLLM client and other components +to enable development and testing of NeuronGRPOTrainer without requiring a full +vLLM server setup. +""" + +from optimum.utils import logging + + +logger = logging.get_logger() + + +class MockVLLMClient: + """ + Mock vLLM client that generates dummy completions for testing. + + This mock client simulates the behavior of a real vLLM server by generating + placeholder completions. It's useful for: + - Development without vLLM server setup + - Testing trainer logic independently of generation quality + - Unit testing GRPO training loop + + Args: + tokenizer: Tokenizer to use for encoding/decoding + max_completion_length: Maximum length of generated completions + + Note: + This is a development tool and should not be used in production. + Generated completions are deterministic placeholders, not real language model outputs. + """ + + def __init__(self, tokenizer, max_completion_length=256): + self.tokenizer = tokenizer + self.max_completion_length = max_completion_length + logger.warning( + "Using MockVLLMClient for development. This generates placeholder completions " + "and should only be used for testing and development." + ) + + def generate( + self, + prompts: list[str], + images=None, + n: int = 1, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 256, + repetition_penalty: float = 1.0, + truncate_prompt_tokens=None, + guided_decoding_regex=None, + generation_kwargs=None, + ): + """ + Generate mock completions for the given prompts. + + Args: + prompts: List of prompt strings + images: Optional list of images (not used in mock) + n: Number of completions to generate per prompt + temperature: Sampling temperature (not used in mock) + top_p: Nucleus sampling parameter (not used in mock) + top_k: Top-k sampling parameter (not used in mock) + min_p: Minimum probability threshold (not used in mock) + max_tokens: Maximum tokens to generate + repetition_penalty: Repetition penalty (not used in mock) + truncate_prompt_tokens: Maximum prompt length + guided_decoding_regex: Regex for guided decoding (not used in mock) + generation_kwargs: Additional generation arguments (not used in mock) + + Returns: + Dictionary with keys: + - prompt_ids: List of tokenized prompts (one per prompt) + - completion_ids: List of tokenized completions (n per prompt) + - logprobs: List of log probabilities (one list per completion) + """ + prompt_ids = [] + completion_ids = [] + logprobs = [] + + for prompt in prompts: + # Tokenize prompt + prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False) + + # Truncate if needed + if truncate_prompt_tokens is not None and len(prompt_tokens) > truncate_prompt_tokens: + prompt_tokens = prompt_tokens[-truncate_prompt_tokens:] + + # Store one copy of prompt tokens (will be repeated n times by caller) + prompt_ids.append(prompt_tokens) + + # Generate n completions per prompt + for i in range(n): + # Generate mock completion + # Use a simple pattern: repeat EOS token to create fixed-length completion + # In real scenario, this would be actual LLM generation + completion_length = min(max_tokens, self.max_completion_length) + + # Create a simple deterministic completion based on prompt and index + # This helps with debugging as completions will be consistent + if len(prompt_tokens) > 0: + # Use last prompt token as seed for variety + seed_token = prompt_tokens[-1] + else: + seed_token = self.tokenizer.eos_token_id + + # Generate completion: alternate between seed_token and eos_token + completion = [] + for j in range(completion_length): + if j % 2 == i % 2: # Use index for variation + completion.append(seed_token) + else: + completion.append(self.tokenizer.eos_token_id) + + completion_ids.append(completion) + + # Generate mock logprobs (uniform negative values) + # Real logprobs would come from the model's probability distribution + completion_logprobs = [-1.0] * completion_length + logprobs.append(completion_logprobs) + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "logprobs": logprobs, + } + + def init_communicator(self, device): + """ + Mock initialization of communicator. + + Args: + device: Device to initialize on (not used in mock) + """ + pass + + +def create_mock_vllm_client(tokenizer, args): + """ + Factory function to create a mock vLLM client. + + Args: + tokenizer: Tokenizer to use for the mock client + args: Training arguments containing max_completion_length + + Returns: + MockVLLMClient instance + """ + return MockVLLMClient( + tokenizer=tokenizer, + max_completion_length=args.max_completion_length, + ) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index fd062661e..01ba593aa 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -15,13 +15,14 @@ import inspect from collections import defaultdict, deque +from functools import partial from typing import Any, Callable import datasets import torch from accelerate.utils import set_seed from optimum.utils import logging -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import DataLoader, Dataset, IterableDataset from transformers import ( AutoConfig, AutoModelForSequenceClassification, @@ -31,6 +32,8 @@ ProcessorMixin, TrainerCallback, ) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model @@ -346,20 +349,35 @@ def make_inputs_require_grad(module, input, output): # vLLM setup - server mode only from ..utils import is_vllm_available - if not is_vllm_available(): - raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") + # For now, use mock vLLM client for development + # TODO: Set to False when real vLLM server is ready for Neuron + USE_MOCK_VLLM = True - # Setup vLLM server client (only on main process) - if self.accelerator.is_main_process: - from trl.extras.vllm_client import VLLMClient + if USE_MOCK_VLLM: + logger.warning( + "Using MOCK vLLM client for development. This generates placeholder completions " + "and should only be used for testing and development. Set USE_MOCK_VLLM=False in " + "grpo_trainer.py to use real vLLM server." + ) + from .grpo_mocks import create_mock_vllm_client + + if self.accelerator.is_main_process: + self.vllm_client = create_mock_vllm_client(tokenizer, args) + else: + if not is_vllm_available(): + raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + # Setup vLLM server client (only on main process) + if self.accelerator.is_main_process: + from trl.extras.vllm_client import VLLMClient - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client.init_communicator(device=torch.cuda.current_device()) # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex @@ -557,20 +575,21 @@ def _set_signature_columns_if_needed(self): """ Set signature columns for GRPO-specific data preprocessing. - TODO: Implement signature column handling for Neuron devices. + In GRPOTrainer, we preprocess data differently than standard Trainer, + so we set the signature columns to those expected by the training_step method. """ - raise NotImplementedError( - "_set_signature_columns_if_needed is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing signature column handling for GRPO on Neuron devices." - ) + if self._signature_columns is None: + self._signature_columns = ["prompt", "image", "images"] def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): """ - Perform a prediction step during evaluation. + Evaluation and prediction are not supported in NeuronGRPOTrainer. - TODO: Implement prediction step for GRPO evaluation on Neuron devices. + The trainer is designed for training only. NeuronTrainer does not provide + evaluation loop functionality, and GRPO-specific evaluation would require + significant additional implementation. """ raise NotImplementedError( - "prediction_step is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing the prediction step for Neuron devices." + "Evaluation and prediction are not supported in NeuronGRPOTrainer. " + "The trainer is designed for training only." ) From 781b27fcada60e845c67bca4821a3dae0a0a2777 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 5 Nov 2025 15:24:44 +0100 Subject: [PATCH 23/78] fix: add is_vllm_available imports --- optimum/neuron/utils/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 04e367429..a64262d9a 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -49,6 +49,7 @@ "is_neuronx_available", "is_torch_neuronx_available", "is_trl_available", + "is_vllm_available", ], "input_generators": [ "DTYPE_MAPPER", @@ -118,6 +119,7 @@ is_neuronx_available, is_torch_neuronx_available, is_trl_available, + is_vllm_available, ) from .input_generators import ( DTYPE_MAPPER, From e932e28c2fdabc476d0d75afcaa7946d0f514f9f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 5 Nov 2025 16:26:38 +0100 Subject: [PATCH 24/78] chore: add data loading --- optimum/neuron/__init__.py | 4 +- optimum/neuron/trainers/__init__.py | 4 +- optimum/neuron/trainers/grpo_config.py | 64 ++++++++++++++++++++++++- optimum/neuron/trainers/grpo_trainer.py | 49 ++++++------------- optimum/neuron/trainers/transformers.py | 7 +-- optimum/neuron/trainers/utils.py | 15 +++++- 6 files changed, 97 insertions(+), 46 deletions(-) diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 3cace0c0a..913bd32ec 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -158,10 +158,10 @@ from .models.inference.yolos import NeuronYolosForObjectDetection from .pipelines import pipeline from .trainers import ( - NeuronSFTConfig, - NeuronSFTTrainer, NeuronGRPOConfig, NeuronGRPOTrainer, + NeuronSFTConfig, + NeuronSFTTrainer, NeuronTrainer, NeuronTrainingArguments, ) diff --git a/optimum/neuron/trainers/__init__.py b/optimum/neuron/trainers/__init__.py index 7081378b4..1d87a937d 100644 --- a/optimum/neuron/trainers/__init__.py +++ b/optimum/neuron/trainers/__init__.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .grpo_config import NeuronGRPOConfig +from .grpo_trainer import NeuronGRPOTrainer from .sft_config import NeuronSFTConfig from .sft_trainer import NeuronSFTTrainer from .training_args import NeuronTrainingArguments from .transformers import NeuronTrainer from .trl_utils import TRL_VERSION -from .grpo_trainer import NeuronGRPOTrainer -from .grpo_config import NeuronGRPOConfig diff --git a/optimum/neuron/trainers/grpo_config.py b/optimum/neuron/trainers/grpo_config.py index 32e6c4fbb..c696ab2ba 100644 --- a/optimum/neuron/trainers/grpo_config.py +++ b/optimum/neuron/trainers/grpo_config.py @@ -32,4 +32,66 @@ def __init__(self, *args, **kwargs): @dataclass class NeuronGRPOConfig(NeuronTrainingArguments, GRPOConfig): - pass + """ + Configuration class for Neuron-optimized GRPO training. + + This class combines NeuronTrainingArguments for Trainium-specific settings + with GRPOConfig for GRPO algorithm parameters. + """ + + def __post_init__(self): + # Handle bf16 default (from GRPOConfig) + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + # Call NeuronTrainingArguments.__post_init__ to initialize Neuron-specific settings + NeuronTrainingArguments.__post_init__(self) + + # Convert scale_rewards boolean to string (from GRPOConfig) + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + + num_processes = self.world_size + # The current default effective batch size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + self.steps_per_generation = self.generation_batch_size // ( + self.per_device_train_batch_size * num_processes + ) + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.do_eval and self.eval_strategy != "no": + # Just ensure the value is divisible by the global batch size + if (self.per_device_eval_batch_size * num_processes) % self.num_generations != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by num_generations ({self.num_generations})." + ) + + # The generation batch must contain full prompt groups (no partials), so it must be divisible by + # num_generations. + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) + + if self.num_generations < 2: + raise ValueError( + "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) + + if self.delta is not None and self.use_liger_loss: + raise ValueError("Liger loss does not support two-sided GRPO loss yet.") diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 01ba593aa..f9ef6f658 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -13,27 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from collections import defaultdict, deque -from functools import partial from typing import Any, Callable import datasets import torch from accelerate.utils import set_seed from optimum.utils import logging -from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data import Dataset, IterableDataset, Sampler from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, ) -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model @@ -279,7 +272,9 @@ def make_inputs_require_grad(module, input, output): if ( isinstance(train_dataset, IterableDataset) or isinstance(eval_dataset, IterableDataset) - or (isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) ): raise NotImplementedError( "Iterable datasets are not yet supported in NeuronGRPOTrainer. Please use a standard dataset instead." @@ -311,6 +306,10 @@ def make_inputs_require_grad(module, input, output): optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, ) + # Set _train_batch_size for compatibility with GRPOTrainer's get_train_dataloader + # NeuronTrainer doesn't set this, but GRPOTrainer expects it + self._train_batch_size = args.train_batch_size + # Reference model self.beta = args.beta if self.beta == 0.0: @@ -516,37 +515,17 @@ def _compute_loss(self, model, inputs): "This requires implementing the core GRPO loss computation for Neuron devices." ) - def get_train_dataloader(self): - """ - Get the training dataloader with GRPO-specific batching strategy. - - TODO: Implement GRPO-specific dataloader with proper batching for Neuron devices. - """ - raise NotImplementedError( - "get_train_dataloader is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing GRPO's custom batching strategy for Neuron devices." - ) - - def _get_train_sampler(self, dataset: Dataset | None = None): - """ - Get the training sampler with GRPO-specific sampling strategy. - - TODO: Implement RepeatSampler strategy for GRPO on Neuron devices. - """ - raise NotImplementedError( - "_get_train_sampler is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing GRPO's RepeatSampler strategy for Neuron devices." - ) - - def _get_eval_sampler(self, eval_dataset): + def _get_eval_sampler(self, eval_dataset) -> Sampler: """ Get the evaluation sampler. - TODO: Implement evaluation sampler for GRPO on Neuron devices. + Note: Evaluation is not supported in NeuronGRPOTrainer as NeuronTrainer does not + provide evaluation loops. This method is kept for interface compatibility but will + raise NotImplementedError if called. """ raise NotImplementedError( - "_get_eval_sampler is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing the evaluation sampler for Neuron devices." + "Evaluation is not supported in NeuronGRPOTrainer. " + "NeuronTrainer does not provide evaluation loops for Trainium devices." ) def _get_per_token_logps_and_entropies(self, *args, **kwargs): diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index c77d4741e..9afca8925 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -101,7 +101,7 @@ from ..utils.misc import is_main_worker, is_precompilation from .metrics import TrainingMetricsCollector from .training_args import NeuronTrainingArguments -from .utils import XLAPrefetchIterator +from .utils import XLAPrefetchIterator, move_inputs_to_device logger = logging.get_logger() @@ -933,10 +933,7 @@ def get_batch_samples( if self.pp_size == 1 and device is not None and device.type == "xla": if prefetch_size is None: - for idx, batch in enumerate(batch_samples): - batch_samples[idx] = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } + batch_samples = move_inputs_to_device(batch_samples, device) else: batch_samples = XLAPrefetchIterator(batch_samples, prefetch_size) diff --git a/optimum/neuron/trainers/utils.py b/optimum/neuron/trainers/utils.py index 595f546cb..1cdaf14b2 100644 --- a/optimum/neuron/trainers/utils.py +++ b/optimum/neuron/trainers/utils.py @@ -17,6 +17,19 @@ import torch_xla.core.xla_model as xm +def move_inputs_to_device(inputs, device: torch.device): + if isinstance(inputs, torch.Tensor): + return inputs.to(device) + elif isinstance(inputs, dict): + return {k: move_inputs_to_device(v, device) for k, v in inputs.items()} + elif isinstance(inputs, list): + return [move_inputs_to_device(v, device) for v in inputs] + elif isinstance(inputs, tuple): + return tuple(move_inputs_to_device(v, device) for v in inputs) + else: + return inputs + + class XLAPrefetchIterator: def __init__(self, examples: list[dict[str, torch.Tensor]], prefetch_size: int = 1): self.examples = examples @@ -28,7 +41,7 @@ def __init__(self, examples: list[dict[str, torch.Tensor]], prefetch_size: int = def _prefetch(self): while len(self.buffer) < self.prefetch_size and self.current_index < len(self.examples): example = self.examples[self.current_index] - example_on_xla = {k: v.to(xm.xla_device()) for k, v in example.items()} + example_on_xla = move_inputs_to_device(example, xm.xla_device()) self.buffer.append(example_on_xla) self.current_index += 1 From cef6d30fd06f43a0f34a7348e31634bb9c802612 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 5 Nov 2025 16:38:42 +0100 Subject: [PATCH 25/78] chore: add _prepare_inputs --- optimum/neuron/trainers/grpo_trainer.py | 32 +++++++++++-------------- optimum/neuron/trainers/transformers.py | 14 +++++++++++ 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index f9ef6f658..199f5c62c 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -447,29 +447,25 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N "This requires implementing GRPO-specific loss computation for Neuron devices." ) - def training_step( - self, model: torch.nn.Module, inputs: dict[str, Any], num_items_in_batch: int | None = None - ) -> torch.Tensor: + def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: """ - Perform a training step for Neuron-optimized training. + Prepare inputs for GRPO training. - TODO: Implement GRPO-specific training step adapted for Neuron devices. - """ - raise NotImplementedError( - "training_step is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing GRPO-specific training logic for Neuron devices." - ) + This method overrides NeuronTrainer._prepare_inputs to use GRPOTrainer's + implementation, which handles: + 1. Generation of completions using vLLM + 2. Scoring completions using reward functions + 3. Buffering completions for reuse across multiple gradient steps + 4. Tokenization and conversion to model inputs - def _prepare_inputs(self, inputs): - """ - Prepare inputs for GRPO training, including generation and reward computation. + Args: + inputs: Raw batch from dataloader (list of prompt dicts for GRPO) - TODO: Implement input preparation with Neuron-compatible generation and reward scoring. + Returns: + Dictionary of tokenized tensors ready for the model """ - raise NotImplementedError( - "_prepare_inputs is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing prompt generation and reward computation for Neuron devices." - ) + # Explicitly call GRPOTrainer's _prepare_inputs + return GRPOTrainer._prepare_inputs(self, inputs) def _generate(self, prompts: list[str], images: list | None): """ diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index 9afca8925..97220ffca 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -986,9 +986,23 @@ def compute_loss( return (loss, outputs) if return_outputs else loss + def _prepare_inputs(self, inputs: Any) -> Any: + """ + Prepare inputs before feeding them to the model. + + This is a no-op for standard NeuronTrainer as inputs are already moved to device in get_batch_samples(). + Subclasses can override this method for custom preprocessing (e.g., GRPOTrainer uses + this for generation, scoring, and tokenization). + + """ + return inputs + def training_step( self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch: int | torch.Tensor | None = None ) -> torch.Tensor: + # Prepare inputs (no-op for base trainer, overridden by subclasses for custom preprocessing) + inputs = self._prepare_inputs(inputs) + manager = self.autocast_smart_context_manager() with manager: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) From a567289783eff8c48067b062a2cd0b6ac7c32414 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 6 Nov 2025 15:57:08 +0100 Subject: [PATCH 26/78] chore: keep replacing stub methods --- optimum/neuron/trainers/grpo_mocks.py | 22 +++++ optimum/neuron/trainers/grpo_trainer.py | 111 ++++++++++++++++-------- 2 files changed, 95 insertions(+), 38 deletions(-) diff --git a/optimum/neuron/trainers/grpo_mocks.py b/optimum/neuron/trainers/grpo_mocks.py index 621a360af..d3e149b41 100644 --- a/optimum/neuron/trainers/grpo_mocks.py +++ b/optimum/neuron/trainers/grpo_mocks.py @@ -152,6 +152,28 @@ def init_communicator(self, device): """ pass + def update_named_param(self, name, data): + """ + Mock update of named parameter. + + In a real vLLM setup, this would sync model weights to the vLLM server. + For mock mode, this is a no-op since we're not using a real server. + + Args: + name: Parameter name + data: Parameter data tensor (not used in mock) + """ + pass + + def reset_prefix_cache(self): + """ + Mock reset of prefix cache. + + In a real vLLM setup, this would clear the KV cache for prefix caching. + For mock mode, this is a no-op since we're not using a real server. + """ + pass + def create_mock_vllm_client(tokenizer, args): """ diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 199f5c62c..46d5f99c2 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import defaultdict, deque +import inspect from typing import Any, Callable import datasets @@ -113,7 +114,7 @@ def __init__( # Model if isinstance(model, str): - model = NeuronModelForCausalLM.from_pretrained(model, **args.model_init_kwargs or {}) + model = NeuronModelForCausalLM.from_pretrained(model, args.trn_config, **args.model_init_kwargs or {}) else: if args.model_init_kwargs is not None: logger.warning( @@ -141,6 +142,18 @@ def __init__( if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + # Store tokens and token IDs for generation and reward computation + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + # Store model forward signature keys for checking supported kwargs + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + # PEFT configuration and model wrapping # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. @@ -310,6 +323,10 @@ def make_inputs_require_grad(module, input, output): # NeuronTrainer doesn't set this, but GRPOTrainer expects it self._train_batch_size = args.train_batch_size + # Set FSDP flag to False (NeuronTrainer doesn't support FSDP) + # GRPOTrainer's methods check this attribute + self.is_fsdp_enabled = False + # Reference model self.beta = args.beta if self.beta == 0.0: @@ -318,7 +335,7 @@ def make_inputs_require_grad(module, input, output): self.ref_model = None else: # Create reference model using NeuronModelForCausalLM - self.ref_model = NeuronModelForCausalLM.from_pretrained(model_id, **args.model_init_kwargs or {}) + self.ref_model = NeuronModelForCausalLM.from_pretrained(model_id, args.trn_config, **args.model_init_kwargs or {}) # Disable dropout in the models if args.disable_dropout: @@ -467,28 +484,65 @@ def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: # Explicitly call GRPOTrainer's _prepare_inputs return GRPOTrainer._prepare_inputs(self, inputs) - def _generate(self, prompts: list[str], images: list | None): - """ - Generate completions for the given prompts. - - TODO: Implement Neuron-compatible text generation. - """ - raise NotImplementedError( - "_generate is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing generation logic compatible with Neuron devices." - ) + # _generate is inherited from GRPOTrainer via the type() trick def _generate_single_turn(self, prompts: list[str], images: list | None): """ - Generate a single turn of completions. + Generate a single turn of completions using mock vLLM. + + This overrides GRPOTrainer's implementation to work with Neuron/XLA devices. + The main difference is avoiding gather_object which doesn't work on XLA. + Since we're using mock vLLM, we generate locally on each process without + gathering/broadcasting. + + Args: + prompts: List of prompt strings + images: Optional list of images - TODO: Implement single-turn generation for Neuron devices. + Returns: + Tuple of (prompt_ids, completion_ids, logprobs, forward_kwargs) """ - raise NotImplementedError( - "_generate_single_turn is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing single-turn generation for Neuron devices." + # Move model weights to vLLM if needed (no-op for mock) + if self.state.global_step != getattr(self, "_last_loaded_step", -1): + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # For mock vLLM, generate locally on each process (no gather/broadcast needed) + # Take unique prompts since we have num_generations duplicates + prompts_text = [prompt if isinstance(prompt, str) else prompt["content"] for prompt in prompts] + ordered_set_of_prompts = prompts_text[:: self.num_generations] + + if images is not None: + ordered_set_of_images = images[:: self.num_generations] + else: + ordered_set_of_images = None + + # Generate using mock vLLM client + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, ) + # Repeat prompt_ids num_generations times to match completion_ids + prompt_ids = [ids for ids in output["prompt_ids"] for _ in range(self.num_generations)] + completion_ids = output["completion_ids"] + logprobs = output["logprobs"] + + # No forward_kwargs for mock vLLM + forward_kwargs = {} + + return prompt_ids, completion_ids, logprobs, forward_kwargs + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): """ Calculate rewards for the generated completions. @@ -524,27 +578,8 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler: "NeuronTrainer does not provide evaluation loops for Trainium devices." ) - def _get_per_token_logps_and_entropies(self, *args, **kwargs): - """ - Compute per-token log probabilities and entropies. - - TODO: Implement log probability and entropy computation for Neuron devices. - """ - raise NotImplementedError( - "_get_per_token_logps_and_entropies is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing log probability computation for Neuron devices." - ) - - def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: - """ - Get mask for high-entropy tokens. - - TODO: Implement entropy-based masking for Neuron devices. - """ - raise NotImplementedError( - "get_high_entropy_mask is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing entropy-based masking for Neuron devices." - ) + # _get_per_token_logps_and_entropies and get_high_entropy_mask are inherited from GRPOTrainer + # They work with standard PyTorch operations and don't need Neuron-specific implementations def _set_signature_columns_if_needed(self): """ From c8a7ed8359f61b598ba2d0c4276c6d5ed6c1f9ab Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 6 Nov 2025 16:16:16 +0100 Subject: [PATCH 27/78] chore: add mock specific comment --- optimum/neuron/trainers/grpo_trainer.py | 28 ++++++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 46d5f99c2..8d2c7bbea 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict, deque import inspect +from collections import defaultdict, deque from typing import Any, Callable import datasets @@ -362,14 +362,17 @@ def make_inputs_require_grad(module, input, output): # Ensure each process receives a unique seed set_seed(args.seed, device_specific=True) + # =================================================================================== + # MOCK CONTROL: Set USE_MOCK_VLLM to False when using real vLLM server + # =================================================================================== # vLLM setup - server mode only from ..utils import is_vllm_available - # For now, use mock vLLM client for development - # TODO: Set to False when real vLLM server is ready for Neuron + # MOCK FLAG: Change this to False when real vLLM server is ready USE_MOCK_VLLM = True if USE_MOCK_VLLM: + # ============= MOCK-SPECIFIC BRANCH ============= logger.warning( "Using MOCK vLLM client for development. This generates placeholder completions " "and should only be used for testing and development. Set USE_MOCK_VLLM=False in " @@ -377,9 +380,11 @@ def make_inputs_require_grad(module, input, output): ) from .grpo_mocks import create_mock_vllm_client - if self.accelerator.is_main_process: - self.vllm_client = create_mock_vllm_client(tokenizer, args) + # MOCK: Each process needs its own client (generates locally, no server) + self.vllm_client = create_mock_vllm_client(tokenizer, args) + # ============= END MOCK-SPECIFIC BRANCH ============= else: + # ============= REAL vLLM SERVER BRANCH ============= if not is_vllm_available(): raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") @@ -394,6 +399,7 @@ def make_inputs_require_grad(module, input, output): self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) self.vllm_client.init_communicator(device=torch.cuda.current_device()) + # ============= END REAL vLLM SERVER BRANCH ============= # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex @@ -486,14 +492,20 @@ def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: # _generate is inherited from GRPOTrainer via the type() trick + # =================================================================================== + # MOCK-SPECIFIC OVERRIDE: This method is needed for mock vLLM mode + # When using real vLLM server, test if TRL's implementation works or if this + # override is still needed to avoid gather_object on XLA + # =================================================================================== def _generate_single_turn(self, prompts: list[str], images: list | None): """ - Generate a single turn of completions using mock vLLM. + Generate a single turn of completions using vLLM (mock or real server). This overrides GRPOTrainer's implementation to work with Neuron/XLA devices. The main difference is avoiding gather_object which doesn't work on XLA. - Since we're using mock vLLM, we generate locally on each process without - gathering/broadcasting. + + MOCK MODE: Each process generates locally without gathering/broadcasting. + REAL SERVER MODE: May need gather_object workaround - test when implementing! Args: prompts: List of prompt strings From 0ddc40f2cefc1144a6d01ff949ce86df103629f7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 13 Nov 2025 17:21:42 +0100 Subject: [PATCH 28/78] chore: wip, full training cycle with mocks --- optimum/neuron/trainers/grpo_mocks.py | 19 +- optimum/neuron/trainers/grpo_trainer.py | 571 +++++++++++++++++++++--- 2 files changed, 508 insertions(+), 82 deletions(-) diff --git a/optimum/neuron/trainers/grpo_mocks.py b/optimum/neuron/trainers/grpo_mocks.py index d3e149b41..39e819106 100644 --- a/optimum/neuron/trainers/grpo_mocks.py +++ b/optimum/neuron/trainers/grpo_mocks.py @@ -104,7 +104,6 @@ def generate( if truncate_prompt_tokens is not None and len(prompt_tokens) > truncate_prompt_tokens: prompt_tokens = prompt_tokens[-truncate_prompt_tokens:] - # Store one copy of prompt tokens (will be repeated n times by caller) prompt_ids.append(prompt_tokens) # Generate n completions per prompt @@ -114,22 +113,8 @@ def generate( # In real scenario, this would be actual LLM generation completion_length = min(max_tokens, self.max_completion_length) - # Create a simple deterministic completion based on prompt and index - # This helps with debugging as completions will be consistent - if len(prompt_tokens) > 0: - # Use last prompt token as seed for variety - seed_token = prompt_tokens[-1] - else: - seed_token = self.tokenizer.eos_token_id - - # Generate completion: alternate between seed_token and eos_token - completion = [] - for j in range(completion_length): - if j % 2 == i % 2: # Use index for variation - completion.append(seed_token) - else: - completion.append(self.tokenizer.eos_token_id) - + # Generate completion: cycle through safe token IDs + completion = [self.tokenizer.eos_token_id] * completion_length completion_ids.append(completion) # Generate mock logprobs (uniform negative values) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 8d2c7bbea..8caed58bb 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -18,10 +18,11 @@ from typing import Any, Callable import datasets +import numpy as np import torch from accelerate.utils import set_seed from optimum.utils import logging -from torch.utils.data import Dataset, IterableDataset, Sampler +from torch.utils.data import Dataset, IterableDataset from transformers import ( PreTrainedModel, PreTrainedTokenizerBase, @@ -31,7 +32,7 @@ from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model -from ..utils import is_trl_available +from ..utils import is_precompilation, is_trl_available from ..utils.import_utils import is_peft_available from .grpo_config import NeuronGRPOConfig from .training_args import NeuronTrainingArguments @@ -41,6 +42,8 @@ if is_trl_available(): from trl import GRPOConfig, GRPOTrainer + from trl.data_utils import is_conversational + from trl.trainer.utils import disable_dropout_in_model, identity, nanmax, nanmin, nanstd else: class GRPOTrainer: @@ -76,6 +79,65 @@ class PeftConfig: RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] +def pad( + tensors: list[torch.Tensor], + padding_value: int = 0, + padding_side: str = "right", + max_length: int | None = None, + ) -> torch.Tensor: + """ + Pads a list of tensors to the same shape along the first dimension. + It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid + recompilation. + """ + batch_size = len(tensors) + if max_length is None: + max_length = np.max([t.shape[0] for t in tensors]).tolist() + + output_shape = (max_length,) + tensors[0].shape[1:] + + # Create an output tensor filled with the padding value + output = torch.full((batch_size, *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) + + for i, t in enumerate(tensors): + if padding_side == "left": + seq_start = output_shape[0] - t.shape[0] + elif padding_side == "right": + seq_start = 0 + else: + raise ValueError("padding_side must be 'left' or 'right'") + + # Define the slices + seq_slice = slice(seq_start, seq_start + t.shape[0]) + slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) + output[i][slices] = t + + return output + + +def neuron_parallel_compile_tokenizer_decoder_method( + self, + token_ids: int | list[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool | None = None, + **kwargs, +) -> str: + """ + Patched `tokenizer._decode` method for Neuron parallel compilation. + This is needed because any tensor operation during `neuron_parallel_compile` produces rubbish results, which is not + an issue in general, but causes failure when the token IDS end up being out of range for the tokenizer vocabulary. + """ + if not is_precompilation(): + raise RuntimeError("This patch method should only be used with `neuron_parallel_compile`.") + + # We log the token IDs to force the data mouvement to CPU, which would happen during actual decoding. + logger.debug("Using patched tokenizer.decode method for Neuron parallel compilation, token_ids = ", token_ids) + + # Returns a dummy string, we do not care about the value in this context. + return "dummy" + + + class NeuronGRPOTrainer(_GRPOTrainer): """ `GRPOTrainer` adapted for Neuron (Trainium) devices. @@ -101,6 +163,12 @@ def __init__( if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronGRPOTrainer requires trl=={TRL_VERSION}.") + # Patch tokenizer decode method for Neuron parallel compilation to avoid failures. + if is_precompilation() and hasattr(processing_class, "_decode"): + processing_class._decode = neuron_parallel_compile_tokenizer_decoder_method.__get__( + processing_class, processing_class.__class__ + ) + # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -304,8 +372,6 @@ def make_inputs_require_grad(module, input, output): model.warnings_issued["estimate_tokens"] = True # Initialize NeuronTrainer - from trl.trainer.utils import identity - NeuronTrainer.__init__( self, model=model, @@ -339,8 +405,6 @@ def make_inputs_require_grad(module, input, output): # Disable dropout in the models if args.disable_dropout: - from trl.trainer.utils import disable_dropout_in_model - disable_dropout_in_model(model) if self.ref_model is not None: disable_dropout_in_model(self.ref_model) @@ -362,9 +426,6 @@ def make_inputs_require_grad(module, input, output): # Ensure each process receives a unique seed set_seed(args.seed, device_specific=True) - # =================================================================================== - # MOCK CONTROL: Set USE_MOCK_VLLM to False when using real vLLM server - # =================================================================================== # vLLM setup - server mode only from ..utils import is_vllm_available @@ -372,7 +433,6 @@ def make_inputs_require_grad(module, input, output): USE_MOCK_VLLM = True if USE_MOCK_VLLM: - # ============= MOCK-SPECIFIC BRANCH ============= logger.warning( "Using MOCK vLLM client for development. This generates placeholder completions " "and should only be used for testing and development. Set USE_MOCK_VLLM=False in " @@ -382,9 +442,7 @@ def make_inputs_require_grad(module, input, output): # MOCK: Each process needs its own client (generates locally, no server) self.vllm_client = create_mock_vllm_client(tokenizer, args) - # ============= END MOCK-SPECIFIC BRANCH ============= else: - # ============= REAL vLLM SERVER BRANCH ============= if not is_vllm_available(): raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") @@ -399,7 +457,6 @@ def make_inputs_require_grad(module, input, output): self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) self.vllm_client.init_communicator(device=torch.cuda.current_device()) - # ============= END REAL vLLM SERVER BRANCH ============= # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex @@ -459,17 +516,6 @@ def _save_checkpoint(self, model=None, trial=None, metrics=None): """ return NeuronTrainer._save_checkpoint(self) - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - """ - Compute training loss for Neuron-optimized training. - - TODO: Implement GRPO-specific loss computation adapted for Neuron devices. - """ - raise NotImplementedError( - "compute_loss is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing GRPO-specific loss computation for Neuron devices." - ) - def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: """ Prepare inputs for GRPO training. @@ -490,13 +536,6 @@ def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: # Explicitly call GRPOTrainer's _prepare_inputs return GRPOTrainer._prepare_inputs(self, inputs) - # _generate is inherited from GRPOTrainer via the type() trick - - # =================================================================================== - # MOCK-SPECIFIC OVERRIDE: This method is needed for mock vLLM mode - # When using real vLLM server, test if TRL's implementation works or if this - # override is still needed to avoid gather_object on XLA - # =================================================================================== def _generate_single_turn(self, prompts: list[str], images: list | None): """ Generate a single turn of completions using vLLM (mock or real server). @@ -555,53 +594,455 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): return prompt_ids, completion_ids, logprobs, forward_kwargs - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + # We patch the pad function to make it compatible with `neuron_parallel_compile`. + # patcher = Patcher([("trl.trainer.grpo_trainer.pad", pad)]) + # with patcher: + return GRPOTrainer._generate_and_score_completions(self, inputs) + + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + ): """ - Calculate rewards for the generated completions. + Override to pad sequences to max_length for XLA compilation. - TODO: Implement reward calculation compatible with Neuron devices. + GRPO generates variable-length prompts + completions. XLA compilation requires + fixed shapes, so we pad all sequences to max_length (max_prompt_length + max_completion_length). """ - raise NotImplementedError( - "_calculate_rewards is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing reward computation for Neuron devices." + # Calculate max_length from GRPO config + max_length = self.max_prompt_length + self.max_completion_length + seq_len = input_ids.shape[1] + + if seq_len < max_length: + pad_amount = max_length - seq_len + + # Pad input_ids with pad_token_id + input_ids = torch.nn.functional.pad( + input_ids, + (0, pad_amount), + value=self.pad_token_id + ) + + # Pad attention_mask + if attention_mask is not None: + attention_mask = torch.nn.functional.pad( + attention_mask, + (0, pad_amount), + value=0 # Padded positions should be masked out + ) + + # Call parent implementation with padded tensors + return GRPOTrainer._get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + compute_entropy=compute_entropy, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + num_images=num_images, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + token_type_ids=token_type_ids, ) - def _compute_loss(self, model, inputs): - """ - Internal loss computation for GRPO. + def _generate_and_score_completions( + self, inputs: list[dict[str,torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" - TODO: Implement GRPO loss computation for Neuron devices. - """ - raise NotImplementedError( - "_compute_loss is not yet implemented for NeuronGRPOTrainer. " - "This requires implementing the core GRPO loss computation for Neuron devices." + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + ( + prompt_ids_list, + completion_ids_list, + num_items_in_batch, + sampling_per_token_logps_list, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + with torch.no_grad(): + # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) + importance_sampling_ratio = torch.clamp( + importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = rewards - mean_grouped_rewards + + if self.scale_rewards in ["group", "none"]: + # If self.scale_rewards = "none", we'll still log group level std + std_rewards = rewards.view(-1, self.num_generations).std(dim=1) + std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) + elif self.scale_rewards == "batch": + # Compute global std + std_rewards = rewards.std().expand_as(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." + ) + + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_rewards.mean().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + # TODO: handle this later. + # self._logs["prompt"].extend(gather_object(prompts_text)) + # self._logs["completion"].extend(gather_object(completions_text)) + # for i, name in enumerate(self.reward_func_names): + # self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + # self._logs["advantages"].extend(all_process_advantages.tolist()) + + # if images is not None: + # self._logs["images"].extend(gather_object(images)) + + # if self.use_vllm and self.vllm_importance_sampling_correction: + # delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + # delta = delta[completion_mask.bool()] + # mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + # self.accelerator.gather(mean_delta).mean().item() + # ) + # self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + # self.accelerator.gather(max_delta).max().item() + # ) + + # flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + # min_importance_sampling_ratio = ( + # torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # mean_importance_sampling_ratio = ( + # torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # max_importance_sampling_ratio = ( + # torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + # nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + # ) + # self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + # self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + # ) + # self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + # nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + # ) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if self.use_vllm and self.vllm_importance_sampling_correction: + output["importance_sampling_ratio"] = importance_sampling_ratio + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + return output - def _get_eval_sampler(self, eval_dataset) -> Sampler: - """ - Get the evaluation sampler. + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + else: + return self._compute_loss(model, inputs) - Note: Evaluation is not supported in NeuronGRPOTrainer as NeuronTrainer does not - provide evaluation loops. This method is kept for interface compatibility but will - raise NotImplementedError if called. - """ - raise NotImplementedError( - "Evaluation is not supported in NeuronGRPOTrainer. " - "NeuronTrainer does not provide evaluation loops for Trainium devices." + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), ) - # _get_per_token_logps_and_entropies and get_high_entropy_mask are inherited from GRPOTrainer - # They work with standard PyTorch operations and don't need Neuron-specific implementations + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None - def _set_signature_columns_if_needed(self): - """ - Set signature columns for GRPO-specific data preprocessing. + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) - In GRPOTrainer, we preprocess data differently than standard Trainer, - so we set the signature columns to those expected by the training_step method. - """ - if self._signature_columns is None: - self._signature_columns = ["prompt", "image", "images"] + # Compute the loss + advantages = inputs["advantages"] + # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, + # old_per_token_logps == per_token_logps. In this case we can skip its computation + # (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # The exception is when using vLLM, where we always compute old_per_token_logps + # for importance sampling + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + if self.use_vllm and self.vllm_importance_sampling_correction: + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + loss = loss / self.current_gradient_accumulation_steps + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_pro + loss = (per_token_loss * completion_mask).sum() / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Log the metrics + mode = "train" if self.model.training else "eval" + + completion_token_count = completion_mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + return loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): """ From 8cf284249e3eaea06723c3e0a78c6fae49587691 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 14 Nov 2025 17:35:42 +0100 Subject: [PATCH 29/78] wip: grpo trainer almost working with mocks (recompilation issues) --- optimum/neuron/trainers/grpo_trainer.py | 116 +++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 8caed58bb..95d160b86 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -20,6 +20,7 @@ import datasets import numpy as np import torch +import torch_xla from accelerate.utils import set_seed from optimum.utils import logging from torch.utils.data import Dataset, IterableDataset @@ -43,7 +44,7 @@ if is_trl_available(): from trl import GRPOConfig, GRPOTrainer from trl.data_utils import is_conversational - from trl.trainer.utils import disable_dropout_in_model, identity, nanmax, nanmin, nanstd + from trl.trainer.utils import disable_dropout_in_model, identity else: class GRPOTrainer: @@ -137,6 +138,69 @@ def neuron_parallel_compile_tokenizer_decoder_method( return "dummy" +def nanmin(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanmin that doesn't use dynamic indexing. + + Compute the minimum value of a tensor, ignoring NaNs. + + Args: + tensor: Input tensor of shape `(N,)`. + + Returns: + Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. + """ + # Replace NaN with a very large value before computing min + # This avoids dynamic indexing which XLA can't handle + mask = torch.isnan(tensor) + if mask.all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + + # Replace NaNs with max float value so they don't affect min + filled = torch.where(mask, torch.tensor(float("inf"), dtype=tensor.dtype, device=tensor.device), tensor) + return torch.min(filled) + + +def nanmax(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanmax that doesn't use dynamic indexing. + + Compute the maximum value of a tensor, ignoring NaNs. + + Args: + tensor: Input tensor of shape `(N,)`. + + Returns: + Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. + """ + # Replace NaN with a very small value before computing max + mask = torch.isnan(tensor) + if mask.all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + + # Replace NaNs with min float value so they don't affect max + filled = torch.where(mask, torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device), tensor) + return torch.max(filled) + + +def nanstd(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanstd. + + Compute the standard deviation of a tensor, ignoring NaNs. + + Args: + tensor: Input tensor of shape `(N,)`. + + Returns: + Standard deviation of the tensor, ignoring NaNs. + """ + # Use torch's built-in nanmean and compute variance with Bessel's correction + variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) + count = torch.sum(~torch.isnan(tensor)) + variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero + return torch.sqrt(variance) + class NeuronGRPOTrainer(_GRPOTrainer): """ @@ -536,6 +600,46 @@ def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: # Explicitly call GRPOTrainer's _prepare_inputs return GRPOTrainer._prepare_inputs(self, inputs) + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + """ + Override to use NeuronAccelerator.gather instead of standalone gather function. + + The standalone gather from accelerate.utils may not be XLA-compatible, + so we use self.accelerator.gather which uses _xla_gather internally. + """ + device = self.accelerator.device + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) + ): + if isinstance(reward_func, torch.nn.Module): + if is_conversational(inputs[0]): + from trl.data_utils import apply_chat_template + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = NeuronTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + else: + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + rewards_per_func = self.accelerator.gather(rewards_per_func) + return rewards_per_func + def _generate_single_turn(self, prompts: list[str], images: list | None): """ Generate a single turn of completions using vLLM (mock or real server). @@ -745,6 +849,7 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) + torch_xla.sync() else: old_per_token_logps = None @@ -767,6 +872,7 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) + torch_xla.sync() else: with self.accelerator.unwrap_model(self.model).disable_adapter(): ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( @@ -778,6 +884,7 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) + torch_xla.sync() else: ref_per_token_logps = None @@ -940,6 +1047,7 @@ def _compute_loss(self, model, inputs): image_sizes=inputs.get("image_sizes"), token_type_ids=inputs.get("token_type_ids"), ) + torch_xla.sync() if self.top_entropy_quantile < 1.0: entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) @@ -1002,7 +1110,11 @@ def _compute_loss(self, model, inputs): elif self.loss_type == "bnpo": loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) loss = loss / self.current_gradient_accumulation_steps - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_pro + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "dapo": + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * completion_mask).sum() / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") From 26bdd70105de7989faffb4eb492e90fbf72ab04c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 17 Nov 2025 18:47:59 +0100 Subject: [PATCH 30/78] fix: gradient checkpointing --- .../neuron/models/training/granite/modeling_granite.py | 6 +++--- optimum/neuron/models/training/llama/modeling_llama.py | 5 +++-- optimum/neuron/models/training/training_utils.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/models/training/granite/modeling_granite.py b/optimum/neuron/models/training/granite/modeling_granite.py index 157b31924..f0f5782bb 100644 --- a/optimum/neuron/models/training/granite/modeling_granite.py +++ b/optimum/neuron/models/training/granite/modeling_granite.py @@ -21,7 +21,6 @@ scatter_to_sequence_parallel_region, ) from torch import nn -from torch_xla.utils.checkpoint import checkpoint from transformers.loss.loss_utils import ForCausalLMLoss from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -40,6 +39,7 @@ LlamaRotaryEmbedding, ) from ..masking_utils import create_causal_mask +from ..training_utils import checkpoint_with_kwargs # Wrap the gather and scatter functions to ensure they are properly traced by `torch.fx`. @@ -170,8 +170,8 @@ def forward( # Decoder layers for decoder_layer in self.layers[: self.config.num_hidden_layers]: if self.gradient_checkpointing and self.training: - hidden_states = checkpoint( - decoder_layer.__call__, + hidden_states = checkpoint_with_kwargs( + decoder_layer, hidden_states, causal_mask, position_ids, diff --git a/optimum/neuron/models/training/llama/modeling_llama.py b/optimum/neuron/models/training/llama/modeling_llama.py index 45c913626..fdc7ef878 100644 --- a/optimum/neuron/models/training/llama/modeling_llama.py +++ b/optimum/neuron/models/training/llama/modeling_llama.py @@ -45,6 +45,7 @@ from ..loss_utils import ForCausalLMLoss from ..masking_utils import create_causal_mask from ..modeling_utils import NeuronModelMixin +from ..training_utils import checkpoint_with_kwargs from ..transformations_utils import ( CustomModule, FusedLinearsSpec, @@ -667,8 +668,8 @@ def forward( # Decoder layers for decoder_layer in self.layers[: self.config.num_hidden_layers]: if self.gradient_checkpointing and self.training: - hidden_states = checkpoint( - decoder_layer.__call__, + hidden_states = checkpoint_with_kwargs( + decoder_layer, hidden_states, causal_mask, position_ids, diff --git a/optimum/neuron/models/training/training_utils.py b/optimum/neuron/models/training/training_utils.py index 43fda903c..f4fd7b561 100644 --- a/optimum/neuron/models/training/training_utils.py +++ b/optimum/neuron/models/training/training_utils.py @@ -199,3 +199,13 @@ def is_custom_modeling_model(model) -> bool: if isinstance(model, PeftModel): model_to_consider = model.get_base_model() return inspect.getmodule(model_to_consider.__class__).__name__.startswith("optimum.neuron.models.training") + + +def checkpoint_with_kwargs(fn, *args, **kwargs): + """XLA-compatible gradient checkpointing that accepts keyword arguments via functools.partial.""" + from functools import partial + + from torch_xla.utils.checkpoint import checkpoint + + fn_with_kwargs = partial(fn, **kwargs) + return checkpoint(fn_with_kwargs, *args) From e9ca8816cbe01bc5fe185eb4e6e821cce0611ac4 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 17 Nov 2025 18:48:24 +0100 Subject: [PATCH 31/78] temp: added the example script, temporary --- .../grpo_qwen3/finetune_grpo_qwen3.py | 234 ++++++++++++++++++ .../grpo_qwen3/finetune_grpo_qwen3.sh | 78 ++++++ 2 files changed, 312 insertions(+) create mode 100755 examples/training/grpo_qwen3/finetune_grpo_qwen3.py create mode 100755 examples/training/grpo_qwen3/finetune_grpo_qwen3.sh diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py new file mode 100755 index 000000000..d146c9908 --- /dev/null +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script for fine-tuning a Qwen3 model using GRPO (Group Relative Policy Optimization) on Neuron devices. + +This script demonstrates how to use NeuronGRPOTrainer to train a model with reinforcement learning +using reward functions. GRPO is particularly effective for reasoning tasks and instruction following. + +For more information about GRPO, see: https://huggingface.co/papers/2402.03300 +""" + +from dataclasses import dataclass, field + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser + +from optimum.neuron import NeuronGRPOConfig, NeuronGRPOTrainer, NeuronTrainingArguments +from optimum.neuron.models.training import NeuronModelForCausalLM + + +# ============================================================================= +# Reward Functions +# ============================================================================= +# GRPO requires reward functions to score the generated completions. +# These can be: +# 1. Model-based: Use a reward model to score completions +# 2. Rule-based: Custom Python functions that compute rewards +# +# For this example, we use simple rule-based rewards for demonstration. + + +def length_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]: + """ + Simple reward function that rewards longer responses (up to a point). + + This is a toy example. In practice, you'd want more sophisticated rewards + based on task-specific criteria (e.g., accuracy, coherence, helpfulness). + + Args: + prompts: List of input prompts + completions: List of generated completions + **kwargs: Additional arguments (e.g., trainer_state) + + Returns: + List of reward scores (one per completion) + """ + rewards = [] + for completion in completions: + # Reward based on length, but cap at 100 tokens to avoid overly long responses + length = len(completion.split()) + reward = min(length / 50.0, 2.0) # Scale: 0-2 + rewards.append(reward) + return rewards + + +def unique_words_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]: + """ + Reward function that encourages diversity by rewarding unique words. + + Args: + prompts: List of input prompts + completions: List of generated completions + **kwargs: Additional arguments + + Returns: + List of reward scores (one per completion) + """ + rewards = [] + for completion in completions: + words = completion.lower().split() + unique_words = len(set(words)) + total_words = len(words) + # Reward diversity: ratio of unique words + reward = unique_words / max(total_words, 1) + rewards.append(reward) + return rewards + + +# ============================================================================= +# Data Loading and Preprocessing Function +# ============================================================================= +# GRPO requires datasets with a "prompt" column. The trainer will generate +# multiple completions for each prompt and score them using reward functions. + + +def load_grpo_dataset(): + """ + Load and prepare a dataset for GRPO training. + + For this example, we use the "trl-internal-testing/zen" dataset which is + a simple test dataset. In practice, you'd use a dataset appropriate for + your task (e.g., math problems, coding tasks, instruction following). + + Returns: + Dataset with "prompt" column + """ + # Load a simple test dataset + # This dataset has prompts in the "prompt" column + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Take a small subset for this example + dataset = dataset.select([0] * 100000) + + return dataset + + +# ============================================================================= +# Model Loading and Training Loop Function +# ============================================================================= +def train(model_id, tokenizer, dataset, training_args): + """ + Train the model using GRPO. + + Args: + model_id: HuggingFace model ID or path + tokenizer: Tokenizer for the model + dataset: Training dataset with "prompt" column + training_args: NeuronTrainingArguments + """ + # NOTE: Models with custom modeling implementation need a TrainingNeuronConfig + # This is automatically created when using NeuronTrainingArguments + trn_config = training_args.trn_config + dtype = torch.bfloat16 if training_args.bf16 else torch.float32 + model = NeuronModelForCausalLM.from_pretrained( + model_id, + trn_config, + torch_dtype=dtype, + # Use FlashAttention2 for better performance + attn_implementation="flash_attention_2", + ) + + # LoRA configuration for efficient fine-tuning + lora_config = LoraConfig( + r=64, + lora_alpha=128, + lora_dropout=0.05, + target_modules=["embed_tokens", "q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + + # Convert NeuronTrainingArguments to dict for NeuronGRPOConfig + args = training_args.to_dict() + + # GRPO-specific configuration + grpo_config = NeuronGRPOConfig( + # Generation parameters + max_prompt_length=1024, # Maximum prompt length + max_completion_length=1024, # Maximum completion length + num_generations=4, # Number of completions to generate per prompt (G in paper) + temperature=0.8, # Sampling temperature + # GRPO algorithm parameters + num_iterations=1, # Number of iterations per batch (μ in paper) + epsilon=0.2, # Clipping parameter + beta=0.01, # KL divergence coefficient + scale_rewards="group", # Reward scaling strategy + # vLLM parameters + use_vllm=True, # Use vLLM for generation (required for Neuron) + vllm_mode="server", # Use vLLM server mode + vllm_server_host="localhost", + vllm_server_port=8000, + # Standard training arguments from NeuronTrainingArguments + **args, + ) + + # Define reward functions + # You can use multiple reward functions - they will be summed + reward_funcs = [ + length_reward, + unique_words_reward, + ] + + # Create the GRPO trainer + trainer = NeuronGRPOTrainer( + model=model, + reward_funcs=reward_funcs, + args=grpo_config, + train_dataset=dataset, + processing_class=tokenizer, + # peft_config=lora_config, + ) + + # Train the model + trainer.train() + + +# ============================================================================= +# Defining the script-specific arguments +# ============================================================================= +@dataclass +class ScriptArguments: + model_id: str = field( + metadata={"help": "The model that you want to train from the Hugging Face hub."}, + ) + + +# ============================================================================= +# Main Function +# ============================================================================= +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, NeuronTrainingArguments)) + script_args, training_args = parser.parse_args_into_dataclasses() + + tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) + + # Ensure tokenizer has pad token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load dataset + dataset = load_grpo_dataset() + + # Start training + train( + model_id=script_args.model_id, + tokenizer=tokenizer, + dataset=dataset, + training_args=training_args, + ) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh new file mode 100755 index 000000000..5458ae86b --- /dev/null +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# ============================================================================ +# GRPO Fine-tuning Script for Qwen3 on AWS Trainium +# ============================================================================ +# This script demonstrates how to fine-tune a Qwen3 model using GRPO +# (Group Relative Policy Optimization) on AWS Trainium devices. +# +# Prerequisites: +# 1. vLLM server running (or use mock vLLM for development) +# 2. Neuron SDK installed +# 3. Multi-node Trainium instance (e.g., trn1.32xlarge) +# +# For mock vLLM (development/testing): +# - Set USE_MOCK_VLLM=True in grpo_trainer.py +# +# For real vLLM: +# - Start vLLM server with: trl vllm-serve --model MODEL_NAME +# ============================================================================ + +# Flags for Neuron compilation +export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation" +export NEURON_FUSE_SOFTMAX=1 +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime +export MALLOC_ARENA_MAX=64 # Host OOM mitigation + +# Variables for training +PROCESSES_PER_NODE=2 +NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT +TP_DEGREE=1 +BS=1 +GRADIENT_ACCUMULATION_STEPS=4 # Smaller for GRPO due to generation overhead +LOGGING_STEPS=1 +MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing +OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" +DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE" +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +# GRPO-specific variables +NUM_GENERATIONS=4 # Number of completions per prompt (G in paper) +MAX_PROMPT_LENGTH=512 +MAX_COMPLETION_LENGTH=256 +TEMPERATURE=0.8 +STEPS_PER_GENERATION=4 # Generate every N steps to amortize generation cost + +if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then + MAX_STEPS=5 +else + MAX_STEPS=100 # Limit steps for testing +fi + +# Note: Adjust these parameters based on your hardware and task +# - Increase num_generations for better exploration (but slower training) +# - Adjust temperature for sampling diversity +# - Tune epsilon and beta for GRPO algorithm sensitivity + +torchrun $DISTRIBUTED_ARGS finetune_grpo_qwen3.py \ + --model_id $MODEL_NAME \ + --num_train_epochs $NUM_EPOCHS \ + --do_train \ + --max_steps $MAX_STEPS \ + --per_device_train_batch_size $BS \ + --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ + --gradient_checkpointing \ + --learning_rate 5e-5 \ + --bf16 \ + --tensor_parallel_size $TP_DEGREE \ + --zero_1 \ + --async_save \ + --logging_steps $LOGGING_STEPS \ + --output_dir $OUTPUT_DIR \ + --lr_scheduler_type "cosine" \ + --overwrite_output_dir + +echo "================================" +echo "Training completed!" +echo "Model saved to: $OUTPUT_DIR" +echo "================================" From 8eed696769994ea6040cf5d9f17243bcbcff32b6 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 18 Nov 2025 10:12:15 +0100 Subject: [PATCH 32/78] chore: wip, added torch.sync() --- optimum/neuron/trainers/grpo_trainer.py | 42 +++++++------------------ 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 95d160b86..f70bf49c7 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -141,22 +141,11 @@ def neuron_parallel_compile_tokenizer_decoder_method( def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ XLA-compatible version of nanmin that doesn't use dynamic indexing. - Compute the minimum value of a tensor, ignoring NaNs. - - Args: - tensor: Input tensor of shape `(N,)`. - - Returns: - Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. """ - # Replace NaN with a very large value before computing min - # This avoids dynamic indexing which XLA can't handle mask = torch.isnan(tensor) if mask.all(): return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) - - # Replace NaNs with max float value so they don't affect min filled = torch.where(mask, torch.tensor(float("inf"), dtype=tensor.dtype, device=tensor.device), tensor) return torch.min(filled) @@ -164,21 +153,11 @@ def nanmin(tensor: torch.Tensor) -> torch.Tensor: def nanmax(tensor: torch.Tensor) -> torch.Tensor: """ XLA-compatible version of nanmax that doesn't use dynamic indexing. - Compute the maximum value of a tensor, ignoring NaNs. - - Args: - tensor: Input tensor of shape `(N,)`. - - Returns: - Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. """ - # Replace NaN with a very small value before computing max mask = torch.isnan(tensor) if mask.all(): return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) - - # Replace NaNs with min float value so they don't affect max filled = torch.where(mask, torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device), tensor) return torch.max(filled) @@ -186,19 +165,13 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: def nanstd(tensor: torch.Tensor) -> torch.Tensor: """ XLA-compatible version of nanstd. - Compute the standard deviation of a tensor, ignoring NaNs. - - Args: - tensor: Input tensor of shape `(N,)`. - - Returns: - Standard deviation of the tensor, ignoring NaNs. """ # Use torch's built-in nanmean and compute variance with Bessel's correction - variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) - count = torch.sum(~torch.isnan(tensor)) - variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero + # variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) + # count = torch.sum(~torch.isnan(tensor)) + # variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero + return torch.tensor(1) return torch.sqrt(variance) @@ -791,6 +764,7 @@ def _generate_and_score_completions( sampling_per_token_logps_list, forward_kwargs, ) = self._generate(prompts, images) + torch_xla.sync() # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -806,6 +780,7 @@ def _generate_and_score_completions( sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") else: sampling_per_token_logps = None + torch_xla.sync() # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -889,6 +864,7 @@ def _generate_and_score_completions( ref_per_token_logps = None # Decode + torch_xla.sync() prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): @@ -903,6 +879,7 @@ def _generate_and_score_completions( # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + torch_xla.sync() # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -937,6 +914,7 @@ def _generate_and_score_completions( ) all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] + torch_xla.sync() # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): @@ -1017,6 +995,8 @@ def _generate_and_score_completions( output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: output["num_images"] = num_images + + torch_xla.sync() return output def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): From 02b625331bb66cc4e740f136e3a3cd80cf738004 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 18 Nov 2025 17:53:07 +0100 Subject: [PATCH 33/78] wip: fix computation device in --- optimum/neuron/trainers/grpo_trainer.py | 132 ++++++++++++++---------- 1 file changed, 77 insertions(+), 55 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index f70bf49c7..ec5a0e5e7 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -20,7 +20,10 @@ import datasets import numpy as np import torch +import torch.utils._pytree as pytree import torch_xla +import torch_xla.core.xla_model as xm +from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from accelerate.utils import set_seed from optimum.utils import logging from torch.utils.data import Dataset, IterableDataset @@ -168,10 +171,9 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: Compute the standard deviation of a tensor, ignoring NaNs. """ # Use torch's built-in nanmean and compute variance with Bessel's correction - # variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) - # count = torch.sum(~torch.isnan(tensor)) - # variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero - return torch.tensor(1) + variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) + count = torch.sum(~torch.isnan(tensor)) + variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero return torch.sqrt(variance) @@ -318,6 +320,7 @@ def make_inputs_require_grad(module, input, output): self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) else: self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + self.reward_weights = self.reward_weights.to(xm.xla_device()) # Reward processing class if reward_processing_classes is None: @@ -764,7 +767,6 @@ def _generate_and_score_completions( sampling_per_token_logps_list, forward_kwargs, ) = self._generate(prompts, images) - torch_xla.sync() # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -780,7 +782,6 @@ def _generate_and_score_completions( sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") else: sampling_per_token_logps = None - torch_xla.sync() # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -803,6 +804,9 @@ def _generate_and_score_completions( num_images = [len(img_list) for img_list in images] if images is not None else None + # Graph break before computing the log probabilities. + torch_xla.sync() + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -824,10 +828,11 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) - torch_xla.sync() else: old_per_token_logps = None + torch_xla.sync() + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) @@ -835,6 +840,8 @@ def _generate_and_score_completions( importance_sampling_ratio, max=self.vllm_importance_sampling_cap ) + torch_xla.sync() + # Compute the per-token log probabilities for the reference model if self.beta != 0.0: if self.ref_model is not None: @@ -847,7 +854,6 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) - torch_xla.sync() else: with self.accelerator.unwrap_model(self.model).disable_adapter(): ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( @@ -859,12 +865,11 @@ def _generate_and_score_completions( num_images=num_images, **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) - torch_xla.sync() else: ref_per_token_logps = None - # Decode torch_xla.sync() + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): @@ -879,10 +884,9 @@ def _generate_and_score_completions( # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - torch_xla.sync() # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + rewards = (rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) @@ -914,60 +918,78 @@ def _generate_and_score_completions( ) all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - torch_xla.sync() + + metrics = defaultdict(list) + logs = {} # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): - mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) - std_func_rewards = nanstd(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) - self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) - self._metrics[mode]["reward_std"].append(std_rewards.mean().item()) - self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + mean_rewards = torch.nanmean(rewards_per_func[:, i]) + metrics[f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]) + metrics[f"rewards/{reward_func_name}/std"].append(std_func_rewards) + + metrics["reward"].append(mean_grouped_rewards.mean()) + metrics["reward_std"].append(std_rewards.mean()) + metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) # Log prompt and completion texts - # TODO: handle this later. # self._logs["prompt"].extend(gather_object(prompts_text)) # self._logs["completion"].extend(gather_object(completions_text)) - # for i, name in enumerate(self.reward_func_names): - # self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) - # self._logs["advantages"].extend(all_process_advantages.tolist()) + logs["rewards"] = {} + logs["advantages"] = [] + for i, name in enumerate(self.reward_func_names): + logs["rewards"][name] = rewards_per_func[:, i] + logs["advantages"] = all_process_advantages # if images is not None: # self._logs["images"].extend(gather_object(images)) - # if self.use_vllm and self.vllm_importance_sampling_correction: - # delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - # delta = delta[completion_mask.bool()] - # mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - # max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - # self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( - # self.accelerator.gather(mean_delta).mean().item() - # ) - # self._metrics[mode]["sampling/sampling_logp_difference/max"].append( - # self.accelerator.gather(max_delta).max().item() - # ) - - # flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] - # min_importance_sampling_ratio = ( - # torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - # ) - # mean_importance_sampling_ratio = ( - # torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - # ) - # max_importance_sampling_ratio = ( - # torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - # ) - # self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( - # nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() - # ) - # self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( - # self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() - # ) - # self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( - # nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() - # ) + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + delta = delta[completion_mask.bool()] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + metrics["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max() + ) + + flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + metrics["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)) + ) + metrics["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean() + ) + metrics["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)) + ) + + # Graph break after metrics and logs computation. + torch_xla.sync() + + # Move metrics and logs to CPU. + metrics = pytree.tree_map(move_all_tensor_to_cpu, metrics) + logs = pytree.tree_map(move_all_tensor_to_cpu, logs) + + # Update the actual metrics and logs. + self._metrics[mode].update(metrics) + for name in self.reward_func_names: + self._logs["rewards"][name].extend(logs["rewards"][name].tolist()) + self._logs["advantages"].extend(logs["advantages"].tolist()) + output = { "prompt_ids": prompt_ids, From e844af95af80526072b230abf0ae7fa8d13e1a50 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 18 Nov 2025 18:22:13 +0100 Subject: [PATCH 34/78] wip: fix computation device in --- optimum/neuron/accelerate/accelerator.py | 70 +++++++++++++++++-- optimum/neuron/accelerate/utils/operations.py | 56 --------------- optimum/neuron/trainers/grpo_trainer.py | 19 +++-- 3 files changed, 72 insertions(+), 73 deletions(-) delete mode 100644 optimum/neuron/accelerate/utils/operations.py diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index bcb1be71b..9b6d54ff7 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -15,6 +15,7 @@ import contextlib import os +import pickle import re import shutil import sys @@ -23,6 +24,7 @@ from typing import Any, Callable import torch +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from accelerate import Accelerator @@ -33,6 +35,7 @@ from neuronx_distributed.optimizer import NeuronZero1Optimizer from neuronx_distributed.parallel_layers.parallel_state import ( get_context_model_parallel_size, + get_data_parallel_group, get_data_parallel_replica_groups, get_data_parallel_size, get_tensor_model_parallel_replica_groups, @@ -66,7 +69,6 @@ apply_activation_checkpointing, create_patched_save_pretrained, ) -from .utils.operations import _xla_gather # Setup logging so that the main process logs at the INFO level and the others are silent. @@ -547,10 +549,65 @@ def save_state( output_dir=output_dir, safe_serialization=safe_serialization, **save_model_func_kwargs ) - def gather(self, tensor, out_of_graph: bool = False): - return _xla_gather(tensor, out_of_graph=out_of_graph) - - def gather_for_metrics(self, input_data, use_gather_object: bool = False): + def gather(self, tensor, sync: bool = False): + groups = get_data_parallel_group(as_list=True) + gathered = xm.all_gather(tensor, groups=groups, pin_layout=False) + if sync: + torch_xla.sync() + return gathered + + def gather_object(self, obj: Any) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed processes. + Returns list of objects from all ranks on all ranks. + + Note: Requires two all-gather operations (lengths then data). + For small objects, this overhead may be significant. + """ + world_size = get_data_parallel_size() + + # Early exit for single process + if world_size == 1: + return [obj] + + groups = get_data_parallel_group(as_list=True) + + # Step 1: Serialize to bytes + serialized = pickle.dumps(obj) + byte_len = len(serialized) + + # Step 2: Convert to tensor on XLA device + byte_tensor = torch.frombuffer(serialized, dtype=torch.uint8).clone() + byte_tensor = byte_tensor.to(xm.xla_device()) + + # Step 3: Gather lengths + len_tensor = torch.tensor([byte_len], dtype=torch.int64, device=byte_tensor.device) + len_list = [torch.zeros_like(len_tensor) for _ in range(world_size)] + xm.all_gather(len_list, len_tensor, groups=groups, pin_layout=False) + + # Step 4: Pad to max length (fixed shape for XLA) + max_len = max(int(l.item()) for l in len_list) + padded = torch.zeros(max_len, dtype=torch.uint8, device=byte_tensor.device) + padded[:byte_len] = byte_tensor + + # Step 5: Gather padded data + gathered_tensors = [torch.zeros_like(padded) for _ in range(world_size)] + xm.all_gather(gathered_tensors, padded, groups=groups, pin_layout=False) + + # Step 6: Sync once, then transfer to CPU + torch_xla.sync() + cpu_tensors = [t.cpu() for t in gathered_tensors] + + # Step 7: Deserialize + results = [] + for tensor, length in zip(cpu_tensors, len_list): + actual_len = int(length.item()) + valid_bytes = tensor[:actual_len].numpy().tobytes() + results.append(pickle.loads(valid_bytes)) + + return results + + def gather_for_metrics(self, input_data, use_gather_object: bool = False, sync: bool = False): try: recursively_apply(lambda x: x, input_data, error_on_other_type=True) all_tensors = True @@ -562,8 +619,7 @@ def gather_for_metrics(self, input_data, use_gather_object: bool = False): if use_gather_object: data = gather_object(input_data) else: - # It is needed to perform out-of-graph gather otherwise re-compilation happens at every evaluation step. - data = self.gather(input_data, out_of_graph=True) + data = self.gather(input_data, sync=sync) try: if self.gradient_state.end_of_dataloader: diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py deleted file mode 100644 index 683086684..000000000 --- a/optimum/neuron/accelerate/utils/operations.py +++ /dev/null @@ -1,56 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch_xla.core.xla_model as xm -from accelerate.utils.operations import recursively_apply -from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_group, - model_parallel_is_initialized, -) - - -def _xla_gather(tensor, out_of_graph: bool = False): - groups = None - if model_parallel_is_initialized(): - groups = get_data_parallel_group(as_list=True) - - def _xla_gather_one(tensor): - if tensor.ndim == 0: - tensor = tensor.clone()[None] - # Can only gather contiguous tensors - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - - if out_of_graph: - gathered_tensors = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x) - if groups is not None: - new_gathered_tensors = [] - # Since groups is containing list of group of replicas, we consider that visiting the first group of - # replicas is enough since the value should be the same across other axes. - replicas_to_consider = set(groups[0]) - for idx, tensor in enumerate(gathered_tensors): - if idx not in replicas_to_consider: - continue - new_gathered_tensors.append(tensor) - gathered_tensors = new_gathered_tensors - gathered = torch.cat(gathered_tensors) - else: - gathered = xm.all_gather(tensor, groups=groups, pin_layout=False) - return gathered - - res = recursively_apply(_xla_gather_one, tensor, error_on_other_type=True) - xm.mark_step() - return res diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index ec5a0e5e7..3cfc6cf18 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -23,8 +23,8 @@ import torch.utils._pytree as pytree import torch_xla import torch_xla.core.xla_model as xm -from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from accelerate.utils import set_seed +from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from optimum.utils import logging from torch.utils.data import Dataset, IterableDataset from transformers import ( @@ -91,7 +91,7 @@ def pad( ) -> torch.Tensor: """ Pads a list of tensors to the same shape along the first dimension. - It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid + It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid recompilation. """ batch_size = len(tensors) @@ -934,16 +934,16 @@ def _generate_and_score_completions( metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) # Log prompt and completion texts - # self._logs["prompt"].extend(gather_object(prompts_text)) - # self._logs["completion"].extend(gather_object(completions_text)) + self._logs["prompt"].extend(self.accelerator.gather_object(prompts_text)) + self._logs["completion"].extend(self.accelerator.gather_object(completions_text)) logs["rewards"] = {} logs["advantages"] = [] for i, name in enumerate(self.reward_func_names): logs["rewards"][name] = rewards_per_func[:, i] logs["advantages"] = all_process_advantages - # if images is not None: - # self._logs["images"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(self.accelerator.gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -977,13 +977,13 @@ def _generate_and_score_completions( nanmax(self.accelerator.gather(max_importance_sampling_ratio)) ) - # Graph break after metrics and logs computation. - torch_xla.sync() - # Move metrics and logs to CPU. metrics = pytree.tree_map(move_all_tensor_to_cpu, metrics) logs = pytree.tree_map(move_all_tensor_to_cpu, logs) + # Graph break after metrics and logs computation. + torch_xla.sync() + # Update the actual metrics and logs. self._metrics[mode].update(metrics) for name in self.reward_func_names: @@ -1018,7 +1018,6 @@ def _generate_and_score_completions( if images is not None: output["num_images"] = num_images - torch_xla.sync() return output def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): From 745674dd92a94cca742bdf20a230df4fbe3400da Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 18 Nov 2025 19:03:27 +0100 Subject: [PATCH 35/78] precompilation --- optimum/neuron/accelerate/accelerator.py | 36 ++-- optimum/neuron/trainers/grpo_trainer.py | 200 +++++++++++++++-------- 2 files changed, 154 insertions(+), 82 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 9b6d54ff7..7e9e8288a 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -551,7 +551,11 @@ def save_state( def gather(self, tensor, sync: bool = False): groups = get_data_parallel_group(as_list=True) - gathered = xm.all_gather(tensor, groups=groups, pin_layout=False) + + # Ensure tensor is at least 1D for all_gather (scalars need to be unsqueezed) + input_tensor = tensor.unsqueeze(0) if tensor.ndim == 0 else tensor + gathered = xm.all_gather(input_tensor, dim=0, groups=groups, pin_layout=False) + if sync: torch_xla.sync() return gathered @@ -572,38 +576,36 @@ def gather_object(self, obj: Any) -> list[Any]: groups = get_data_parallel_group(as_list=True) - # Step 1: Serialize to bytes serialized = pickle.dumps(obj) byte_len = len(serialized) - # Step 2: Convert to tensor on XLA device byte_tensor = torch.frombuffer(serialized, dtype=torch.uint8).clone() byte_tensor = byte_tensor.to(xm.xla_device()) - # Step 3: Gather lengths len_tensor = torch.tensor([byte_len], dtype=torch.int64, device=byte_tensor.device) - len_list = [torch.zeros_like(len_tensor) for _ in range(world_size)] - xm.all_gather(len_list, len_tensor, groups=groups, pin_layout=False) + # all_gather concatenates along dim=0, so [1] -> [world_size] + gathered_lengths = xm.all_gather(len_tensor, dim=0, groups=groups, pin_layout=False) + + torch_xla.sync() + max_len = int(gathered_lengths.max().item()) - # Step 4: Pad to max length (fixed shape for XLA) - max_len = max(int(l.item()) for l in len_list) padded = torch.zeros(max_len, dtype=torch.uint8, device=byte_tensor.device) padded[:byte_len] = byte_tensor - # Step 5: Gather padded data - gathered_tensors = [torch.zeros_like(padded) for _ in range(world_size)] - xm.all_gather(gathered_tensors, padded, groups=groups, pin_layout=False) + # all_gather concatenates, so [max_len] -> [world_size * max_len] + gathered_data = xm.all_gather(padded, dim=0, groups=groups, pin_layout=False) - # Step 6: Sync once, then transfer to CPU torch_xla.sync() - cpu_tensors = [t.cpu() for t in gathered_tensors] + gathered_data_cpu = gathered_data.cpu() + gathered_lengths_cpu = gathered_lengths.cpu() - # Step 7: Deserialize results = [] - for tensor, length in zip(cpu_tensors, len_list): - actual_len = int(length.item()) - valid_bytes = tensor[:actual_len].numpy().tobytes() + offset = 0 + for i in range(world_size): + actual_len = int(gathered_lengths_cpu[i].item()) + valid_bytes = gathered_data_cpu[offset:offset + max_len][:actual_len].numpy().tobytes() results.append(pickle.loads(valid_bytes)) + offset += max_len return results diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 3cfc6cf18..46d2e71a4 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -47,7 +47,12 @@ if is_trl_available(): from trl import GRPOConfig, GRPOTrainer from trl.data_utils import is_conversational - from trl.trainer.utils import disable_dropout_in_model, identity + from trl.trainer.utils import ( + disable_dropout_in_model, + entropy_from_logits, + identity, + selective_log_softmax, + ) else: class GRPOTrainer: @@ -682,13 +687,51 @@ def _generate_and_score_completions( # with patcher: return GRPOTrainer._generate_and_score_completions(self, inputs) + def _to_fixed_length( + self, + tensor: torch.Tensor, + padding_value: int = 0, + padding_side: str = "right" + ) -> torch.Tensor: + """ + Pads or truncates tensor to fixed length for XLA compilation. + + XLA requires static shapes at graph construction time. This method ensures + all tensors (input_ids, attention_mask) have the same fixed length to enable + graph reuse across training steps. + + Args: + tensor: Input tensor to pad/truncate (2D: batch × seq_len) + padding_value: Value to use for padding (default: 0) + padding_side: "left" or "right" padding (default: "right") + + Returns: + Tensor with fixed length = max_prompt_length + max_completion_length + """ + fixed_length = self.max_prompt_length + self.max_completion_length + seq_len = tensor.shape[1] + + if seq_len == fixed_length: + return tensor + elif seq_len < fixed_length: + # Pad to fixed length + pad_amount = fixed_length - seq_len + pad_config = (pad_amount, 0) if padding_side == "left" else (0, pad_amount) + return torch.nn.functional.pad(tensor, pad_config, value=padding_value) + else: + # Truncate to fixed length + if padding_side == "left": + return tensor[:, -fixed_length:] + else: + return tensor[:, :fixed_length] + def _get_per_token_logps_and_entropies( self, model, input_ids, attention_mask, logits_to_keep, - batch_size=None, + batch_size, # Compared to the original `trl` implementation, `batch_size` must be specified. compute_entropy=False, pixel_values=None, image_grid_thw=None, @@ -696,51 +739,80 @@ def _get_per_token_logps_and_entropies( pixel_attention_mask=None, image_sizes=None, token_type_ids=None, - ): - """ - Override to pad sequences to max_length for XLA compilation. + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Make sure the inputs have a fixed shape. + input_ids = self._to_fixed_length( + input_ids, padding_value=self.pad_token_id, padding_side="left" + ) + attention_mask = self._to_fixed_length( + attention_mask, padding_value=0, padding_side="left" + ) - GRPO generates variable-length prompts + completions. XLA compilation requires - fixed shapes, so we pad all sequences to max_length (max_prompt_length + max_completion_length). - """ - # Calculate max_length from GRPO config - max_length = self.max_prompt_length + self.max_completion_length - seq_len = input_ids.shape[1] - - if seq_len < max_length: - pad_amount = max_length - seq_len - - # Pad input_ids with pad_token_id - input_ids = torch.nn.functional.pad( - input_ids, - (0, pad_amount), - value=self.pad_token_id - ) + # Force synchronization before starting computation to re-use the same graph. + torch_xla.sync() - # Pad attention_mask - if attention_mask is not None: - attention_mask = torch.nn.functional.pad( - attention_mask, - (0, pad_amount), - value=0 # Padded positions should be masked out - ) + batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + all_logps = [] + all_entropies = [] + # TODO: check if it's ok with TORCH XLA + for start in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[start : start + batch_size] + attention_mask_batch = attention_mask[start : start + batch_size] + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + # TODO: not support with torch XLA, fix it later. + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + + completion_ids = input_ids_batch[:, -logits_to_keep:] + logps = selective_log_softmax(logits, completion_ids) # compute logprobs + all_logps.append(logps) + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + all_entropies.append(entropies) + + logps = torch.cat(all_logps, dim=0) + entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + + # Force synchronization after computation to ensure graph is re-used. + torch_xla.sync() - # Call parent implementation with padded tensors - return GRPOTrainer._get_per_token_logps_and_entropies( - self, - model, - input_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - compute_entropy=compute_entropy, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - num_images=num_images, - pixel_attention_mask=pixel_attention_mask, - image_sizes=image_sizes, - token_type_ids=token_type_ids, - ) + return logps, entropies def _generate_and_score_completions( self, inputs: list[dict[str,torch.Tensor | Any]] @@ -804,9 +876,6 @@ def _generate_and_score_completions( num_images = [len(img_list) for img_list in images] if images is not None else None - # Graph break before computing the log probabilities. - torch_xla.sync() - with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -831,8 +900,6 @@ def _generate_and_score_completions( else: old_per_token_logps = None - torch_xla.sync() - # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) @@ -840,8 +907,6 @@ def _generate_and_score_completions( importance_sampling_ratio, max=self.vllm_importance_sampling_cap ) - torch_xla.sync() - # Compute the per-token log probabilities for the reference model if self.beta != 0.0: if self.ref_model is not None: @@ -868,8 +933,6 @@ def _generate_and_score_completions( else: ref_per_token_logps = None - torch_xla.sync() - prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): @@ -934,16 +997,16 @@ def _generate_and_score_completions( metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) # Log prompt and completion texts - self._logs["prompt"].extend(self.accelerator.gather_object(prompts_text)) - self._logs["completion"].extend(self.accelerator.gather_object(completions_text)) + # self._logs["prompt"].extend(self.accelerator.gather_object(prompts_text)) + # self._logs["completion"].extend(self.accelerator.gather_object(completions_text)) logs["rewards"] = {} logs["advantages"] = [] for i, name in enumerate(self.reward_func_names): logs["rewards"][name] = rewards_per_func[:, i] logs["advantages"] = all_process_advantages - if images is not None: - self._logs["images"].extend(self.accelerator.gather_object(images)) + # if images is not None: + # self._logs["images"].extend(self.accelerator.gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -1040,6 +1103,7 @@ def _compute_loss(self, model, inputs): input_ids, attention_mask, logits_to_keep, + batch_size=self.args.per_device_train_batch_size, compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), @@ -1048,7 +1112,6 @@ def _compute_loss(self, model, inputs): image_sizes=inputs.get("image_sizes"), token_type_ids=inputs.get("token_type_ids"), ) - torch_xla.sync() if self.top_entropy_quantile < 1.0: entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) @@ -1131,12 +1194,14 @@ def masked_batch_mean(x): else: return (x * completion_mask).sum() / completion_token_count + metrics = defaultdict(list) + if self.beta != 0.0: mean_kl = masked_batch_mean(per_token_kl) - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + metrics["kl"].append(self.accelerator.gather(mean_kl).nanmean()) mean_entropy = masked_batch_mean(entropies) - self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + metrics["entropy"].append(self.accelerator.gather(mean_entropy).nanmean()) # Compute the clipped probability ratios is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) @@ -1148,13 +1213,18 @@ def masked_batch_mean(x): clip_ratio = masked_batch_mean(is_region_clipped.float()) gathered_low_clip = self.accelerator.gather(low_clip) - self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + metrics["clip_ratio/low_mean"].append(gathered_low_clip.nanmean()) + metrics["clip_ratio/low_min"].append(nanmin(gathered_low_clip)) gathered_high_clip = self.accelerator.gather(high_clip) - self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + metrics["clip_ratio/high_mean"].append(gathered_high_clip.nanmean()) + metrics["clip_ratio/high_max"].append(nanmax(gathered_high_clip)) gathered_clip_ratio = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) + + torch_xla.sync() # Graph break before moving metrics to CPU. + metrics = pytree.tree_map(move_all_tensor_to_cpu, metrics) + self._metrics[mode].update(metrics) + return loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): From b82025d763e9087b504eefb879626c077ee2bef8 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 19 Nov 2025 16:19:17 +0100 Subject: [PATCH 36/78] make ops XLA friendly --- optimum/neuron/trainers/grpo_trainer.py | 334 ++++++++++++------------ optimum/neuron/trainers/trl_utils.py | 103 ++++++++ 2 files changed, 264 insertions(+), 173 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 46d2e71a4..fda5880cc 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -18,9 +18,7 @@ from typing import Any, Callable import datasets -import numpy as np import torch -import torch.utils._pytree as pytree import torch_xla import torch_xla.core.xla_model as xm from accelerate.utils import set_seed @@ -32,7 +30,9 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_wandb_available, ) +from transformers.utils import is_rich_available from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model @@ -41,9 +41,12 @@ from .grpo_config import NeuronGRPOConfig from .training_args import NeuronTrainingArguments from .transformers import NeuronTrainer -from .trl_utils import TRL_VERSION +from .trl_utils import TRL_VERSION, nanmax, nanmin, nanstd, neuron_parallel_compile_tokenizer_decoder_method, pad +if is_wandb_available(): + import wandb + if is_trl_available(): from trl import GRPOConfig, GRPOTrainer from trl.data_utils import is_conversational @@ -51,6 +54,7 @@ disable_dropout_in_model, entropy_from_logits, identity, + print_prompt_completions_sample, selective_log_softmax, ) else: @@ -88,100 +92,6 @@ class PeftConfig: RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] -def pad( - tensors: list[torch.Tensor], - padding_value: int = 0, - padding_side: str = "right", - max_length: int | None = None, - ) -> torch.Tensor: - """ - Pads a list of tensors to the same shape along the first dimension. - It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid - recompilation. - """ - batch_size = len(tensors) - if max_length is None: - max_length = np.max([t.shape[0] for t in tensors]).tolist() - - output_shape = (max_length,) + tensors[0].shape[1:] - - # Create an output tensor filled with the padding value - output = torch.full((batch_size, *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) - - for i, t in enumerate(tensors): - if padding_side == "left": - seq_start = output_shape[0] - t.shape[0] - elif padding_side == "right": - seq_start = 0 - else: - raise ValueError("padding_side must be 'left' or 'right'") - - # Define the slices - seq_slice = slice(seq_start, seq_start + t.shape[0]) - slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) - output[i][slices] = t - - return output - - -def neuron_parallel_compile_tokenizer_decoder_method( - self, - token_ids: int | list[int], - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool | None = None, - **kwargs, -) -> str: - """ - Patched `tokenizer._decode` method for Neuron parallel compilation. - This is needed because any tensor operation during `neuron_parallel_compile` produces rubbish results, which is not - an issue in general, but causes failure when the token IDS end up being out of range for the tokenizer vocabulary. - """ - if not is_precompilation(): - raise RuntimeError("This patch method should only be used with `neuron_parallel_compile`.") - - # We log the token IDs to force the data mouvement to CPU, which would happen during actual decoding. - logger.debug("Using patched tokenizer.decode method for Neuron parallel compilation, token_ids = ", token_ids) - - # Returns a dummy string, we do not care about the value in this context. - return "dummy" - - -def nanmin(tensor: torch.Tensor) -> torch.Tensor: - """ - XLA-compatible version of nanmin that doesn't use dynamic indexing. - Compute the minimum value of a tensor, ignoring NaNs. - """ - mask = torch.isnan(tensor) - if mask.all(): - return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) - filled = torch.where(mask, torch.tensor(float("inf"), dtype=tensor.dtype, device=tensor.device), tensor) - return torch.min(filled) - - -def nanmax(tensor: torch.Tensor) -> torch.Tensor: - """ - XLA-compatible version of nanmax that doesn't use dynamic indexing. - Compute the maximum value of a tensor, ignoring NaNs. - """ - mask = torch.isnan(tensor) - if mask.all(): - return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) - filled = torch.where(mask, torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device), tensor) - return torch.max(filled) - - -def nanstd(tensor: torch.Tensor) -> torch.Tensor: - """ - XLA-compatible version of nanstd. - Compute the standard deviation of a tensor, ignoring NaNs. - """ - # Use torch's built-in nanmean and compute variance with Bessel's correction - variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) - count = torch.sum(~torch.isnan(tensor)) - variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero - return torch.sqrt(variance) - - class NeuronGRPOTrainer(_GRPOTrainer): """ `GRPOTrainer` adapted for Neuron (Trainium) devices. @@ -539,55 +449,65 @@ def train( self, resume_from_checkpoint: str | bool | None = None, ): - """ - Main training entry point. - - Args: - resume_from_checkpoint: Path to a checkpoint to resume from, or True to resume from the latest checkpoint. - """ return NeuronTrainer.train(self, resume_from_checkpoint=resume_from_checkpoint) - def log(self, logs: dict[str, float]) -> None: - """ - Override GRPOTrainer's log method to use NeuronTrainer's implementation. + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics - GRPOTrainer has custom metrics tracking that we don't use for Neuron training. - """ - return NeuronTrainer.log(self, logs) + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} - def _save_checkpoint(self, model=None, trial=None, metrics=None): - """ - Override GRPOTrainer's _save_checkpoint to use NeuronTrainer's implementation. - """ - return NeuronTrainer._save_checkpoint(self) + logs = {**logs, **metrics} - def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: - """ - Prepare inputs for GRPO training. + # Using the NeuronTrainer log method instead of super().log. + NeuronTrainer.log(self, logs) - This method overrides NeuronTrainer._prepare_inputs to use GRPOTrainer's - implementation, which handles: - 1. Generation of completions using vLLM - 2. Scoring completions using reward functions - 3. Buffering completions for reuse across multiple gradient steps - 4. Tokenization and conversion to model inputs + self._metrics[mode].clear() - Args: - inputs: Raw batch from dataloader (list of prompt dicts for GRPO) + if self.accelerator.is_main_process and self.log_completions: + if is_rich_available(): + print_prompt_completions_sample( + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"], + self._logs["advantages"], + self.state.global_step, + self.num_completions_to_print, + ) - Returns: - Dictionary of tokenized tensors ready for the model - """ + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + import pandas as pd + + table = { + "step": [str(self.state.global_step)] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + **self._logs["rewards"], + "advantage": self._logs["advantages"], + } + + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) + + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + wandb.log({"completions": wandb.Table(dataframe=df)}) + + def _save_checkpoint(self, model=None, trial=None, metrics=None): + return NeuronTrainer._save_checkpoint(self) + + def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: # Explicitly call GRPOTrainer's _prepare_inputs return GRPOTrainer._prepare_inputs(self, inputs) def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - """ - Override to use NeuronAccelerator.gather instead of standalone gather function. - - The standalone gather from accelerate.utils may not be XLA-compatible, - so we use self.accelerator.gather which uses _xla_gather internally. - """ device = self.accelerator.device rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) @@ -611,6 +531,7 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): reward_inputs = NeuronTrainer._prepare_inputs(self, reward_inputs) with torch.inference_mode(): rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] + torch_xla.sync() else: output_reward_func = reward_func( prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs @@ -619,6 +540,7 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) rewards_per_func = self.accelerator.gather(rewards_per_func) + torch_xla.sync() return rewards_per_func def _generate_single_turn(self, prompts: list[str], images: list | None): @@ -679,14 +601,6 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): return prompt_ids, completion_ids, logprobs, forward_kwargs - def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] - ) -> dict[str, torch.Tensor | Any]: - # We patch the pad function to make it compatible with `neuron_parallel_compile`. - # patcher = Patcher([("trl.trainer.grpo_trainer.pad", pad)]) - # with patcher: - return GRPOTrainer._generate_and_score_completions(self, inputs) - def _to_fixed_length( self, tensor: torch.Tensor, @@ -694,19 +608,7 @@ def _to_fixed_length( padding_side: str = "right" ) -> torch.Tensor: """ - Pads or truncates tensor to fixed length for XLA compilation. - - XLA requires static shapes at graph construction time. This method ensures - all tensors (input_ids, attention_mask) have the same fixed length to enable - graph reuse across training steps. - - Args: - tensor: Input tensor to pad/truncate (2D: batch × seq_len) - padding_value: Value to use for padding (default: 0) - padding_side: "left" or "right" padding (default: "right") - - Returns: - Tensor with fixed length = max_prompt_length + max_completion_length + Pads or truncates tensor to fixed length = max_prompt_length + max_completion_length. """ fixed_length = self.max_prompt_length + self.max_completion_length seq_len = tensor.shape[1] @@ -1010,26 +912,49 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - delta = delta[completion_mask.bool()] - mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # Original code was: + # delta = delta[completion_mask.bool()] + # mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # But it is not XLA friendly because it involves dynamic indexing before reduction, so we rewrite it as: + completion_mask_count = completion_mask.sum() + delta_masked = delta * completion_mask + sum_delta = delta_masked.sum() + mean_delta = sum_delta / (completion_mask_count + 1e-10) + # We can simply take the max of the masked delta because values in delta are >= 0 (torch.abs). + max_delta = delta_masked.max() + metrics["sampling/sampling_logp_difference/mean"].append( self.accelerator.gather(mean_delta).mean() ) - self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + metrics["sampling/sampling_logp_difference/max"].append( self.accelerator.gather(max_delta).max() ) - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] - min_importance_sampling_ratio = ( - torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - mean_importance_sampling_ratio = ( - torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - max_importance_sampling_ratio = ( - torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # Original code was: + # flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + # min_importance_sampling_ratio = ( + # torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # mean_importance_sampling_ratio = ( + # torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # max_importance_sampling_ratio = ( + # torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # But it is not XLA friendly because it involves dynamic indexing before reduction, so we rewrite it as: + masked_is_ratio_for_min = torch.where( + completion_mask.bool(), + importance_sampling_ratio, + torch.tensor(float('inf'), device=device, dtype=importance_sampling_ratio.dtype) ) + min_importance_sampling_ratio = masked_is_ratio_for_min.min() + # importance_sampling_ratio values are >= 0 (torch.exp) so we can use the same computation as for delta. + flat_is_ratio_masked = importance_sampling_ratio * completion_mask + sum_flat_is_ratio = flat_is_ratio_masked.sum() + mean_importance_sampling_ratio = sum_flat_is_ratio / (completion_mask_count + 1e-10) + max_importance_sampling_ratio = flat_is_ratio_masked.max() + metrics["sampling/importance_sampling_ratio/min"].append( nanmin(self.accelerator.gather(min_importance_sampling_ratio)) ) @@ -1040,13 +965,14 @@ def _generate_and_score_completions( nanmax(self.accelerator.gather(max_importance_sampling_ratio)) ) - # Move metrics and logs to CPU. - metrics = pytree.tree_map(move_all_tensor_to_cpu, metrics) - logs = pytree.tree_map(move_all_tensor_to_cpu, logs) - # Graph break after metrics and logs computation. torch_xla.sync() + # Move metrics and logs to CPU. + metrics = move_all_tensor_to_cpu(metrics) + metrics = {key: [val.item() for val in value] for key, value in metrics.items()} + logs = move_all_tensor_to_cpu(logs) + # Update the actual metrics and logs. self._metrics[mode].update(metrics) for name in self.reward_func_names: @@ -1083,6 +1009,66 @@ def _generate_and_score_completions( return output + def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: + # Original code does the following: + # local = entropies[mask.bool()].float() + # # Use a negative pad_value as a sentinel because entropy values are always >= 0. + # # This guarantees that the sentinel cannot collide with any real entropy value. + # pad_value = -1e9 + + # # Pad across processes so that every rank has the same tensor length + # padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) + # gathered = self.accelerator.gather(padded) + + # # Drop sentinel values (safe because no entropy can be negative) + # gathered = gathered[gathered != pad_value] + + # if gathered.numel() == 0: + # return torch.zeros_like(entropies, dtype=torch.bool) + + # entropy_threshold = torch.quantile(gathered, threshold) + # masked_entropies = entropies * mask.float() + # entropy_mask = masked_entropies >= entropy_threshold + # return entropy_mask & mask.bool() # ensure padding tokens are always masked out + + pad_value = -1e9 + device = entropies.device + + masked_entropies = torch.where( + mask.bool(), + entropies, + torch.tensor(pad_value, device=device, dtype=entropies.dtype), + ) + + local_flat = masked_entropies.view(-1) + gathered = self.accelerator.gather(local_flat) + + # Sort gathered values, so that pad_value sentinels are at the beginning + sorted_values, _ = torch.sort(gathered) + + # Compute the number of valid (non-sentinel) values + num_valid = (sorted_values != pad_value).sum() + num_sentinels = (sorted_values == pad_value).sum() + valid_start_idx = num_sentinels + num_valid_values = gathered.numel() - num_sentinels + + # Get the quantile index and the corresponding entropy threshold value + quantile_idx = valid_start_idx + (threshold * num_valid_values).long() + quantile_idx = quantile_idx.clamp(max=gathered.numel() - 1) + entropy_threshold = sorted_values[quantile_idx] + + # Handle empty case, if everything is sentinel, set threshold to +inf so no token is selected + has_valid = num_valid > 0 + entropy_threshold = torch.where( + has_valid, + entropy_threshold, + torch.tensor(float('inf'), device=device, dtype=entropies.dtype) + ) + + entropy_mask = (entropies > entropy_threshold) & mask.bool() + return entropy_mask + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") @@ -1222,7 +1208,9 @@ def masked_batch_mean(x): metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) torch_xla.sync() # Graph break before moving metrics to CPU. - metrics = pytree.tree_map(move_all_tensor_to_cpu, metrics) + metrics = move_all_tensor_to_cpu(metrics) + metrics = {key: [val.item() for val in value] for key, value in metrics.items()} + self._metrics[mode].update(metrics) return loss diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 4046dd93b..fa8c96aca 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -13,4 +13,107 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np +import torch +from optimum.utils import logging + +from ..utils import is_precompilation + + +logger = logging.get_logger() + TRL_VERSION = "0.24.0" + +def pad( + tensors: list[torch.Tensor], + padding_value: int = 0, + padding_side: str = "right", + max_length: int | None = None, + ) -> torch.Tensor: + """ + Pads a list of tensors to the same shape along the first dimension. + It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid + recompilation. + """ + batch_size = len(tensors) + if max_length is None: + max_length = np.max([t.shape[0] for t in tensors]).tolist() + + output_shape = (max_length,) + tensors[0].shape[1:] + + # Create an output tensor filled with the padding value + output = torch.full((batch_size, *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) + + for i, t in enumerate(tensors): + if padding_side == "left": + seq_start = output_shape[0] - t.shape[0] + elif padding_side == "right": + seq_start = 0 + else: + raise ValueError("padding_side must be 'left' or 'right'") + + # Define the slices + seq_slice = slice(seq_start, seq_start + t.shape[0]) + slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) + output[i][slices] = t + + return output + + +def neuron_parallel_compile_tokenizer_decoder_method( + self, + token_ids: int | list[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool | None = None, + **kwargs, +) -> str: + """ + Patched `tokenizer._decode` method for `neuron_parallel_compile`. + This is needed because any tensor operation on the XLA device during `neuron_parallel_compile` produces rubbish + results, which is not an issue in general, but causes failure when the token IDS end up being out of range for the + tokenizer vocabulary. + """ + if not is_precompilation(): + raise RuntimeError("This patch method should only be used with `neuron_parallel_compile`.") + + # We log the token IDs to force the data mouvement to CPU, which would happen during actual decoding. + logger.debug("Using patched tokenizer.decode method for Neuron parallel compilation, token_ids = ", token_ids) + + # Returns a dummy string, we do not care about the value in this context. + return "dummy" + + +def nanmin(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanmin that doesn't use dynamic indexing. + Compute the minimum value of a tensor, ignoring NaNs. + """ + mask = torch.isnan(tensor) + if mask.all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + filled = torch.where(mask, torch.tensor(float("inf"), dtype=tensor.dtype, device=tensor.device), tensor) + return torch.min(filled) + + +def nanmax(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanmax that doesn't use dynamic indexing. + Compute the maximum value of a tensor, ignoring NaNs. + """ + mask = torch.isnan(tensor) + if mask.all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + filled = torch.where(mask, torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device), tensor) + return torch.max(filled) + + +def nanstd(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanstd. + Compute the standard deviation of a tensor, ignoring NaNs. + """ + # Use torch's built-in nanmean and compute variance with Bessel's correction + variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) + count = torch.sum(~torch.isnan(tensor)) + variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero + return torch.sqrt(variance) From a224bed7af9b8ce81e98b46e234e204c73c26833 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 19 Nov 2025 21:00:31 +0100 Subject: [PATCH 37/78] add torch_xla.sync() to break the graphs in the for loops --- optimum/neuron/trainers/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index fda5880cc..239066825 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -708,6 +708,8 @@ def _get_per_token_logps_and_entropies( entropies = entropy_from_logits(logits) all_entropies.append(entropies) + torch_xla.sync() + logps = torch.cat(all_logps, dim=0) entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None From fc1836607893fb664b068809174e165156e3b321 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Nov 2025 14:20:55 +0100 Subject: [PATCH 38/78] add DistributedRepeatSampler --- .../grpo_qwen3/finetune_grpo_qwen3.py | 2 +- .../grpo_qwen3/finetune_grpo_qwen3.sh | 5 +- optimum/neuron/trainers/grpo_trainer.py | 81 +++++++++++- optimum/neuron/trainers/trl_utils.py | 125 ++++++++++++++++++ 4 files changed, 209 insertions(+), 4 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index d146c9908..47df95de8 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -192,7 +192,7 @@ def train(model_id, tokenizer, dataset, training_args): args=grpo_config, train_dataset=dataset, processing_class=tokenizer, - # peft_config=lora_config, + peft_config=lora_config, ) # Train the model diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index 5458ae86b..e2b9afcb4 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -32,6 +32,7 @@ BS=1 GRADIENT_ACCUMULATION_STEPS=4 # Smaller for GRPO due to generation overhead LOGGING_STEPS=1 MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing +# MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE" SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) @@ -65,7 +66,9 @@ torchrun $DISTRIBUTED_ARGS finetune_grpo_qwen3.py \ --learning_rate 5e-5 \ --bf16 \ --tensor_parallel_size $TP_DEGREE \ - --zero_1 \ + --zero_1 false \ + --optimizer_use_master_weights false \ + --optimizer_use_fp32_grad_acc false \ --async_save \ --logging_steps $LOGGING_STEPS \ --output_dir $OUTPUT_DIR \ diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 239066825..4847503a1 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -22,9 +22,10 @@ import torch_xla import torch_xla.core.xla_model as xm from accelerate.utils import set_seed +from neuronx_distributed import parallel_layers from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from optimum.utils import logging -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import Dataset, IterableDataset, Sampler from transformers import ( PreTrainedModel, PreTrainedTokenizerBase, @@ -41,7 +42,15 @@ from .grpo_config import NeuronGRPOConfig from .training_args import NeuronTrainingArguments from .transformers import NeuronTrainer -from .trl_utils import TRL_VERSION, nanmax, nanmin, nanstd, neuron_parallel_compile_tokenizer_decoder_method, pad +from .trl_utils import ( + TRL_VERSION, + DistributedRepeatSampler, + nanmax, + nanmin, + nanstd, + neuron_parallel_compile_tokenizer_decoder_method, + pad, +) if is_wandb_available(): @@ -51,6 +60,7 @@ from trl import GRPOConfig, GRPOTrainer from trl.data_utils import is_conversational from trl.trainer.utils import ( + RepeatSampler, disable_dropout_in_model, entropy_from_logits, identity, @@ -444,6 +454,34 @@ def make_inputs_require_grad(module, input, output): self.reward_funcs[i] = self.accelerator.prepare_model( reward_func, evaluation_mode=True, device_placement=True ) + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: + if dataset is None: + dataset = self.train_dataset + if self.accelerator.num_processes == 1: + sampler = RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + else: + trn_config = self.accelerator.state.trn_config + num_replicas = trn_config.data_parallel_size + rank = parallel_layers.parallel_state.get_data_parallel_rank() + sampler = DistributedRepeatSampler( + dataset=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + num_replicas=num_replicas, + rank=rank, + ) + return sampler + def train( self, @@ -500,6 +538,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: df = df.drop_duplicates(subset=["prompt"]) wandb.log({"completions": wandb.Table(dataframe=df)}) + def _save_checkpoint(self, model=None, trial=None, metrics=None): return NeuronTrainer._save_checkpoint(self) @@ -543,6 +582,44 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): torch_xla.sync() return rewards_per_func + def _move_model_to_vllm(self): + if isinstance(self.model, NeuronPeftModel): + self.model.merge_adapter() + + # DeepSpeed ZeRO-3 with PEFT + for name, param in self.model.named_parameters(): + # When using PEFT, we need to recover the original parameter name and discard some parameters + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + if self.model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + self.model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + for name, param in self.model.named_parameters(): + name = self._fix_param_name_to_vllm(name) + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == "colocate": + self.llm.reset_prefix_cache() + def _generate_single_turn(self, prompts: list[str], images: list | None): """ Generate a single turn of completions using vLLM (mock or real server). diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index fa8c96aca..4da05da22 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -13,9 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import numpy as np import torch +import torch.distributed as dist from optimum.utils import logging +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler from ..utils import is_precompilation @@ -117,3 +122,123 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: count = torch.sum(~torch.isnan(tensor)) variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero return torch.sqrt(variance) + + +class DistributedRepeatSampler(DistributedSampler): + """ + Sampler that repeats the indices of a dataset in a structured manner. + Same as `trl.trainer.utils.RepeatSampler` but adapted to work with distributed training. + + To implement it, we simply combine the logic from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py + with the logic from https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1692. + + First, we distribute the dataset indices across the different ranks, then we repeat the indices on each rank. + + We inherit from `torch.utils.data.DistributedSampler` even though we override all of its methods to pass the checks + "isinstance(sampler, DistributedSampler)" done in `torch.utils.data.DataLoader` when using distributed training. + """ + + def __init__( + self, + dataset: Dataset, + mini_repeat_count: int, + batch_size: int = 1, + repeat_count: int = 1, + num_replicas: int | None = None, + rank: int | None = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + self.mini_repeat_count = mini_repeat_count + self.batch_size = batch_size + self.repeat_count = repeat_count + self.shuffle = shuffle + + if shuffle: + self.generator = torch.Generator() # Create a local random generator + self.generator.manual_seed(seed) + + def __iter__(self): + # First, we produce indices for each rank. + # That is the distributed part of the sampler. + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + # Second, we repeat the indices on each rank. + # This is the non-distributed part of the sampler. + # [2, 4, 3, 1, 0, 6, 5] + # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) + indices = [indices[i : i + self.batch_size] for i in range(0, len(indices), self.batch_size)] + + # [[2, 4, 3], [1, 0, 6], [5]] + # -> [[2, 4, 3], [1, 0, 6]] + indices = [chunk for chunk in indices if len(chunk) == self.batch_size] + + for chunk in indices: + for _ in range(self.repeat_count): + for index in chunk: + for _ in range(self.mini_repeat_count): + yield index + + def __len__(self) -> int: + return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + From caf238a31104a61e4454fbae65fe6313368396ac Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Nov 2025 14:48:44 +0100 Subject: [PATCH 39/78] merge for lora.ParallelLinear --- optimum/neuron/peft/tuners/lora/layer.py | 132 ++++++++++++++++++++++- optimum/neuron/trainers/grpo_trainer.py | 1 - 2 files changed, 129 insertions(+), 4 deletions(-) diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index aa4f4ed1b..2b37d756b 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -35,6 +35,7 @@ from peft.tuners.lora import Linear as LoraLinear from peft.tuners.lora import LoraLayer from peft.tuners.lora.variants import LoraVariant + from peft.tuners.tuners_utils import check_adapters_to_merge from peft.utils.integrations import gather_params_ctx else: @@ -53,6 +54,9 @@ class LoraVariant: def gather_params_ctx(param): pass + def check_adapters_to_merge(layer, adapter_names): + return [] + def use_peft_instead_of_optimum_neuron(neuron_lora_layer_method): """ @@ -262,9 +266,131 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> LoraVariant | Non return DoraLinearVariant() - merge = use_peft_instead_of_optimum_neuron(LoraLinear.merge) - unmerge = use_peft_instead_of_optimum_neuron(LoraLinear.unmerge) - get_delta_weight = use_peft_instead_of_optimum_neuron(LoraLinear.get_delta_weight) + def get_delta_weight(self, adapter: str) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + For parallel linear layers, this handles both RowParallelLinear (lora_A) and + ColumnParallelLinear (lora_B) cases. The delta is computed in the sharded form. + + Args: + adapter: The name of the adapter for which the delta weight should be computed. + + Returns: + The delta weight tensor (sharded if the base layer is sharded). + """ + lora_A = self.lora_A[adapter] + lora_B = self.lora_B[adapter] + + device = lora_B.weight.device + dtype = lora_B.weight.dtype + + # Cast to fp32 on CPU for better performance with bf16 + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = lora_A.weight + weight_B = lora_B.weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # Compute delta: B @ A * scaling + # The result is sharded the same way as the base layer: + # - If lora_A is RowParallelLinear: delta is sharded along input dimension + # - If lora_B is ColumnParallelLinear: delta is sharded along output dimension + output_tensor = (weight_B @ weight_A) * self.scaling[adapter] + + if self.fan_in_fan_out: + output_tensor = output_tensor.transpose(0, 1) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + # Cast weights back + lora_A.weight.data = weight_A.to(dtype) + lora_B.weight.data = weight_B.to(dtype) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: list[str] | None = None) -> None: + """ + Merge the active adapter weights into the base weights. + + This works with distributed parallel linear layers (RowParallelLinear, ColumnParallelLinear). + The merge happens on the sharded weights - each rank merges its own shard. + + Args: + safe_merge: If True, perform merge in a copy and check for NaNs before merging. + adapter_names: List of adapter names to merge. If None, all active adapters will be merged. + """ + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + + if self.use_dora[active_adapter]: + raise NotImplementedError("DoRA is not yet supported for merge with Neuron parallel layers") + + if safe_merge: + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + orig_weights += delta_weight + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weights + + if self.lora_bias[active_adapter]: + lora_B = self.lora_B[active_adapter] + if hasattr(base_layer, "bias") and base_layer.bias is not None: + new_bias = base_layer.bias + lora_B.bias + if not torch.isfinite(new_bias).all(): + raise ValueError( + f"NaNs detected in the merged bias. The adapter {active_adapter} seems to be broken" + ) + base_layer.bias.data = new_bias + else: + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data += delta_weight + + if self.lora_bias[active_adapter]: + lora_B = self.lora_B[active_adapter] + if hasattr(base_layer, "bias") and base_layer.bias is not None: + base_layer.bias.data += lora_B.bias + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + Unmerge all merged adapter layers from the base weights. + + This works with distributed parallel linear layers (RowParallelLinear, ColumnParallelLinear). + The unmerge happens on the sharded weights - each rank unmerges its own shard. + """ + if not self.merged: + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + + if self.use_dora[active_adapter]: + raise NotImplementedError("DoRA is not yet supported for unmerge with Neuron parallel layers") + + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data -= delta_weight + + if self.lora_bias[active_adapter]: + lora_B = self.lora_B[active_adapter] + if hasattr(base_layer, "bias") and base_layer.bias is not None: + base_layer.bias.data -= lora_B.bias def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: self._check_forward_args(x, *args, **kwargs) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 4847503a1..d30dcbc24 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -586,7 +586,6 @@ def _move_model_to_vllm(self): if isinstance(self.model, NeuronPeftModel): self.model.merge_adapter() - # DeepSpeed ZeRO-3 with PEFT for name, param in self.model.named_parameters(): # When using PEFT, we need to recover the original parameter name and discard some parameters name = name.removeprefix("base_model.model.").replace(".base_layer", "") From 026a237a4b213dbfc10d98ec5949658958e47c44 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Nov 2025 14:52:47 +0100 Subject: [PATCH 40/78] merge for ParallelEmbedding --- optimum/neuron/peft/tuners/lora/layer.py | 248 ++++++++++++++++++++++- 1 file changed, 245 insertions(+), 3 deletions(-) diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index 2b37d756b..232cd12cd 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -583,6 +583,154 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights): nn.init.zeros_(self.lora_B[adapter_name].bias_k) nn.init.zeros_(self.lora_B[adapter_name].bias_v) + def get_delta_weight(self, adapter: str) -> dict[str, torch.Tensor]: + """ + Compute the delta weights for Q, K, V for the given adapter. + + Returns a dict with keys 'q', 'k', 'v' (or 'qkv' if fused) containing the delta tensors. + + Args: + adapter: The name of the adapter for which the delta weight should be computed. + + Returns: + Dict mapping 'q'/'k'/'v' (or 'qkv') to their delta weight tensors (sharded). + """ + lora_A = self.lora_A[adapter] + lora_B = self.lora_B[adapter] + + device = lora_A.weight.device + dtype = lora_A.weight.dtype + + # Cast to fp32 on CPU for better performance with fp16/bf16 + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = lora_A.weight + if cast_to_fp32: + weight_A = weight_A.float() + + base_layer = self.get_base_layer() + delta_weights = {} + + # Compute delta for each Q, K, V + if base_layer.fuse_qkv: + weight_B_qkv = lora_B.weight_qkv + if cast_to_fp32: + weight_B_qkv = weight_B_qkv.float() + delta_weights["qkv"] = (weight_B_qkv @ weight_A) * self.scaling[adapter] + if cast_to_fp32: + delta_weights["qkv"] = delta_weights["qkv"].to(dtype=dtype) + else: + for key, weight_attr in [("q", "weight_q"), ("k", "weight_k"), ("v", "weight_v")]: + weight_B = getattr(lora_B, weight_attr) + if cast_to_fp32: + weight_B = weight_B.float() + delta_weights[key] = (weight_B @ weight_A) * self.scaling[adapter] + if cast_to_fp32: + delta_weights[key] = delta_weights[key].to(dtype=dtype) + + return delta_weights + + def merge(self, safe_merge: bool = False, adapter_names: list[str] | None = None) -> None: + """ + Merge the active adapter weights into the base Q, K, V weights. + + This works with GQAQKVColumnParallelLinear layers. + The merge happens on the sharded weights - each rank merges its own shard. + + Args: + safe_merge: If True, perform merge in a copy and check for NaNs before merging. + adapter_names: List of adapter names to merge. If None, all active adapters will be merged. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + + if self.use_dora[active_adapter]: + raise NotImplementedError("DoRA is not yet supported for merge with GQA QKV layers") + + delta_weights = self.get_delta_weight(active_adapter) + + if safe_merge: + if base_layer.fuse_qkv: + orig_weight_qkv = base_layer.weight_qkv.data.clone() + orig_weight_qkv += delta_weights["qkv"] + if not torch.isfinite(orig_weight_qkv).all(): + raise ValueError( + f"NaNs detected in merged QKV weights. Adapter {active_adapter} seems broken" + ) + base_layer.weight_qkv.data = orig_weight_qkv + else: + for key, weight_attr in [("q", "weight_q"), ("k", "weight_k"), ("v", "weight_v")]: + orig_weight = getattr(base_layer, weight_attr).data.clone() + orig_weight += delta_weights[key] + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in merged {key.upper()} weights. Adapter {active_adapter} seems broken" + ) + getattr(base_layer, weight_attr).data = orig_weight + else: + if base_layer.fuse_qkv: + base_layer.weight_qkv.data += delta_weights["qkv"] + else: + for key, weight_attr in [("q", "weight_q"), ("k", "weight_k"), ("v", "weight_v")]: + getattr(base_layer, weight_attr).data += delta_weights[key] + + # Handle bias if present + if self.lora_bias[active_adapter]: + lora_B = self.lora_B[active_adapter] + if base_layer.fuse_qkv and hasattr(lora_B, "bias_qkv") and lora_B.bias_qkv is not None: + if hasattr(base_layer, "bias_qkv") and base_layer.bias_qkv is not None: + base_layer.bias_qkv.data += lora_B.bias_qkv + elif not base_layer.fuse_qkv: + for key, bias_attr in [("q", "bias_q"), ("k", "bias_k"), ("v", "bias_v")]: + if hasattr(lora_B, bias_attr) and getattr(lora_B, bias_attr) is not None: + if hasattr(base_layer, bias_attr) and getattr(base_layer, bias_attr) is not None: + getattr(base_layer, bias_attr).data += getattr(lora_B, bias_attr) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + Unmerge all merged adapter layers from the base Q, K, V weights. + + This works with GQAQKVColumnParallelLinear layers. + The unmerge happens on the sharded weights - each rank unmerges its own shard. + """ + if not self.merged: + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + + if self.use_dora[active_adapter]: + raise NotImplementedError("DoRA is not yet supported for unmerge with GQA QKV layers") + + delta_weights = self.get_delta_weight(active_adapter) + + if base_layer.fuse_qkv: + base_layer.weight_qkv.data -= delta_weights["qkv"] + else: + for key, weight_attr in [("q", "weight_q"), ("k", "weight_k"), ("v", "weight_v")]: + getattr(base_layer, weight_attr).data -= delta_weights[key] + + # Handle bias if present + if self.lora_bias[active_adapter]: + lora_B = self.lora_B[active_adapter] + if base_layer.fuse_qkv and hasattr(lora_B, "bias_qkv") and lora_B.bias_qkv is not None: + if hasattr(base_layer, "bias_qkv") and base_layer.bias_qkv is not None: + base_layer.bias_qkv.data -= lora_B.bias_qkv + elif not base_layer.fuse_qkv: + for key, bias_attr in [("q", "bias_q"), ("k", "bias_k"), ("v", "bias_v")]: + if hasattr(lora_B, bias_attr) and getattr(lora_B, bias_attr) is not None: + if hasattr(base_layer, bias_attr) and getattr(base_layer, bias_attr) is not None: + getattr(base_layer, bias_attr).data -= getattr(lora_B, bias_attr) + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: previous_dtype = x.dtype output_q, output_k, output_v = self.base_layer(x, *args, **kwargs) @@ -655,12 +803,106 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> LoraVariant | Non return DoraEmbeddingVariant() update_layer = LoraEmbedding.update_layer - merge = use_peft_instead_of_optimum_neuron(LoraEmbedding.merge) - unmerge = use_peft_instead_of_optimum_neuron(LoraEmbedding.unmerge) - get_delta_weight = use_peft_instead_of_optimum_neuron(LoraEmbedding.get_delta_weight) _mixed_batch_forward = LoraEmbedding._mixed_batch_forward _embed = LoraEmbedding._embed + def get_delta_weight(self, adapter: str) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + For parallel embedding layers, the delta is computed in the sharded form. + + Args: + adapter: The name of the adapter for which the delta weight should be computed. + + Returns: + The delta weight tensor (sharded if the base layer is sharded). + """ + device = self.lora_embedding_B[adapter].device + dtype = self.lora_embedding_A[adapter].dtype + + # Cast to fp32 on CPU for better performance with fp16/bf16 + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.lora_embedding_A[adapter] + weight_B = self.lora_embedding_B[adapter] + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # Compute delta: B @ A (transposed if fan_in_fan_out) + output_tensor = (weight_B @ weight_A) * self.scaling[adapter] + if self.fan_in_fan_out: + output_tensor = output_tensor.T + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + # Cast weights back + self.lora_embedding_A[adapter] = weight_A.to(dtype) + self.lora_embedding_B[adapter] = weight_B.to(dtype) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: list[str] | None = None) -> None: + """ + Merge the active adapter weights into the base embedding weights. + + This works with ParallelEmbedding layers. + The merge happens on the sharded weights - each rank merges its own shard. + + Args: + safe_merge: If True, perform merge in a copy and check for NaNs before merging. + adapter_names: List of adapter names to merge. If None, all active adapters will be merged. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_embedding_A.keys(): + base_layer = self.get_base_layer() + + if self.use_dora[active_adapter]: + raise NotImplementedError("DoRA is not yet supported for merge with Neuron parallel embeddings") + + if safe_merge: + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + orig_weights += delta_weight + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data += delta_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + Unmerge all merged adapter layers from the base embedding weights. + + This works with ParallelEmbedding layers. + The unmerge happens on the sharded weights - each rank unmerges its own shard. + """ + if not self.merged: + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_embedding_A.keys(): + base_layer = self.get_base_layer() + + if self.use_dora[active_adapter]: + raise NotImplementedError("DoRA is not yet supported for unmerge with Neuron parallel embeddings") + + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data -= delta_weight + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: # TODO: no dtype conversion here, unlike in Linear, is that correct? self._check_forward_args(x, *args, **kwargs) From a15843a56bf95cd48071217b3030dc406b6c427a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Nov 2025 18:13:01 +0100 Subject: [PATCH 41/78] merge for peft models --- optimum/neuron/peft/tuners/lora/layer.py | 28 +------ tests/training/test_custom_modeling.py | 100 +++++++++++++++++++++++ 2 files changed, 103 insertions(+), 25 deletions(-) diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index 232cd12cd..f4d019cec 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -267,18 +267,6 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> LoraVariant | Non return DoraLinearVariant() def get_delta_weight(self, adapter: str) -> torch.Tensor: - """ - Compute the delta weight for the given adapter. - - For parallel linear layers, this handles both RowParallelLinear (lora_A) and - ColumnParallelLinear (lora_B) cases. The delta is computed in the sharded form. - - Args: - adapter: The name of the adapter for which the delta weight should be computed. - - Returns: - The delta weight tensor (sharded if the base layer is sharded). - """ lora_A = self.lora_A[adapter] lora_B = self.lora_B[adapter] @@ -295,6 +283,7 @@ def get_delta_weight(self, adapter: str) -> torch.Tensor: weight_A = weight_A.float() weight_B = weight_B.float() + base_layer = self.get_base_layer() # Compute delta: B @ A * scaling # The result is sharded the same way as the base layer: # - If lora_A is RowParallelLinear: delta is sharded along input dimension @@ -587,13 +576,13 @@ def get_delta_weight(self, adapter: str) -> dict[str, torch.Tensor]: """ Compute the delta weights for Q, K, V for the given adapter. - Returns a dict with keys 'q', 'k', 'v' (or 'qkv' if fused) containing the delta tensors. + Returns a dict with keys "q", "k", "v" (or "qkv" if fused) containing the delta tensors. Args: adapter: The name of the adapter for which the delta weight should be computed. Returns: - Dict mapping 'q'/'k'/'v' (or 'qkv') to their delta weight tensors (sharded). + Dict mapping "q"/"k"/"v" (or "qkv") to their delta weight tensors (sharded). """ lora_A = self.lora_A[adapter] lora_B = self.lora_B[adapter] @@ -807,17 +796,6 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> LoraVariant | Non _embed = LoraEmbedding._embed def get_delta_weight(self, adapter: str) -> torch.Tensor: - """ - Compute the delta weight for the given adapter. - - For parallel embedding layers, the delta is computed in the sharded form. - - Args: - adapter: The name of the adapter for which the delta weight should be computed. - - Returns: - The delta weight tensor (sharded if the base layer is sharded). - """ device = self.lora_embedding_B[adapter].device dtype = self.lora_embedding_A[adapter].dtype diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index bda6d3b89..d2abd70dc 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -732,3 +732,103 @@ def test_peft_adapters_with_pp(set_cache_for_ci): if param.requires_grad: # Skip parameters that might be trainable (like embeddings) continue assert param.grad is None, f"Base parameter {name} should not have gradients" + + +@distributed_test(world_size=2, tp_size=2, pp_size=1) +def test_peft_merge_unmerge(set_cache_for_ci): + tp_size = get_tensor_model_parallel_size() + pp_size = get_pipeline_model_parallel_size() + + trn_config = TrainingNeuronConfig( + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + ) + accelerator = NeuronAccelerator(trn_config=trn_config) + + tok = AutoTokenizer.from_pretrained(LLAMA_V2_MODEL_NAME) + inputs = tok("Hello, my dog is cute", return_tensors="pt") + inputs = {k: v.to("xla") for k, v in inputs.items()} + xm.mark_step() + + model = NeuronModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, trn_config, torch_dtype=torch.float32) + + peft_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.0, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, peft_config) + + # Store original base weights for verification + original_weights = {} + for name, param in model.named_parameters(): + if "lora" not in name.lower() and "weight" in name: + original_weights[name] = param.data.clone() + elif "lora_B" in name: + # LoRA B weights should be initialized to zero, we change that for the test otherwise the delta is zero, + # which prevents meaningful checks. + assert torch.all(param.data == 0), f"LoRA B weight {name} should be initialized to zero" + param.data += 0.1 + + model = accelerator.prepare_model(model) + model.eval() + + # Get output with LoRA (unmerged) + with torch.no_grad(): + output_unmerged = model(**inputs) + logits_unmerged = output_unmerged.logits.clone() + xm.mark_step() + + # Merge LoRA adapters + model.merge_adapter() + xm.mark_step() + + # Verify weights changed after merge (at least one weight should change) + current_weights = move_all_tensor_to_cpu(dict(model.named_parameters())) + xm.mark_step() + weights_changed = False + for name, original_weight in original_weights.items(): + current_weight = current_weights[name].data + if not torch.allclose(original_weight, current_weight, rtol=1e-4): + weights_changed = True + break + + assert weights_changed, "At least one base weight should change after merge" + + # Get output with merged weights + with torch.no_grad(): + output_merged = model(**inputs) + logits_merged = output_merged.logits.clone() + xm.mark_step() + + print(output_merged) + print(output_unmerged) + + # Outputs should match + assert torch.allclose(logits_unmerged, logits_merged, rtol=1e-3, atol=1e-3), \ + f"Merged and unmerged outputs should match. Max diff: {(logits_unmerged - logits_merged).abs().max().item()}" + + # Unmerge LoRA adapters + model.unmerge_adapter() + xm.mark_step() + + # Verify weights restored after unmerge + current_weights = move_all_tensor_to_cpu(dict(model.named_parameters())) + xm.mark_step() + for name, original_weight in original_weights.items(): + current_weight = current_weights[name].data + assert torch.allclose(original_weight, current_weight, rtol=1e-5, atol=1e-6), \ + f"Weight {name} should be restored after unmerge. Max diff: {(original_weight - current_weight).abs().max().item()}" + + # Final output check + with torch.no_grad(): + output_final = model(**inputs) + logits_final = output_final.logits.clone() + xm.mark_step() + + assert torch.allclose(logits_unmerged, logits_final, rtol=1e-5, atol=1e-5), \ + f"Final output should match original unmerged output. Max diff: {(logits_unmerged - logits_final).abs().max().item()}" From b052efc79920294b59dd9b5975d5c018f274aa33 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 21 Nov 2025 16:37:06 +0100 Subject: [PATCH 42/78] merge for peft models --- .../models/training/transformations_utils.py | 4 +- optimum/neuron/peft/tuners/lora/layer.py | 6 +- optimum/neuron/peft/utils/__init__.py | 1 + optimum/neuron/trainers/grpo_trainer.py | 27 +++--- tests/training/test_custom_modeling.py | 95 ++++++++++++++++++- 5 files changed, 110 insertions(+), 23 deletions(-) diff --git a/optimum/neuron/models/training/transformations_utils.py b/optimum/neuron/models/training/transformations_utils.py index 7c8ec9281..0e5e87c3b 100644 --- a/optimum/neuron/models/training/transformations_utils.py +++ b/optimum/neuron/models/training/transformations_utils.py @@ -101,7 +101,7 @@ def peft_type(self) -> str | None: return self._peft_type @peft_type.setter - def peft_type(self, value: str): + def peft_type(self, value: str | None): self._peft_type = value @abstractmethod @@ -1662,7 +1662,9 @@ def create_parameter_metadata(model) -> dict[str, dict[str, Any]]: consolidating the sharded state dicts. """ metadata = {"parameters": {}, "model_weight_transformation_specs": []} + print("micka", model.parameters_for_current_stage) for name, param in model.named_parameters(): + print(name) if name not in model.parameters_for_current_stage: continue tensor_model_parallel = getattr(param, "tensor_model_parallel", False) diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index f4d019cec..b2c23a388 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -809,10 +809,8 @@ def get_delta_weight(self, adapter: str) -> torch.Tensor: weight_A = weight_A.float() weight_B = weight_B.float() - # Compute delta: B @ A (transposed if fan_in_fan_out) - output_tensor = (weight_B @ weight_A) * self.scaling[adapter] - if self.fan_in_fan_out: - output_tensor = output_tensor.T + # Compute delta: (B @ A).T * scaling + output_tensor = (weight_B @ weight_A).T * self.scaling[adapter] if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) diff --git a/optimum/neuron/peft/utils/__init__.py b/optimum/neuron/peft/utils/__init__.py index e69de29bb..3d03b0b25 100644 --- a/optimum/neuron/peft/utils/__init__.py +++ b/optimum/neuron/peft/utils/__init__.py @@ -0,0 +1 @@ +from .vllm import get_original_merged_weights_for_vllm diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index d30dcbc24..4ca612b25 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -37,6 +37,7 @@ from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model +from ..peft.utils.vllm import get_original_merged_weights_for_vllm from ..utils import is_precompilation, is_trl_available from ..utils.import_utils import is_peft_available from .grpo_config import NeuronGRPOConfig @@ -193,8 +194,9 @@ def __init__( if peft_config is not None and not isinstance(model, NeuronPeftModel): # Enable gradient checkpointing if needed - gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} if args.gradient_checkpointing and ( "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] ): @@ -584,26 +586,19 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): def _move_model_to_vllm(self): if isinstance(self.model, NeuronPeftModel): - self.model.merge_adapter() + # Get original (unsharded, untransformed) merged weights for vLLM + original_weights = get_original_merged_weights_for_vllm(self.model) - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + # Send weights to vLLM server (only main process for server mode) + for name, weight in original_weights.items(): + # Clean up parameter name for vLLM + name = self._fix_param_name_to_vllm(name) if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) + self.vllm_client.update_named_param(name, weight) elif self.vllm_mode == "colocate": llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context + llm_model.load_weights([(name, weight)]) else: for name, param in self.model.named_parameters(): name = self._fix_param_name_to_vllm(name) diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index d2abd70dc..c775051b9 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -769,11 +769,11 @@ def test_peft_merge_unmerge(set_cache_for_ci): if "lora" not in name.lower() and "weight" in name: original_weights[name] = param.data.clone() elif "lora_B" in name: - # LoRA B weights should be initialized to zero, we change that for the test otherwise the delta is zero, + # LoRA B weights should be initialized to zero, we change that for the test otherwise the delta is zero, # which prevents meaningful checks. assert torch.all(param.data == 0), f"LoRA B weight {name} should be initialized to zero" param.data += 0.1 - + model = accelerator.prepare_model(model) model.eval() @@ -832,3 +832,94 @@ def test_peft_merge_unmerge(set_cache_for_ci): assert torch.allclose(logits_unmerged, logits_final, rtol=1e-5, atol=1e-5), \ f"Final output should match original unmerged output. Max diff: {(logits_unmerged - logits_final).abs().max().item()}" + + +@distributed_test(world_size=8, tp_size=2, pp_size=1) +def test_get_original_merged_weights_for_vllm(set_cache_for_ci): + """Test that get_original_merged_weights_for_vllm produces correct unsharded, original-format weights.""" + tp_size = get_tensor_model_parallel_size() + tp_rank = get_tensor_model_parallel_rank() + + trn_config = TrainingNeuronConfig( + tensor_parallel_size=tp_size, + ) + mixed_precision = MixedPrecisionConfig(mode="FULL_BF16") + accelerator = NeuronAccelerator(trn_config=trn_config, mixed_precision_config=mixed_precision) + + model = NeuronModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, trn_config, torch_dtype=torch.bfloat16) + + # Store original base weights before PEFT for comparison + original_base_weights = {} + for name, param in model.named_parameters(): + original_base_weights[name] = param.data.clone() + + peft_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.0, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, peft_config) + + # Set lora_B weights to non-zero values + for name, param in model.named_parameters(): + if "lora_B" in name: + assert torch.all(param.data == 0), f"LoRA B weight {name} should be initialized to zero" + param.data += 0.1 + + model = accelerator.prepare_model(model) + model.eval() + + # Get original merged weights for vLLM + from optimum.neuron.peft.utils.vllm import get_original_merged_weights_for_vllm + original_weights = get_original_merged_weights_for_vllm(model) + xm.mark_step() + + # Only check on main process since get_original_merged_weights_for_vllm returns same weights on all ranks + if tp_rank == 0: + # Test 1: Check weights are unsharded (full size) + # For Llama-2-7b with TP=2, hidden_size=4096 + hidden_size = model.config.hidden_size + intermediate_size = model.config.intermediate_size + + # Check attention projection sizes (should be full, not sharded) + assert "model.layers.0.self_attn.q_proj.weight" in original_weights + q_proj_weight = original_weights["model.layers.0.self_attn.q_proj.weight"] + assert q_proj_weight.shape == (hidden_size, hidden_size), \ + f"q_proj should be unsharded {hidden_size}x{hidden_size}, got {q_proj_weight.shape}" + + # Test 2: Check weights are in original format (separate gate/up, not fused) + # The custom model uses fused gate_up_proj, but original format should have separate projections + assert "model.layers.0.mlp.gate_proj.weight" in original_weights + assert "model.layers.0.mlp.up_proj.weight" in original_weights + assert "model.layers.0.mlp.gate_up_proj.weight" not in original_weights, \ + "Should use original format (separate gate/up), not custom format (fused gate_up)" + + gate_proj_weight = original_weights["model.layers.0.mlp.gate_proj.weight"] + up_proj_weight = original_weights["model.layers.0.mlp.up_proj.weight"] + assert gate_proj_weight.shape == (intermediate_size, hidden_size) + assert up_proj_weight.shape == (intermediate_size, hidden_size) + + # Test 3: Verify LoRA delta is merged + # Since we set lora_B += 0.1, the merged weights should differ from original base weights + # Get the corresponding original base weight (need to map from PEFT name to base name) + base_q_proj_name = "model.layers.0.self_attn.q_proj.weight" + if base_q_proj_name in original_base_weights: + original_q_proj = original_base_weights[base_q_proj_name] + merged_q_proj = original_weights["model.layers.0.self_attn.q_proj.weight"] + + # Weights should be different (LoRA delta was merged) + assert not torch.allclose(original_q_proj, merged_q_proj, rtol=1e-4), \ + "Merged weight should differ from original base weight (LoRA delta should be added)" + + # Test 4: Verify model state is restored (adapters are unmerged) + # After calling get_original_merged_weights_for_vllm, the model should be back to unmerged state + for module in model.modules(): + if hasattr(module, "merged"): + assert not module.merged, \ + f"Module {module.__class__.__name__} should be unmerged after get_original_merged_weights_for_vllm" + + print("✓ All get_original_merged_weights_for_vllm tests passed") From b2ff3106d55218f61bc513f239ae3eb4433c983d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 21 Nov 2025 16:48:59 +0100 Subject: [PATCH 43/78] merge for peft models --- optimum/neuron/peft/tuners/lora/layer.py | 1 - tests/training/test_custom_modeling.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index b2c23a388..79c1f2e07 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -283,7 +283,6 @@ def get_delta_weight(self, adapter: str) -> torch.Tensor: weight_A = weight_A.float() weight_B = weight_B.float() - base_layer = self.get_base_layer() # Compute delta: B @ A * scaling # The result is sharded the same way as the base layer: # - If lora_A is RowParallelLinear: delta is sharded along input dimension diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index c775051b9..2ba044cb8 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -28,6 +28,7 @@ get_pipeline_model_parallel_rank, get_pipeline_model_parallel_size, get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_size, ) from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu @@ -45,6 +46,7 @@ from optimum.neuron.models.training.modeling_auto import NeuronModelForCausalLM from optimum.neuron.models.training.transformations_utils import GQAQKVColumnParallelLinearSpec from optimum.neuron.peft import get_peft_model +from optimum.neuron.peft.utils.vllm import get_original_merged_weights_for_vllm from optimum.neuron.utils.import_utils import ( is_neuronx_available, ) @@ -836,12 +838,13 @@ def test_peft_merge_unmerge(set_cache_for_ci): @distributed_test(world_size=8, tp_size=2, pp_size=1) def test_get_original_merged_weights_for_vllm(set_cache_for_ci): - """Test that get_original_merged_weights_for_vllm produces correct unsharded, original-format weights.""" tp_size = get_tensor_model_parallel_size() tp_rank = get_tensor_model_parallel_rank() + pp_size = get_pipeline_model_parallel_size() trn_config = TrainingNeuronConfig( tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, ) mixed_precision = MixedPrecisionConfig(mode="FULL_BF16") accelerator = NeuronAccelerator(trn_config=trn_config, mixed_precision_config=mixed_precision) @@ -874,24 +877,20 @@ def test_get_original_merged_weights_for_vllm(set_cache_for_ci): model.eval() # Get original merged weights for vLLM - from optimum.neuron.peft.utils.vllm import get_original_merged_weights_for_vllm original_weights = get_original_merged_weights_for_vllm(model) xm.mark_step() # Only check on main process since get_original_merged_weights_for_vllm returns same weights on all ranks if tp_rank == 0: - # Test 1: Check weights are unsharded (full size) - # For Llama-2-7b with TP=2, hidden_size=4096 + # Test 1: Check that we have unsharded weights (full size) and with the original naming / fusing / unfusing. hidden_size = model.config.hidden_size intermediate_size = model.config.intermediate_size - # Check attention projection sizes (should be full, not sharded) assert "model.layers.0.self_attn.q_proj.weight" in original_weights q_proj_weight = original_weights["model.layers.0.self_attn.q_proj.weight"] assert q_proj_weight.shape == (hidden_size, hidden_size), \ f"q_proj should be unsharded {hidden_size}x{hidden_size}, got {q_proj_weight.shape}" - # Test 2: Check weights are in original format (separate gate/up, not fused) # The custom model uses fused gate_up_proj, but original format should have separate projections assert "model.layers.0.mlp.gate_proj.weight" in original_weights assert "model.layers.0.mlp.up_proj.weight" in original_weights @@ -903,19 +902,19 @@ def test_get_original_merged_weights_for_vllm(set_cache_for_ci): assert gate_proj_weight.shape == (intermediate_size, hidden_size) assert up_proj_weight.shape == (intermediate_size, hidden_size) - # Test 3: Verify LoRA delta is merged + # Test 2: Verify LoRA delta is merged # Since we set lora_B += 0.1, the merged weights should differ from original base weights # Get the corresponding original base weight (need to map from PEFT name to base name) base_q_proj_name = "model.layers.0.self_attn.q_proj.weight" if base_q_proj_name in original_base_weights: original_q_proj = original_base_weights[base_q_proj_name] - merged_q_proj = original_weights["model.layers.0.self_attn.q_proj.weight"] + merged_q_proj = original_weights[base_q_proj_name] # Weights should be different (LoRA delta was merged) assert not torch.allclose(original_q_proj, merged_q_proj, rtol=1e-4), \ "Merged weight should differ from original base weight (LoRA delta should be added)" - # Test 4: Verify model state is restored (adapters are unmerged) + # Test 3: Verify model state is restored (adapters are unmerged) # After calling get_original_merged_weights_for_vllm, the model should be back to unmerged state for module in model.modules(): if hasattr(module, "merged"): From 21dc065b577161b3490fe6c66c895fc3428111e3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 21 Nov 2025 17:54:44 +0100 Subject: [PATCH 44/78] fix test --- .../models/training/transformations_utils.py | 14 ++++++----- optimum/neuron/peft/tuners/lora/layer.py | 23 ++++++++++++++++++- tests/training/test_custom_modeling.py | 22 ++++++++++-------- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/optimum/neuron/models/training/transformations_utils.py b/optimum/neuron/models/training/transformations_utils.py index 0e5e87c3b..2304b3e4d 100644 --- a/optimum/neuron/models/training/transformations_utils.py +++ b/optimum/neuron/models/training/transformations_utils.py @@ -533,6 +533,9 @@ def _lora_adapt_state_dict( f"{module_fully_qualified_name}.{name}.lora_A.{param_name}" for name in self.linear_names ] + if not all(name in state_dict for name in lora_A_weight_names): + continue + logger.warning("Taking the mean of the LoRA A weights since there is only one LoRA A weight after fusing.") lora_A_weight = torch.mean( torch.stack([state_dict.pop(name) for name in lora_A_weight_names], dim=0), @@ -650,9 +653,7 @@ def _lora_to_original_weights( break if weight_name is None or to_duplicate_name is None: - raise ValueError( - f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}." - ) + continue # When saved, the name of the adapter is removed in the weight qualified name since weights for each # adapter are saved separately. @@ -700,9 +701,7 @@ def _lora_to_original_weights( if to_concat_and_duplicate_name is not None and to_unfuse_name is not None: break if to_concat_and_duplicate_name is None or to_unfuse_name is None: - raise ValueError( - f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}." - ) + continue weight_name_without_adapter_name = remove_adapter_name(to_concat_and_duplicate_name) linear_sharded_weights = sharded_state_dicts[weight_name_without_adapter_name] @@ -1100,6 +1099,9 @@ def _lora_adapt_state_dict( lora_A_weight_names = [lora_A_q_name, lora_A_k_name, lora_A_v_name] + if not all(name in state_dict for name in lora_A_weight_names): + continue + logger.warning("Taking the mean of the LoRA A weights since there is only one LoRA A weight after fusing.") lora_A_weight = torch.mean( torch.stack([state_dict.pop(name) for name in lora_A_weight_names], dim=0), diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index 79c1f2e07..07f0f768d 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -25,6 +25,10 @@ ) from neuronx_distributed.parallel_layers.layers import ParallelEmbedding as NxDParallelEmbedding from neuronx_distributed.parallel_layers.mappings import scatter_to_sequence_parallel_region +from neuronx_distributed.parallel_layers.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_size, +) from torch import nn from ....utils.import_utils import is_peft_available @@ -500,7 +504,7 @@ def update_layer( self.lora_B[adapter_name] = NxDGQAQKVColumnParallelLinear( input_size=r, output_sizes=self.out_features, - bias=False, + bias=lora_bias, gather_output=self.base_layer.gather_output, dtype=self.base_layer.dtype, init_method=self.base_layer.arg_init_method, @@ -817,6 +821,23 @@ def get_delta_weight(self, adapter: str) -> torch.Tensor: self.lora_embedding_A[adapter] = weight_A.to(dtype) self.lora_embedding_B[adapter] = weight_B.to(dtype) + tp_size = get_tensor_model_parallel_size() + if tp_size > 1: + tp_rank = get_tensor_model_parallel_rank() + base_layer = self.get_base_layer() + # We need to slice the delta weight to match the local shard + # The ParallelEmbedding layer pads the weight so we need to handle that + vocab_size_per_partition = base_layer.weight.shape[0] + start_idx = tp_rank * vocab_size_per_partition + end_idx = start_idx + vocab_size_per_partition + + # Pad output_tensor if needed (last rank might need padding) + if end_idx > output_tensor.shape[0]: + pad_len = end_idx - output_tensor.shape[0] + output_tensor = torch.nn.functional.pad(output_tensor, (0, 0, 0, pad_len)) + + output_tensor = output_tensor[start_idx:end_idx, :] + return output_tensor def merge(self, safe_merge: bool = False, adapter_names: list[str] | None = None) -> None: diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index 2ba044cb8..f143e5abb 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -840,6 +840,7 @@ def test_peft_merge_unmerge(set_cache_for_ci): def test_get_original_merged_weights_for_vllm(set_cache_for_ci): tp_size = get_tensor_model_parallel_size() tp_rank = get_tensor_model_parallel_rank() + tp_group = get_tensor_model_parallel_group(as_list=True) pp_size = get_pipeline_model_parallel_size() trn_config = TrainingNeuronConfig( @@ -876,8 +877,14 @@ def test_get_original_merged_weights_for_vllm(set_cache_for_ci): model = accelerator.prepare_model(model) model.eval() + base_q_proj_name = "model.layers.0.self_attn.q_proj.weight" + original_q_proj = original_base_weights[base_q_proj_name].to("xla") + original_q_proj = xm.all_gather(original_q_proj, dim=0, groups=tp_group) + original_q_proj = original_q_proj.cpu() + # Get original merged weights for vLLM original_weights = get_original_merged_weights_for_vllm(model) + original_weights = {k: v.cpu() for k, v in original_weights.items()} xm.mark_step() # Only check on main process since get_original_merged_weights_for_vllm returns same weights on all ranks @@ -903,16 +910,11 @@ def test_get_original_merged_weights_for_vllm(set_cache_for_ci): assert up_proj_weight.shape == (intermediate_size, hidden_size) # Test 2: Verify LoRA delta is merged - # Since we set lora_B += 0.1, the merged weights should differ from original base weights - # Get the corresponding original base weight (need to map from PEFT name to base name) - base_q_proj_name = "model.layers.0.self_attn.q_proj.weight" - if base_q_proj_name in original_base_weights: - original_q_proj = original_base_weights[base_q_proj_name] - merged_q_proj = original_weights[base_q_proj_name] - - # Weights should be different (LoRA delta was merged) - assert not torch.allclose(original_q_proj, merged_q_proj, rtol=1e-4), \ - "Merged weight should differ from original base weight (LoRA delta should be added)" + merged_q_proj = original_weights[base_q_proj_name] + + # Weights should be different (LoRA delta was merged) + assert not torch.allclose(original_q_proj, merged_q_proj, rtol=1e-4), \ + "Merged weight should differ from original base weight (LoRA delta should be added)" # Test 3: Verify model state is restored (adapters are unmerged) # After calling get_original_merged_weights_for_vllm, the model should be back to unmerged state From 58c438cfc5065be7d687177af8ad126ccff71cb3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 24 Nov 2025 17:08:33 +0100 Subject: [PATCH 45/78] trainer runs with mock but produces NaNs --- examples/training/grpo_qwen3/finetune_grpo_qwen3.py | 7 ++++--- optimum/neuron/models/training/transformations_utils.py | 2 -- tests/training/test_custom_modeling.py | 3 --- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index 47df95de8..97be45f02 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -141,7 +141,8 @@ def train(model_id, tokenizer, dataset, training_args): trn_config, torch_dtype=dtype, # Use FlashAttention2 for better performance - attn_implementation="flash_attention_2", + # attn_implementation="flash_attention_2", + attn_implementation="eager", ) # LoRA configuration for efficient fine-tuning @@ -160,8 +161,8 @@ def train(model_id, tokenizer, dataset, training_args): # GRPO-specific configuration grpo_config = NeuronGRPOConfig( # Generation parameters - max_prompt_length=1024, # Maximum prompt length - max_completion_length=1024, # Maximum completion length + max_prompt_length=512, # Maximum prompt length + max_completion_length=268, # Maximum completion length num_generations=4, # Number of completions to generate per prompt (G in paper) temperature=0.8, # Sampling temperature # GRPO algorithm parameters diff --git a/optimum/neuron/models/training/transformations_utils.py b/optimum/neuron/models/training/transformations_utils.py index 2304b3e4d..a53e57841 100644 --- a/optimum/neuron/models/training/transformations_utils.py +++ b/optimum/neuron/models/training/transformations_utils.py @@ -1664,9 +1664,7 @@ def create_parameter_metadata(model) -> dict[str, dict[str, Any]]: consolidating the sharded state dicts. """ metadata = {"parameters": {}, "model_weight_transformation_specs": []} - print("micka", model.parameters_for_current_stage) for name, param in model.named_parameters(): - print(name) if name not in model.parameters_for_current_stage: continue tensor_model_parallel = getattr(param, "tensor_model_parallel", False) diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index f143e5abb..08099f9b5 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -807,9 +807,6 @@ def test_peft_merge_unmerge(set_cache_for_ci): logits_merged = output_merged.logits.clone() xm.mark_step() - print(output_merged) - print(output_unmerged) - # Outputs should match assert torch.allclose(logits_unmerged, logits_merged, rtol=1e-3, atol=1e-3), \ f"Merged and unmerged outputs should match. Max diff: {(logits_unmerged - logits_merged).abs().max().item()}" From b0d3056aa8486f0fa781d6e9c8d1cce1914da806 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 25 Nov 2025 17:01:12 +0100 Subject: [PATCH 46/78] add vllm file --- optimum/neuron/peft/utils/vllm.py | 126 ++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 optimum/neuron/peft/utils/vllm.py diff --git a/optimum/neuron/peft/utils/vllm.py b/optimum/neuron/peft/utils/vllm.py new file mode 100644 index 000000000..923f14aec --- /dev/null +++ b/optimum/neuron/peft/utils/vllm.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from neuronx_distributed.parallel_layers.parallel_state import ( + get_tensor_model_parallel_group, +) + +from ...models.training import create_parameter_metadata +from ...models.training.transformations_utils import to_original_weights + + +def get_original_merged_weights_for_vllm(model) -> dict[str, torch.Tensor]: + """ + Gets original (unsharded, untransformed) weights from a NeuronPeftModel for vLLM. + + Steps: + 1. Merge LoRA adapters in-place on each TP shard. This way we go from "base + LoRA" to "merged weights" on each + shard. + 2. Gather the sharded state dicts from all TP ranks. + 3. Get transformation specs and metadata required to revert to original weights. + 4. Use `to_original_weights` to revert the gathered sharded weights to original weights. + 5. Unmerge LoRA adapters to restore the model state. + """ + from ...peft import NeuronPeftModel + + if not isinstance(model, NeuronPeftModel): + raise TypeError(f"Expected NeuronPeftModel, got {type(model).__name__}") + + tp_group = get_tensor_model_parallel_group(as_list=True) + + # Step 1: Merge LoRA adapters (modifies weights in-place on each shard) + model.merge_adapter() + torch_xla.sync() + + # Step 2: Gather state dict across TP ranks + # Get local state dict (sharded weights from this rank) + # Strip PEFT prefixes to make it look like a regular (non-PEFT) model + local_state_dict = {} + for name, param in model.named_parameters(): + # Skip LoRA adapter parameters (lora_A, lora_B, lora_embedding_A, lora_embedding_B, etc.) + if "lora" in name.lower(): + continue + + # Strip PEFT prefixes: "base_model.model." and ".base_layer" + # This makes merged PEFT weights look like regular model weights + clean_name = name.removeprefix("base_model.model.").replace(".base_layer", "") + + # Skip modules_to_save and original_module (PEFT-specific) + if "modules_to_save" in clean_name or "original_module" in clean_name: + continue + + local_state_dict[clean_name] = param.data + + # Gather all TP rank state dicts into format: {name: [tensor_rank0, tensor_rank1, ...]} + sharded_state_dicts = {} + for name, local_tensor in local_state_dict.items(): + gathered = xm.all_gather(local_tensor, dim=0, groups=tp_group) + gathered_tensors = list(torch.split(gathered, local_tensor.size(0), dim=0)) + sharded_state_dicts[name] = gathered_tensors + torch_xla.sync() + + # Step 3: Get transformation specs and metadata + # For NeuronPeftModel: model.base_model.model is the actual NeuronModelForCausalLM + base_model = model.base_model.model + + # Get transformation specs from base model + transformation_specs = [] + for module in base_model.modules(): + if hasattr(module, "specs"): + transformation_specs.append(copy.deepcopy(module.specs)) + + for specs in transformation_specs: + specs.module_fully_qualified_name = specs.module_fully_qualified_name.removeprefix("base_model.model.").replace(".base_layer", "") + for spec in specs: + spec.peft_type = None + + # Create parameter metadata from peft model + metadata = create_parameter_metadata(model) + parameters_metadata = metadata["parameters"] + + # Clean parameter names in metadata to match the cleaned state dict + cleaned_parameters_metadata = {} + for name, param_metadata in parameters_metadata.items(): + # Skip LoRA parameters + if "lora" in name.lower(): + continue + + # Strip PEFT prefixes + clean_name = name.removeprefix("base_model.model.").replace(".base_layer", "") + + # Skip modules_to_save and original_module + if "modules_to_save" in clean_name or "original_module" in clean_name: + continue + + cleaned_parameters_metadata[clean_name] = param_metadata + parameters_metadata = cleaned_parameters_metadata + + # Step 5: Transform to original weights + original_state_dict = to_original_weights( + transformation_specs, + sharded_state_dicts, + parameters_metadata, + ) + + # Step 6: Unmerge LoRA adapters to restore model state + model.unmerge_adapter() + torch_xla.sync() + + return original_state_dict From 4f42f21e5c97d60a97377ac7fc8b4ce5f20c5b39 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 26 Nov 2025 12:43:43 +0100 Subject: [PATCH 47/78] add collectives for python objects --- .../grpo_qwen3/finetune_grpo_qwen3.py | 2 +- optimum/neuron/accelerate/accelerator.py | 56 +------- optimum/neuron/accelerate/utils/__init__.py | 7 + optimum/neuron/trainers/grpo_mocks.py | 21 ++- optimum/neuron/trainers/grpo_trainer.py | 123 +++++++++++++----- pyproject.toml | 3 + 6 files changed, 119 insertions(+), 93 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index 97be45f02..6a93f9572 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -173,7 +173,7 @@ def train(model_id, tokenizer, dataset, training_args): # vLLM parameters use_vllm=True, # Use vLLM for generation (required for Neuron) vllm_mode="server", # Use vLLM server mode - vllm_server_host="localhost", + vllm_server_host="0.0.0.0", vllm_server_port=8000, # Standard training arguments from NeuronTrainingArguments **args, diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 7e9e8288a..ca1ec1f70 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -15,7 +15,6 @@ import contextlib import os -import pickle import re import shutil import sys @@ -30,7 +29,7 @@ from accelerate import Accelerator from accelerate.checkpointing import save_accelerator_state, save_custom_state from accelerate.utils import AutocastKwargs, DistributedType -from accelerate.utils.operations import gather_object, recursively_apply +from accelerate.utils.operations import recursively_apply from neuronx_distributed import parallel_layers from neuronx_distributed.optimizer import NeuronZero1Optimizer from neuronx_distributed.parallel_layers.parallel_state import ( @@ -61,14 +60,13 @@ from .optimizer import NeuronAcceleratedOptimizer from .scheduler import NeuronAcceleratedScheduler from .state import NeuronAcceleratorState -from .utils import ( - patch_accelerate_is_torch_xla_available, -) from .utils.dataclasses import MixedPrecisionConfig, MixedPrecisionMode from .utils.misc import ( apply_activation_checkpointing, create_patched_save_pretrained, + patch_accelerate_is_torch_xla_available, ) +from .utils.operations import gather_object # Setup logging so that the main process logs at the INFO level and the others are silent. @@ -560,54 +558,6 @@ def gather(self, tensor, sync: bool = False): torch_xla.sync() return gathered - def gather_object(self, obj: Any) -> list[Any]: - """ - Gathers arbitrary objects across XLA-distributed processes. - Returns list of objects from all ranks on all ranks. - - Note: Requires two all-gather operations (lengths then data). - For small objects, this overhead may be significant. - """ - world_size = get_data_parallel_size() - - # Early exit for single process - if world_size == 1: - return [obj] - - groups = get_data_parallel_group(as_list=True) - - serialized = pickle.dumps(obj) - byte_len = len(serialized) - - byte_tensor = torch.frombuffer(serialized, dtype=torch.uint8).clone() - byte_tensor = byte_tensor.to(xm.xla_device()) - - len_tensor = torch.tensor([byte_len], dtype=torch.int64, device=byte_tensor.device) - # all_gather concatenates along dim=0, so [1] -> [world_size] - gathered_lengths = xm.all_gather(len_tensor, dim=0, groups=groups, pin_layout=False) - - torch_xla.sync() - max_len = int(gathered_lengths.max().item()) - - padded = torch.zeros(max_len, dtype=torch.uint8, device=byte_tensor.device) - padded[:byte_len] = byte_tensor - - # all_gather concatenates, so [max_len] -> [world_size * max_len] - gathered_data = xm.all_gather(padded, dim=0, groups=groups, pin_layout=False) - - torch_xla.sync() - gathered_data_cpu = gathered_data.cpu() - gathered_lengths_cpu = gathered_lengths.cpu() - - results = [] - offset = 0 - for i in range(world_size): - actual_len = int(gathered_lengths_cpu[i].item()) - valid_bytes = gathered_data_cpu[offset:offset + max_len][:actual_len].numpy().tobytes() - results.append(pickle.loads(valid_bytes)) - offset += max_len - - return results def gather_for_metrics(self, input_data, use_gather_object: bool = False, sync: bool = False): try: diff --git a/optimum/neuron/accelerate/utils/__init__.py b/optimum/neuron/accelerate/utils/__init__.py index 6218b6f31..d46cba470 100644 --- a/optimum/neuron/accelerate/utils/__init__.py +++ b/optimum/neuron/accelerate/utils/__init__.py @@ -15,3 +15,10 @@ from .dataclasses import MixedPrecisionConfig, MixedPrecisionMode from .misc import patch_accelerate_is_torch_xla_available +from .operations import ( + broadcast_object, + broadcast_object_to_data_parallel_group, + broadcast_object_to_pipeline_model_parallel_group, + broadcast_object_to_tensor_model_parallel_group, + gather_object, +) diff --git a/optimum/neuron/trainers/grpo_mocks.py b/optimum/neuron/trainers/grpo_mocks.py index 39e819106..a900bf0ed 100644 --- a/optimum/neuron/trainers/grpo_mocks.py +++ b/optimum/neuron/trainers/grpo_mocks.py @@ -21,6 +21,8 @@ vLLM server setup. """ +import random + from optimum.utils import logging @@ -108,18 +110,23 @@ def generate( # Generate n completions per prompt for i in range(n): - # Generate mock completion - # Use a simple pattern: repeat EOS token to create fixed-length completion + # Generate mock completion with realistic varied tokens # In real scenario, this would be actual LLM generation completion_length = min(max_tokens, self.max_completion_length) - # Generate completion: cycle through safe token IDs - completion = [self.tokenizer.eos_token_id] * completion_length + # Generate completion with varied tokens from the vocabulary + # Avoid special tokens by using a safe range of token IDs + vocab_size = self.tokenizer.vocab_size + # Use token IDs in a safe range (skip first 100 tokens which often include special tokens) + min_token_id = min(100, vocab_size - 1) + max_token_id = vocab_size - 1 + completion = [random.randint(min_token_id, max_token_id) for _ in range(completion_length)] completion_ids.append(completion) - # Generate mock logprobs (uniform negative values) - # Real logprobs would come from the model's probability distribution - completion_logprobs = [-1.0] * completion_length + # Generate realistic mock logprobs + # Real language model logprobs typically range from -2 to -10 + # with occasional values outside this range + completion_logprobs = [-random.uniform(2.0, 8.0) for _ in range(completion_length)] logprobs.append(completion_logprobs) return { diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 4ca612b25..add90eb79 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -23,6 +23,11 @@ import torch_xla.core.xla_model as xm from accelerate.utils import set_seed from neuronx_distributed import parallel_layers +from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + get_pipeline_model_parallel_rank, + get_tensor_model_parallel_rank, +) from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from optimum.utils import logging from torch.utils.data import Dataset, IterableDataset, Sampler @@ -35,6 +40,10 @@ ) from transformers.utils import is_rich_available +from ..accelerate.utils import ( + broadcast_object_to_pipeline_model_parallel_group, + broadcast_object_to_tensor_model_parallel_group, +) from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model from ..peft.utils.vllm import get_original_merged_weights_for_vllm @@ -397,7 +406,10 @@ def make_inputs_require_grad(module, input, output): from ..utils import is_vllm_available # MOCK FLAG: Change this to False when real vLLM server is ready - USE_MOCK_VLLM = True + USE_MOCK_VLLM = False + # TODO: Remove this when mock is no longer used - tracks mock usage to disable + # importance sampling correction which doesn't work with mock's random tokens + self._using_mock_vllm = USE_MOCK_VLLM if USE_MOCK_VLLM: logger.warning( @@ -413,17 +425,18 @@ def make_inputs_require_grad(module, input, output): if not is_vllm_available(): raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") - # Setup vLLM server client (only on main process) - if self.accelerator.is_main_process: - from trl.extras.vllm_client import VLLMClient + # Setup vLLM server client (all processes need it to call the server) + from trl.extras.vllm_client import VLLMClient - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + # Only main process initializes the communicator for weight updates + if self.accelerator.is_main_process: + self.vllm_client.init_communicator(device="cpu") # vLLM specific sampling arguments self.guided_decoding_regex = args.vllm_guided_decoding_regex @@ -456,6 +469,12 @@ def make_inputs_require_grad(module, input, output): self.reward_funcs[i] = self.accelerator.prepare_model( reward_func, evaluation_mode=True, device_placement=True ) + + + self.dp_rank = get_data_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() + self.pp_rank = get_pipeline_model_parallel_rank() + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: if dataset is None: dataset = self.train_dataset @@ -595,7 +614,8 @@ def _move_model_to_vllm(self): name = self._fix_param_name_to_vllm(name) if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, weight) + pass + # self.vllm_client.update_named_param(name, weight) elif self.vllm_mode == "colocate": llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights([(name, weight)]) @@ -622,7 +642,7 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): The main difference is avoiding gather_object which doesn't work on XLA. MOCK MODE: Each process generates locally without gathering/broadcasting. - REAL SERVER MODE: May need gather_object workaround - test when implementing! + REAL SERVER MODE: Only main process generates, results are broadcast to all processes. Args: prompts: List of prompt strings @@ -636,7 +656,6 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - # For mock vLLM, generate locally on each process (no gather/broadcast needed) # Take unique prompts since we have num_generations duplicates prompts_text = [prompt if isinstance(prompt, str) else prompt["content"] for prompt in prompts] ordered_set_of_prompts = prompts_text[:: self.num_generations] @@ -646,21 +665,56 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): else: ordered_set_of_images = None - # Generate using mock vLLM client - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) + # For mock vLLM: each process generates independently + # For real vLLM server: only main process generates, then broadcast + if self._using_mock_vllm: + # Each process generates locally + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) + else: + # Real vLLM server mode: only main process generates + if self.tp_rank == self.pp_rank == 0: + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) + else: + output = { + "prompt_ids": [[]], + "completion_ids": [[]], + "logprobs": [[]], + } + output = None + + trn_config = self.accelerator.state.trn_config + # TODO: change that to a better default. + fixed_size = int(2e6) # 2MB fixed size for buffer, should be enough. + if trn_config.tensor_parallel_size > 1: + output = broadcast_object_to_tensor_model_parallel_group(output, fixed_size=fixed_size) + if trn_config.pipeline_parallel_size > 1: + output = broadcast_object_to_pipeline_model_parallel_group(output, fixed_size=fixed_size) # Repeat prompt_ids num_generations times to match completion_ids prompt_ids = [ids for ids in output["prompt_ids"] for _ in range(self.num_generations)] @@ -876,7 +930,9 @@ def _generate_and_score_completions( old_per_token_logps = None # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch - if self.use_vllm and self.vllm_importance_sampling_correction: + # TODO: Remove _using_mock_vllm check when mock is no longer used - mock's random tokens + # cause -inf logprobs from the model, leading to NaN in importance sampling computation + if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) importance_sampling_ratio = torch.clamp( importance_sampling_ratio, max=self.vllm_importance_sampling_cap @@ -983,7 +1039,8 @@ def _generate_and_score_completions( # if images is not None: # self._logs["images"].extend(self.accelerator.gather_object(images)) - if self.use_vllm and self.vllm_importance_sampling_correction: + # TODO: Remove _using_mock_vllm check when mock is no longer used + if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) # Original code was: # delta = delta[completion_mask.bool()] @@ -1063,7 +1120,8 @@ def _generate_and_score_completions( } if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps - if self.use_vllm and self.vllm_importance_sampling_correction: + # TODO: Remove _using_mock_vllm check when mock is no longer used + if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps @@ -1221,7 +1279,8 @@ def _compute_loss(self, model, inputs): if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask - if self.use_vllm and self.vllm_importance_sampling_correction: + # TODO: Remove _using_mock_vllm check when mock is no longer used + if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] if self.beta != 0.0: diff --git a/pyproject.toml b/pyproject.toml index 7c90a7eca..1a49c1711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,9 @@ include = ["optimum*"] [tool.setuptools.package-data] "*" = ["*"] +[tool.uv] +extra-index-url = ["https://pip.repos.neuron.amazonaws.com"] + [tool.ruff] line-length = 119 From 23b00902f88a995e8988cac17d80f90be40d577a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 26 Nov 2025 18:27:12 +0100 Subject: [PATCH 48/78] add vllm_client for CPU --- optimum/neuron/trainers/extras/vllm_client.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 optimum/neuron/trainers/extras/vllm_client.py diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py new file mode 100644 index 000000000..1d5a2799e --- /dev/null +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import time +from typing import Union + +import torch +from trl.extras.vllm_client import VLLMClient as TRLVLLMClient + +from trl.import_utils import is_vllm_available + + +if is_vllm_available(): + from vllm.distributed.utils import StatelessProcessGroup +else: + class StatelessProcessGroup: + pass + + + + +class VLLMClient(TRLVLLMClient): + """ + Extension of TRL's VLLMClient that adds CPU support for development and testing. + + This class inherits all functionality from trl.extras.vllm_client.VLLMClient and only + overrides the init_communicator method to add CPU fallback support when neither CUDA + nor XPU devices are available. + """ + + def init_communicator(self, device: Union[torch.device, str, int] = 0): + """ + Initializes the weight update group in a distributed setup for model synchronization. + + This method extends the parent implementation to support CPU-only environments by adding + a CPU fallback communicator when neither XPU nor CUDA devices are available. + + Args: + device (`torch.device`, `str`, or `int`, *optional*, defaults to `0`): + Device of trainer main process. It's the device that will be used for the weights synchronization. Can + be a `torch.device` object, a string like `'cuda:0'`, or an integer device index. + """ + # Get the world size from the server + import requests + url = f"{self.base_url}/get_world_size/" + response = requests.get(url) + if response.status_code == 200: + vllm_world_size = response.json()["world_size"] + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + world_size = vllm_world_size + 1 # add the client to the world + self.rank = vllm_world_size # the client's rank is the last process + + # Initialize weight update group + url = f"{self.base_url}/init_communicator/" + # Will simplify it after torch xpu 2.9 support get uuid. + if is_torch_xpu_available(): + if hasattr(torch.xpu.get_device_properties(device), "uuid"): + client_device_uuid = str(torch.xpu.get_device_properties(device).uuid) + else: + client_device_uuid = "42" + elif torch.cuda.is_available(): + client_device_uuid = str(torch.cuda.get_device_properties(device).uuid) + else: + # CPU fallback - use dummy UUID + client_device_uuid = "42" + + # In the server side, the host is set to 0.0.0.0 + response = self.session.post( + url, + json={ + "host": "0.0.0.0", + "port": self.group_port, + "world_size": world_size, + "client_device_uuid": client_device_uuid, + }, + ) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + # Brief delay to allow server initialization. While not strictly required (client socket will retry on + # connection failure), this prevents log warnings like: + # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 + time.sleep(0.1) + + # Set up the communication group for weight broadcasting + if is_torch_xpu_available(): + store = torch.distributed.TCPStore( + host_name=self.host, port=self.group_port, world_size=world_size, is_master=(self.rank == 0) + ) + prefixed_store = c10d.PrefixStore("client2server", store) + pg = c10d.ProcessGroupXCCL( + store=prefixed_store, + rank=self.rank, + size=world_size, + ) + self.communicator = pg + elif torch.cuda.is_available(): + pg = StatelessProcessGroup.create( + host=self.host, port=self.group_port, rank=self.rank, world_size=world_size + ) + self.communicator = PyNcclCommunicator(pg, device=device) + else: + # CPU fallback - create a custom communicator that uses object broadcasting + from collections import namedtuple + Group = namedtuple('Group', 'barrier') + + class CPUCommunicator: + def __init__(self, store, rank): + self.rank = rank + self.store = store + self.group = Group(barrier=self.barrier) + + def broadcast(self, tensor, src): + # Move tensor to CPU to avoid issues on the server side when running vLLM+CPU + tensor = tensor.cpu() + self.store.broadcast_obj(tensor, src=self.rank) + + def barrier(self): + self.store.barrier() + + def __del__(self): + del self.store + + pg = StatelessProcessGroup.create( + host=self.host, port=self.group_port, rank=self.rank, world_size=world_size + ) + self.communicator = CPUCommunicator(pg, self.rank) + + # When the client object is deleted, close the weight update group + atexit.register(self.close_communicator) From 1e15a27fbaa48c26e26fe882ecf95474eb2c5a04 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 26 Nov 2025 18:35:13 +0100 Subject: [PATCH 49/78] add VLLMClient and collectives on python objects --- optimum/neuron/accelerate/utils/operations.py | 192 ++++++++++++++++++ optimum/neuron/trainers/extras/__init__.py | 19 ++ optimum/neuron/trainers/extras/vllm_client.py | 98 ++++----- 3 files changed, 246 insertions(+), 63 deletions(-) create mode 100644 optimum/neuron/accelerate/utils/operations.py create mode 100644 optimum/neuron/trainers/extras/__init__.py diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py new file mode 100644 index 000000000..5db12c6e2 --- /dev/null +++ b/optimum/neuron/accelerate/utils/operations.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from typing import Any, Callable + +import torch +import numpy as np +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +from neuronx_distributed.parallel_layers.parallel_state import ( + get_context_model_parallel_size, + get_data_parallel_group, + get_data_parallel_replica_groups, + get_data_parallel_size, + get_pipeline_model_parallel_replica_groups, + get_tensor_model_parallel_replica_groups, +) + + +def broadcast_object(obj: Any, src: int = 0, groups: list[list[int]] | None = None, world_size_function: Callable[[], int] = xr.world_size, get_rank_function: Callable[[], int] = xr.global_ordinal, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed processes. + Returns the object from the source rank on all ranks. + If `groups` is specified, broadcast is done separately in each group, and the `src` rank is relative to each group. + """ + world_size = world_size_function() + if world_size == 1: + return obj + + rank = get_rank_function() + + if rank == src: + bytes_ = pickle.dumps(obj) + length = len(bytes_) + # Ensure the serialized object fits in the fixed size if specified. + # Otherwise we would corrupt the transferred data. + if fixed_size is not None and length > fixed_size: + raise ValueError(f"Serialized object size {length} exceeds the specified fixed_size {fixed_size}") + else: + bytes_ = b"" + length = 0 + + # First, broadcast the length of the serialized object. + max_length = xm.all_reduce("max", torch.tensor(length, dtype=torch.int64).to(xm.xla_device())) + max_length = max_length.cpu() + + # Ensure all ranks agree on the max length. + torch_xla.sync() + + max_length = int(max_length.item()) + + if fixed_size is not None: + target_length = fixed_size + else: + target_length = max_length + + if rank == src: + np_buffer = np.frombuffer(bytes_, dtype=np.uint8) + data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) + padding_length = target_length - length + if padding_length > 0: + padding_tensor = torch.zeros(padding_length, dtype=torch.uint8, device=xm.xla_device()) + data_tensor = torch.cat([data_tensor, padding_tensor], dim=0) + else: + data_tensor = torch.zeros(target_length, dtype=torch.uint8, device=xm.xla_device()) + + data_tensor = xm.all_reduce("sum", data_tensor, groups=groups) + torch_xla.sync() + + # In this case we truncate the tensor to the original max length on device to minimize the data transfer from device + # to host. + if fixed_size is None: + data_tensor = data_tensor[:max_length] + + data_tensor_cpu = data_tensor.cpu() + reduced_bytes = data_tensor_cpu.numpy().tobytes() + + # Truncate to the original max length on host if fixed_size is specified to avoid changing shapes on device. + if fixed_size is not None: + reduced_bytes = reduced_bytes[:max_length] + + return pickle.loads(reduced_bytes) + + +def broadcast_object_to_data_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed data parallel group. + Returns the object from the source rank on all ranks in the data parallel group. + """ + groups = get_data_parallel_replica_groups() + return broadcast_object( + obj, + src=src, + groups=groups, + world_size_function=get_data_parallel_size, + get_rank_function=get_data_parallel_replica_groups, + fixed_size=fixed_size, + ) + +def broadcast_object_to_tensor_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed tensor model parallel group. + Returns the object from the source rank on all ranks in the tensor model parallel group. + """ + groups = get_tensor_model_parallel_replica_groups() + return broadcast_object( + obj, + src=src, + groups=groups, + world_size_function=get_context_model_parallel_size, + get_rank_function=get_tensor_model_parallel_replica_groups, + fixed_size=fixed_size, + ) + +def broadcast_object_to_pipeline_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed pipeline model parallel group. + Returns the object from the source rank on all ranks in the pipeline model parallel group. + """ + groups = get_pipeline_model_parallel_replica_groups() + return broadcast_object( + obj, + src=src, + groups=groups, + world_size_function=get_context_model_parallel_size, + get_rank_function=get_pipeline_model_parallel_replica_groups, + fixed_size=fixed_size, + ) + + +def gather_object(obj: Any) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed processes. + Returns list of objects from all ranks on all ranks. + + Note: Requires two all-gather operations (lengths then data). + For small objects, this overhead may be significant. + """ + world_size = get_data_parallel_size() + + # Early exit for single process + if world_size == 1: + return [obj] + + groups = get_data_parallel_group(as_list=True) + + serialized = pickle.dumps(obj) + byte_len = len(serialized) + + byte_tensor = torch.frombuffer([serialized], dtype=torch.uint8).to(xm.xla_device()) + + len_tensor = torch.tensor([byte_len], dtype=torch.int64, device=byte_tensor.device) + # all_gather concatenates along dim=0, so [1] -> [world_size] + gathered_lengths = xm.all_gather(len_tensor, dim=0, groups=groups, pin_layout=False) + torch_xla.sync() + + max_len = torch.max(gathered_lengths) + max_len = int(max_len.item()) + padded = torch.zeros(max_len, dtype=torch.uint8, device=byte_tensor.device) + padded[:byte_len] = byte_tensor + + # all_gather concatenates, so [max_len] -> [world_size * max_len] + gathered_data = xm.all_gather(padded, dim=0, groups=groups, pin_layout=False) + torch_xla.sync() + + gathered_data_cpu = gathered_data.cpu() + gathered_lengths_cpu = gathered_lengths.cpu() + + results = [] + offset = 0 + for i in range(world_size): + actual_len = int(gathered_lengths_cpu[i].item()) + valid_bytes = gathered_data_cpu[offset:offset + max_len][:actual_len].numpy().tobytes() + valid_bytes = valid_bytes[0] + results.append(pickle.loads(valid_bytes)) + offset += max_len + + return results diff --git a/optimum/neuron/trainers/extras/__init__.py b/optimum/neuron/trainers/extras/__init__.py new file mode 100644 index 000000000..da5732120 --- /dev/null +++ b/optimum/neuron/trainers/extras/__init__.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vllm_client import VLLMClient + + +__all__ = ["VLLMClient"] diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index 1d5a2799e..e0c41ac4c 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -14,10 +14,13 @@ # limitations under the License. import atexit +import requests import time from typing import Union +from collections import namedtuple import torch +import torch_xla from trl.extras.vllm_client import VLLMClient as TRLVLLMClient from trl.import_utils import is_vllm_available @@ -29,32 +32,48 @@ class StatelessProcessGroup: pass +# Set up the communication group for weight broadcasting using CPU communicator +Group = namedtuple('Group', 'barrier') +class CPUCommunicator: + def __init__(self, store, rank): + self.rank = rank + self.store = store + self.group = Group(barrier=self.barrier) + def broadcast(self, tensor, src): + # Move tensor to CPU to ensure compatibility with vLLM server + if tensor.device.type == "xla": + tensor = tensor.cpu() + torch_xla.sync() + self.store.broadcast_obj(tensor, src=self.rank) + + def barrier(self): + self.store.barrier() + + def __del__(self): + del self.store class VLLMClient(TRLVLLMClient): """ - Extension of TRL's VLLMClient that adds CPU support for development and testing. + VLLMClient for Neuron environments. - This class inherits all functionality from trl.extras.vllm_client.VLLMClient and only - overrides the init_communicator method to add CPU fallback support when neither CUDA - nor XPU devices are available. + This class inherits all functionality from trl.extras.vllm_client.VLLMClient and only overrides methods + to enable CPU-based communication suitable for Neuron setups and development/testing scenarios. """ def init_communicator(self, device: Union[torch.device, str, int] = 0): """ Initializes the weight update group in a distributed setup for model synchronization. - This method extends the parent implementation to support CPU-only environments by adding - a CPU fallback communicator when neither XPU nor CUDA devices are available. + This method uses CPU-based communication via object broadcasting, suitable for Neuron + environments and development/testing scenarios. Args: device (`torch.device`, `str`, or `int`, *optional*, defaults to `0`): - Device of trainer main process. It's the device that will be used for the weights synchronization. Can - be a `torch.device` object, a string like `'cuda:0'`, or an integer device index. + Device parameter for compatibility. Communication is handled via CPU. """ # Get the world size from the server - import requests url = f"{self.base_url}/get_world_size/" response = requests.get(url) if response.status_code == 200: @@ -67,17 +86,9 @@ def init_communicator(self, device: Union[torch.device, str, int] = 0): # Initialize weight update group url = f"{self.base_url}/init_communicator/" - # Will simplify it after torch xpu 2.9 support get uuid. - if is_torch_xpu_available(): - if hasattr(torch.xpu.get_device_properties(device), "uuid"): - client_device_uuid = str(torch.xpu.get_device_properties(device).uuid) - else: - client_device_uuid = "42" - elif torch.cuda.is_available(): - client_device_uuid = str(torch.cuda.get_device_properties(device).uuid) - else: - # CPU fallback - use dummy UUID - client_device_uuid = "42" + + # Use dummy UUID for CPU/Neuron environments + client_device_uuid = "42" # In the server side, the host is set to 0.0.0.0 response = self.session.post( @@ -97,49 +108,10 @@ def init_communicator(self, device: Union[torch.device, str, int] = 0): # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 time.sleep(0.1) - # Set up the communication group for weight broadcasting - if is_torch_xpu_available(): - store = torch.distributed.TCPStore( - host_name=self.host, port=self.group_port, world_size=world_size, is_master=(self.rank == 0) - ) - prefixed_store = c10d.PrefixStore("client2server", store) - pg = c10d.ProcessGroupXCCL( - store=prefixed_store, - rank=self.rank, - size=world_size, - ) - self.communicator = pg - elif torch.cuda.is_available(): - pg = StatelessProcessGroup.create( - host=self.host, port=self.group_port, rank=self.rank, world_size=world_size - ) - self.communicator = PyNcclCommunicator(pg, device=device) - else: - # CPU fallback - create a custom communicator that uses object broadcasting - from collections import namedtuple - Group = namedtuple('Group', 'barrier') - - class CPUCommunicator: - def __init__(self, store, rank): - self.rank = rank - self.store = store - self.group = Group(barrier=self.barrier) - - def broadcast(self, tensor, src): - # Move tensor to CPU to avoid issues on the server side when running vLLM+CPU - tensor = tensor.cpu() - self.store.broadcast_obj(tensor, src=self.rank) - - def barrier(self): - self.store.barrier() - - def __del__(self): - del self.store - - pg = StatelessProcessGroup.create( - host=self.host, port=self.group_port, rank=self.rank, world_size=world_size - ) - self.communicator = CPUCommunicator(pg, self.rank) + pg = StatelessProcessGroup.create( + host=self.host, port=self.group_port, rank=self.rank, world_size=world_size + ) + self.communicator = CPUCommunicator(pg, self.rank) # When the client object is deleted, close the weight update group atexit.register(self.close_communicator) From fd1239bc8b0f73d98e367ed48a87bcd7e3a9a939 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 28 Nov 2025 15:57:57 +0100 Subject: [PATCH 50/78] add MockVLLMClient --- .../grpo_qwen3/finetune_grpo_qwen3.py | 8 +- optimum/neuron/accelerate/utils/operations.py | 2 +- optimum/neuron/trainers/extras/__init__.py | 4 +- optimum/neuron/trainers/extras/vllm_client.py | 133 ++++++++++-- optimum/neuron/trainers/grpo_mocks.py | 184 ----------------- optimum/neuron/trainers/grpo_trainer.py | 193 ++++++++++-------- optimum/neuron/trainers/trl_utils.py | 85 ++++++++ 7 files changed, 319 insertions(+), 290 deletions(-) delete mode 100644 optimum/neuron/trainers/grpo_mocks.py diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index 6a93f9572..a671deb72 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -31,6 +31,7 @@ from optimum.neuron import NeuronGRPOConfig, NeuronGRPOTrainer, NeuronTrainingArguments from optimum.neuron.models.training import NeuronModelForCausalLM +from optimum.neuron.trainers.extras import MockVLLMClient # ============================================================================= @@ -113,9 +114,6 @@ def load_grpo_dataset(): # This dataset has prompts in the "prompt" column dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - # Take a small subset for this example - dataset = dataset.select([0] * 100000) - return dataset @@ -150,7 +148,7 @@ def train(model_id, tokenizer, dataset, training_args): r=64, lora_alpha=128, lora_dropout=0.05, - target_modules=["embed_tokens", "q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"], + target_modules=["q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"], bias="none", task_type="CAUSAL_LM", ) @@ -194,6 +192,8 @@ def train(model_id, tokenizer, dataset, training_args): train_dataset=dataset, processing_class=tokenizer, peft_config=lora_config, + # To do: disable this fake client, only for development without vLLM server. + vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), ) # Train the model diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py index 5db12c6e2..e7cf13a51 100644 --- a/optimum/neuron/accelerate/utils/operations.py +++ b/optimum/neuron/accelerate/utils/operations.py @@ -16,8 +16,8 @@ import pickle from typing import Any, Callable -import torch import numpy as np +import torch import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr diff --git a/optimum/neuron/trainers/extras/__init__.py b/optimum/neuron/trainers/extras/__init__.py index da5732120..5ce9115ed 100644 --- a/optimum/neuron/trainers/extras/__init__.py +++ b/optimum/neuron/trainers/extras/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .vllm_client import VLLMClient +from .vllm_client import MockVLLMClient, VLLMClient -__all__ = ["VLLMClient"] +__all__ = ["VLLMClient", "MockVLLMClient"] diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index e0c41ac4c..4bd089438 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -14,15 +14,16 @@ # limitations under the License. import atexit -import requests +import random import time -from typing import Union from collections import namedtuple +from typing import Union +import requests import torch import torch_xla +from optimum.utils import logging from trl.extras.vllm_client import VLLMClient as TRLVLLMClient - from trl.import_utils import is_vllm_available @@ -32,6 +33,8 @@ class StatelessProcessGroup: pass +logger = logging.get_logger() + # Set up the communication group for weight broadcasting using CPU communicator Group = namedtuple('Group', 'barrier') @@ -55,24 +58,9 @@ def __del__(self): del self.store class VLLMClient(TRLVLLMClient): - """ - VLLMClient for Neuron environments. - - This class inherits all functionality from trl.extras.vllm_client.VLLMClient and only overrides methods - to enable CPU-based communication suitable for Neuron setups and development/testing scenarios. - """ + """VLLMClient with CPU-based communication for Neuron environments.""" def init_communicator(self, device: Union[torch.device, str, int] = 0): - """ - Initializes the weight update group in a distributed setup for model synchronization. - - This method uses CPU-based communication via object broadcasting, suitable for Neuron - environments and development/testing scenarios. - - Args: - device (`torch.device`, `str`, or `int`, *optional*, defaults to `0`): - Device parameter for compatibility. Communication is handled via CPU. - """ # Get the world size from the server url = f"{self.base_url}/get_world_size/" response = requests.get(url) @@ -115,3 +103,110 @@ def init_communicator(self, device: Union[torch.device, str, int] = 0): # When the client object is deleted, close the weight update group atexit.register(self.close_communicator) + +class MockVLLMClient(VLLMClient): + """ + Mock VLLMClient that generates random completions and triggers XLA compilation without vLLM server. + + Used for neuron_parallel_compile and testing. Generates random tokens, not real LLM outputs. + + Args: + tokenizer: Tokenizer for encoding/decoding + max_completion_length: Maximum completion length + min_completion_length: Minimum completion length (default: 10) + seed: Random seed for reproducibility + """ + + def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None): + # Don't call super().__init__() - we don't need server connection + self.tokenizer = tokenizer + self.max_completion_length = max_completion_length + self.min_completion_length = min(min_completion_length, max_completion_length) + self.random = random.Random(seed) + + logger.warning( + "Using MockVLLMClient for neuron_parallel_compile or testing. " + "This generates random dummy completions and should only be used for compilation/testing." + ) + + def generate( + self, + prompts: list[str], + images=None, + n: int = 1, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 256, + repetition_penalty: float = 1.0, + truncate_prompt_tokens=None, + guided_decoding_regex=None, + generation_kwargs=None, + ): + """ + Generate random completions with random lengths. + + Returns dict with prompt_ids, completion_ids, and logprobs. + """ + prompt_ids = [] + completion_ids = [] + logprobs = [] + + # Determine vocab range (avoid special tokens) + vocab_size = self.tokenizer.vocab_size + min_token_id = min(100, vocab_size - 1) + max_token_id = vocab_size - 1 + + for prompt in prompts: + # Tokenize prompt + prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False) + + # Truncate if needed + if truncate_prompt_tokens is not None and len(prompt_tokens) > truncate_prompt_tokens: + prompt_tokens = prompt_tokens[-truncate_prompt_tokens:] + + prompt_ids.append(prompt_tokens) + + # Generate n completions per prompt + for _ in range(n): + # Random completion length within bounds + max_len = min(max_tokens, self.max_completion_length) + completion_length = self.random.randint(self.min_completion_length, max_len) + + # Generate random tokens from safe vocab range + completion = [ + self.random.randint(min_token_id, max_token_id) + for _ in range(completion_length) + ] + completion_ids.append(completion) + + # Generate realistic random logprobs (typical range: -2 to -10) + completion_logprobs = [ + -self.random.uniform(2.0, 8.0) + for _ in range(completion_length) + ] + logprobs.append(completion_logprobs) + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "logprobs": logprobs, + } + + def init_communicator(self, device): + """No-op: mock has no communicator.""" + pass + + def update_named_param(self, name, weights): + """No-op: mock has no model to update.""" + pass + + def reset_prefix_cache(self): + """No-op: mock has no cache.""" + pass + + def close_communicator(self): + """No-op: mock has no communicator.""" + pass + diff --git a/optimum/neuron/trainers/grpo_mocks.py b/optimum/neuron/trainers/grpo_mocks.py deleted file mode 100644 index a900bf0ed..000000000 --- a/optimum/neuron/trainers/grpo_mocks.py +++ /dev/null @@ -1,184 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Mock implementations for GRPO trainer testing and development. - -This module provides mock implementations of vLLM client and other components -to enable development and testing of NeuronGRPOTrainer without requiring a full -vLLM server setup. -""" - -import random - -from optimum.utils import logging - - -logger = logging.get_logger() - - -class MockVLLMClient: - """ - Mock vLLM client that generates dummy completions for testing. - - This mock client simulates the behavior of a real vLLM server by generating - placeholder completions. It's useful for: - - Development without vLLM server setup - - Testing trainer logic independently of generation quality - - Unit testing GRPO training loop - - Args: - tokenizer: Tokenizer to use for encoding/decoding - max_completion_length: Maximum length of generated completions - - Note: - This is a development tool and should not be used in production. - Generated completions are deterministic placeholders, not real language model outputs. - """ - - def __init__(self, tokenizer, max_completion_length=256): - self.tokenizer = tokenizer - self.max_completion_length = max_completion_length - logger.warning( - "Using MockVLLMClient for development. This generates placeholder completions " - "and should only be used for testing and development." - ) - - def generate( - self, - prompts: list[str], - images=None, - n: int = 1, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - max_tokens: int = 256, - repetition_penalty: float = 1.0, - truncate_prompt_tokens=None, - guided_decoding_regex=None, - generation_kwargs=None, - ): - """ - Generate mock completions for the given prompts. - - Args: - prompts: List of prompt strings - images: Optional list of images (not used in mock) - n: Number of completions to generate per prompt - temperature: Sampling temperature (not used in mock) - top_p: Nucleus sampling parameter (not used in mock) - top_k: Top-k sampling parameter (not used in mock) - min_p: Minimum probability threshold (not used in mock) - max_tokens: Maximum tokens to generate - repetition_penalty: Repetition penalty (not used in mock) - truncate_prompt_tokens: Maximum prompt length - guided_decoding_regex: Regex for guided decoding (not used in mock) - generation_kwargs: Additional generation arguments (not used in mock) - - Returns: - Dictionary with keys: - - prompt_ids: List of tokenized prompts (one per prompt) - - completion_ids: List of tokenized completions (n per prompt) - - logprobs: List of log probabilities (one list per completion) - """ - prompt_ids = [] - completion_ids = [] - logprobs = [] - - for prompt in prompts: - # Tokenize prompt - prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False) - - # Truncate if needed - if truncate_prompt_tokens is not None and len(prompt_tokens) > truncate_prompt_tokens: - prompt_tokens = prompt_tokens[-truncate_prompt_tokens:] - - prompt_ids.append(prompt_tokens) - - # Generate n completions per prompt - for i in range(n): - # Generate mock completion with realistic varied tokens - # In real scenario, this would be actual LLM generation - completion_length = min(max_tokens, self.max_completion_length) - - # Generate completion with varied tokens from the vocabulary - # Avoid special tokens by using a safe range of token IDs - vocab_size = self.tokenizer.vocab_size - # Use token IDs in a safe range (skip first 100 tokens which often include special tokens) - min_token_id = min(100, vocab_size - 1) - max_token_id = vocab_size - 1 - completion = [random.randint(min_token_id, max_token_id) for _ in range(completion_length)] - completion_ids.append(completion) - - # Generate realistic mock logprobs - # Real language model logprobs typically range from -2 to -10 - # with occasional values outside this range - completion_logprobs = [-random.uniform(2.0, 8.0) for _ in range(completion_length)] - logprobs.append(completion_logprobs) - - return { - "prompt_ids": prompt_ids, - "completion_ids": completion_ids, - "logprobs": logprobs, - } - - def init_communicator(self, device): - """ - Mock initialization of communicator. - - Args: - device: Device to initialize on (not used in mock) - """ - pass - - def update_named_param(self, name, data): - """ - Mock update of named parameter. - - In a real vLLM setup, this would sync model weights to the vLLM server. - For mock mode, this is a no-op since we're not using a real server. - - Args: - name: Parameter name - data: Parameter data tensor (not used in mock) - """ - pass - - def reset_prefix_cache(self): - """ - Mock reset of prefix cache. - - In a real vLLM setup, this would clear the KV cache for prefix caching. - For mock mode, this is a no-op since we're not using a real server. - """ - pass - - -def create_mock_vllm_client(tokenizer, args): - """ - Factory function to create a mock vLLM client. - - Args: - tokenizer: Tokenizer to use for the mock client - args: Training arguments containing max_completion_length - - Returns: - MockVLLMClient instance - """ - return MockVLLMClient( - tokenizer=tokenizer, - max_completion_length=args.max_completion_length, - ) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index add90eb79..858bc69a2 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -49,6 +49,7 @@ from ..peft.utils.vllm import get_original_merged_weights_for_vllm from ..utils import is_precompilation, is_trl_available from ..utils.import_utils import is_peft_available +from .extras import MockVLLMClient, VLLMClient from .grpo_config import NeuronGRPOConfig from .training_args import NeuronTrainingArguments from .transformers import NeuronTrainer @@ -59,7 +60,7 @@ nanmin, nanstd, neuron_parallel_compile_tokenizer_decoder_method, - pad, + pad_or_truncate_to_length, ) @@ -69,6 +70,7 @@ if is_trl_available(): from trl import GRPOConfig, GRPOTrainer from trl.data_utils import is_conversational + from trl.extras.vllm_client import VLLMClient as TRLVLLMClient from trl.trainer.utils import ( RepeatSampler, disable_dropout_in_model, @@ -85,6 +87,9 @@ class GRPOTrainer: class GRPOConfig: pass + class TRLVLLMClient: + pass + if is_peft_available(): from peft import PeftConfig @@ -133,6 +138,7 @@ def __init__( optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, peft_config: PeftConfig | None = None, + vllm_client: TRLVLLMClient | None = None, ): if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronGRPOTrainer requires trl=={TRL_VERSION}.") @@ -403,37 +409,26 @@ def make_inputs_require_grad(module, input, output): set_seed(args.seed, device_specific=True) # vLLM setup - server mode only - from ..utils import is_vllm_available - - # MOCK FLAG: Change this to False when real vLLM server is ready - USE_MOCK_VLLM = False - # TODO: Remove this when mock is no longer used - tracks mock usage to disable - # importance sampling correction which doesn't work with mock's random tokens - self._using_mock_vllm = USE_MOCK_VLLM - - if USE_MOCK_VLLM: - logger.warning( - "Using MOCK vLLM client for development. This generates placeholder completions " - "and should only be used for testing and development. Set USE_MOCK_VLLM=False in " - "grpo_trainer.py to use real vLLM server." - ) - from .grpo_mocks import create_mock_vllm_client - - # MOCK: Each process needs its own client (generates locally, no server) - self.vllm_client = create_mock_vllm_client(tokenizer, args) + if vllm_client is not None: + # Use injected client (for testing, mocking, or custom implementations) + self.vllm_client = vllm_client else: + # Default: Create VLLMClient from args + from ..utils import is_vllm_available + if not is_vllm_available(): raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") - # Setup vLLM server client (all processes need it to call the server) - from trl.extras.vllm_client import VLLMClient - if args.vllm_server_base_url is not None: base_url = args.vllm_server_base_url else: base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + # For `neuron_parallel_compile`, use a mock VLLM client that doesn't make actual server requests. + if is_precompilation(): + self.vllm_client = MockVLLMClient(tokenizer, max_completion_length=self.max_completion_length) + else: + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) # Only main process initializes the communicator for weight updates if self.accelerator.is_main_process: self.vllm_client.init_communicator(device="cpu") @@ -596,6 +591,7 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): output_reward_func = reward_func( prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs ) + print("Reward function output:", reward_func_name, output_reward_func) output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) @@ -607,6 +603,11 @@ def _move_model_to_vllm(self): if isinstance(self.model, NeuronPeftModel): # Get original (unsharded, untransformed) merged weights for vLLM original_weights = get_original_merged_weights_for_vllm(self.model) + # For now, we only support CPU communicator in Neuron environments. + # The CPU communicator moves weights to CPU before broadcasting, but to avoid a lot of device -> host moves, + # we move the weights to CPU here once before broadcasting, and the communicator will just broadcast them. + original_weights = move_all_tensor_to_cpu(original_weights) + torch_xla.sync() # Send weights to vLLM server (only main process for server mode) for name, weight in original_weights.items(): @@ -614,8 +615,7 @@ def _move_model_to_vllm(self): name = self._fix_param_name_to_vllm(name) if self.vllm_mode == "server" and self.accelerator.is_main_process: - pass - # self.vllm_client.update_named_param(name, weight) + self.vllm_client.update_named_param(name, weight) elif self.vllm_mode == "colocate": llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights([(name, weight)]) @@ -665,10 +665,8 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): else: ordered_set_of_images = None - # For mock vLLM: each process generates independently - # For real vLLM server: only main process generates, then broadcast - if self._using_mock_vllm: - # Each process generates locally + # Generate on main process only, then broadcast to all ranks + if self.tp_rank == self.pp_rank == 0: output = self.vllm_client.generate( prompts=ordered_set_of_prompts, images=ordered_set_of_images, @@ -684,37 +682,16 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): generation_kwargs=self.args.generation_kwargs, ) else: - # Real vLLM server mode: only main process generates - if self.tp_rank == self.pp_rank == 0: - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - else: - output = { - "prompt_ids": [[]], - "completion_ids": [[]], - "logprobs": [[]], - } - output = None + output = None - trn_config = self.accelerator.state.trn_config - # TODO: change that to a better default. - fixed_size = int(2e6) # 2MB fixed size for buffer, should be enough. - if trn_config.tensor_parallel_size > 1: - output = broadcast_object_to_tensor_model_parallel_group(output, fixed_size=fixed_size) - if trn_config.pipeline_parallel_size > 1: - output = broadcast_object_to_pipeline_model_parallel_group(output, fixed_size=fixed_size) + # Broadcast output to all ranks + trn_config = self.accelerator.state.trn_config + # TODO: change that to a better default. + fixed_size = int(2e6) # 2MB fixed size for buffer, should be enough. + if trn_config.tensor_parallel_size > 1: + output = broadcast_object_to_tensor_model_parallel_group(output, fixed_size=fixed_size) + if trn_config.pipeline_parallel_size > 1: + output = broadcast_object_to_pipeline_model_parallel_group(output, fixed_size=fixed_size) # Repeat prompt_ids num_generations times to match completion_ids prompt_ids = [ids for ids in output["prompt_ids"] for _ in range(self.num_generations)] @@ -816,6 +793,10 @@ def _get_per_token_logps_and_entropies( model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings logits = model(**model_inputs).logits + + # Synchronize after model forward to avoid recompiling multiple model graphs. + torch_xla.sync() + # Exclude the last value: it corresponds to the next token pred logits = logits[:, :-1, :] # (B, L-1, H) # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. @@ -870,17 +851,68 @@ def _generate_and_score_completions( ) = self._generate(prompts, images) # Convert lists of token IDs to padded tensors - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + prompt_ids = [ + pad_or_truncate_to_length( + torch.tensor(ids), + self.max_prompt_length, + padding_value=self.pad_token_id, + padding_or_truncate_side="left" + ) + for ids in prompt_ids_list + ] + # prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + # prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_mask = [ + pad_or_truncate_to_length( + torch.ones(len(ids), dtype=torch.long), + self.max_prompt_length, + padding_value=0, + padding_or_truncate_side="left", + ) + for ids in prompt_ids_list + ] + prompt_ids = torch.stack(prompt_ids, dim=0).to(device) + prompt_mask = torch.stack(prompt_mask, dim=0).to(device) + # prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + # prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + # completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + # completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + # completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + # completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + completion_ids = [ + pad_or_truncate_to_length( + torch.tensor(ids), + self.max_completion_length, + padding_value=self.pad_token_id, + padding_or_truncate_side="right" + ) + for ids in completion_ids_list + ] + completion_mask = [ + pad_or_truncate_to_length( + torch.ones(len(ids), dtype=torch.long), + self.max_completion_length, + padding_value=0, + padding_or_truncate_side="right", + ) + for ids in completion_ids_list + ] + completion_ids = torch.stack(completion_ids, dim=0).to(device) + completion_mask = torch.stack(completion_mask, dim=0).to(device) + if sampling_per_token_logps_list is not None: - sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + # sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + # sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + sampling_per_token_logps = [ + pad_or_truncate_to_length( + torch.tensor(logps), + self.max_completion_length, + padding_value=0.0, + padding_or_truncate_side="right", + ) + for logps in sampling_per_token_logps_list + ] + sampling_per_token_logps = torch.stack(sampling_per_token_logps, dim=0).to(device) else: sampling_per_token_logps = None @@ -930,9 +962,7 @@ def _generate_and_score_completions( old_per_token_logps = None # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch - # TODO: Remove _using_mock_vllm check when mock is no longer used - mock's random tokens - # cause -inf logprobs from the model, leading to NaN in importance sampling computation - if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: + if self.use_vllm and self.vllm_importance_sampling_correction: importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) importance_sampling_ratio = torch.clamp( importance_sampling_ratio, max=self.vllm_importance_sampling_cap @@ -1039,8 +1069,7 @@ def _generate_and_score_completions( # if images is not None: # self._logs["images"].extend(self.accelerator.gather_object(images)) - # TODO: Remove _using_mock_vllm check when mock is no longer used - if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: + if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) # Original code was: # delta = delta[completion_mask.bool()] @@ -1100,8 +1129,10 @@ def _generate_and_score_completions( # Move metrics and logs to CPU. metrics = move_all_tensor_to_cpu(metrics) - metrics = {key: [val.item() for val in value] for key, value in metrics.items()} logs = move_all_tensor_to_cpu(logs) + torch_xla.sync() + + metrics = {key: [val.item() for val in value] for key, value in metrics.items()} # Update the actual metrics and logs. self._metrics[mode].update(metrics) @@ -1120,8 +1151,7 @@ def _generate_and_score_completions( } if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps - # TODO: Remove _using_mock_vllm check when mock is no longer used - if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: + if self.use_vllm and self.vllm_importance_sampling_correction: output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps @@ -1230,6 +1260,9 @@ def _compute_loss(self, model, inputs): token_type_ids=inputs.get("token_type_ids"), ) + print("Per-token log probabilities:", per_token_logps) + print("Entropies:", entropies) + if self.top_entropy_quantile < 1.0: entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) else: @@ -1279,8 +1312,7 @@ def _compute_loss(self, model, inputs): if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask - # TODO: Remove _using_mock_vllm check when mock is no longer used - if self.use_vllm and self.vllm_importance_sampling_correction and not self._using_mock_vllm: + if self.use_vllm and self.vllm_importance_sampling_correction: per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] if self.beta != 0.0: @@ -1339,8 +1371,9 @@ def masked_batch_mean(x): gathered_clip_ratio = self.accelerator.gather(clip_ratio) metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) - torch_xla.sync() # Graph break before moving metrics to CPU. + # torch_xla.sync() # Graph break before moving metrics to CPU. metrics = move_all_tensor_to_cpu(metrics) + torch_xla.sync() metrics = {key: [val.item() for val in value] for key, value in metrics.items()} self._metrics[mode].update(metrics) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 4da05da22..54c73c186 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -14,10 +14,12 @@ # limitations under the License. import math +from typing import Literal import numpy as np import torch import torch.distributed as dist +import torch.nn.functional as F from optimum.utils import logging from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler @@ -65,6 +67,89 @@ def pad( return output +def pad_or_truncate_to_length( + tensor: torch.Tensor, + length: int, + dim: int = 0, + padding_value: int = 0, + padding_or_truncate_side: Literal["left", "right"] = "right", +) -> torch.Tensor: + """ + Pads or truncates a tensor to a given length along the provided dimension. + + Args: + tensor: Input tensor to pad or truncate + length: Target length + dim: Dimension along which to pad/truncate + padding_value: Value to use for padding + padding_or_truncate_side: Side for both padding and truncation + - "left": Pads on left, truncates from left (keeps last tokens) + - "right": Pads on right, truncates from right (keeps first tokens) + """ + current_length = tensor.shape[dim] + if current_length == length: + return tensor + elif current_length > length: + # Truncate + slice_ = [slice(None)] * tensor.dim() + if padding_or_truncate_side == "left": + # Keep last tokens (truncate from left) + slice_[dim] = slice(current_length - length, current_length) + elif padding_or_truncate_side == "right": + # Keep first tokens (truncate from right) + slice_[dim] = slice(0, length) + else: + raise ValueError("padding_or_truncate_side must be 'left' or 'right'") + return tensor[slice_] + else: + # Pad + padding_shape = list(tensor.shape) + padding_shape[dim] = length - current_length + padding = torch.full(padding_shape, padding_value, dtype=tensor.dtype, device=tensor.device) + if padding_or_truncate_side == "left": + return torch.cat([padding, tensor], dim=dim) + elif padding_or_truncate_side == "right": + return torch.cat([tensor, padding], dim=dim) + else: + raise ValueError("padding_or_truncate_side must be 'left' or 'right'") + + +def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: + """ + Compute the Shannon entropy (in nats) for each row of *logits* in a memory-efficient way. + + Instead of materializing the full softmax for all rows at once, the logits are flattened to shape (N, num_classes), + where N is the product of all leading dimensions. Computation is then performed in chunks of size `chunk_size` + along this flattened dimension, reducing peak memory usage. The result is reshaped back to match the input's + leading dimensions. + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions + are preserved in the output. + chunk_size (`int`, *optional*, defaults to `128`): + Number of rows from the flattened logits to process per iteration. Smaller values reduce memory usage at + the cost of more iterations. + + Returns: + `torch.Tensor`: + Entropy values with shape `logits.shape[:-1]`. + """ + original_shape = logits.shape[:-1] # all dims except num_classes + num_classes = logits.shape[-1] + + # Flatten all leading dimensions into one + flat_logits = logits.reshape(-1, num_classes) + + entropies = [] + for chunk in flat_logits.split(chunk_size, dim=0): + logps = F.log_softmax(chunk, dim=-1) + chunk_entropy = -(torch.exp(logps) * logps).sum(-1) + entropies.append(chunk_entropy) + + entropies = torch.cat(entropies, dim=0) + return entropies.reshape(original_shape) + def neuron_parallel_compile_tokenizer_decoder_method( self, token_ids: int | list[int], From 5e8af493a9dc889c8aaaf954ccb25bb0c31ff234 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 2 Dec 2025 09:39:58 +0100 Subject: [PATCH 51/78] wip, recompilations --- .../grpo_qwen3/finetune_grpo_qwen3.py | 2 + optimum/neuron/accelerate/accelerator.py | 1 - optimum/neuron/accelerate/utils/__init__.py | 3 + optimum/neuron/accelerate/utils/operations.py | 142 +++++++++++++----- optimum/neuron/peft/utils/vllm.py | 4 +- optimum/neuron/trainers/extras/vllm_client.py | 22 ++- optimum/neuron/trainers/grpo_trainer.py | 94 +++++------- optimum/neuron/trainers/trl_utils.py | 16 +- pyproject.toml | 29 ++-- tests/training/test_custom_modeling.py | 21 ++- 10 files changed, 197 insertions(+), 137 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index a671deb72..6bb478665 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -34,6 +34,8 @@ from optimum.neuron.trainers.extras import MockVLLMClient +x = MockVLLMClient # To avoid linter warning about unused import + # ============================================================================= # Reward Functions # ============================================================================= diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index ca1ec1f70..4b9180481 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -558,7 +558,6 @@ def gather(self, tensor, sync: bool = False): torch_xla.sync() return gathered - def gather_for_metrics(self, input_data, use_gather_object: bool = False, sync: bool = False): try: recursively_apply(lambda x: x, input_data, error_on_other_type=True) diff --git a/optimum/neuron/accelerate/utils/__init__.py b/optimum/neuron/accelerate/utils/__init__.py index d46cba470..294e53bb3 100644 --- a/optimum/neuron/accelerate/utils/__init__.py +++ b/optimum/neuron/accelerate/utils/__init__.py @@ -21,4 +21,7 @@ broadcast_object_to_pipeline_model_parallel_group, broadcast_object_to_tensor_model_parallel_group, gather_object, + gather_object_from_data_parallel_group, + gather_object_from_pipeline_model_parallel_group, + gather_object_from_tensor_model_parallel_group, ) diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py index e7cf13a51..8570ae73a 100644 --- a/optimum/neuron/accelerate/utils/operations.py +++ b/optimum/neuron/accelerate/utils/operations.py @@ -23,7 +23,6 @@ import torch_xla.runtime as xr from neuronx_distributed.parallel_layers.parallel_state import ( get_context_model_parallel_size, - get_data_parallel_group, get_data_parallel_replica_groups, get_data_parallel_size, get_pipeline_model_parallel_replica_groups, @@ -31,7 +30,14 @@ ) -def broadcast_object(obj: Any, src: int = 0, groups: list[list[int]] | None = None, world_size_function: Callable[[], int] = xr.world_size, get_rank_function: Callable[[], int] = xr.global_ordinal, fixed_size: int | None = None) -> Any: +def broadcast_object( + obj: Any, + src: int = 0, + groups: list[list[int]] | None = None, + world_size_function: Callable[[], int] = xr.world_size, + get_rank_function: Callable[[], int] = xr.global_ordinal, + fixed_size: int | None = None, +) -> Any: """ Broadcasts arbitrary objects across XLA-distributed processes. Returns the object from the source rank on all ranks. @@ -55,7 +61,7 @@ def broadcast_object(obj: Any, src: int = 0, groups: list[list[int]] | None = No length = 0 # First, broadcast the length of the serialized object. - max_length = xm.all_reduce("max", torch.tensor(length, dtype=torch.int64).to(xm.xla_device())) + max_length = xm.all_reduce("max", torch.tensor([length], dtype=torch.int64).to(xm.xla_device())) max_length = max_length.cpu() # Ensure all ranks agree on the max length. @@ -70,28 +76,21 @@ def broadcast_object(obj: Any, src: int = 0, groups: list[list[int]] | None = No if rank == src: np_buffer = np.frombuffer(bytes_, dtype=np.uint8) - data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) padding_length = target_length - length if padding_length > 0: - padding_tensor = torch.zeros(padding_length, dtype=torch.uint8, device=xm.xla_device()) - data_tensor = torch.cat([data_tensor, padding_tensor], dim=0) + padding = np.zeros(padding_length, dtype=np.uint8) + np_buffer = np.concatenate([np_buffer, padding], axis=0) + data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) else: data_tensor = torch.zeros(target_length, dtype=torch.uint8, device=xm.xla_device()) data_tensor = xm.all_reduce("sum", data_tensor, groups=groups) torch_xla.sync() - # In this case we truncate the tensor to the original max length on device to minimize the data transfer from device - # to host. - if fixed_size is None: - data_tensor = data_tensor[:max_length] - data_tensor_cpu = data_tensor.cpu() reduced_bytes = data_tensor_cpu.numpy().tobytes() - # Truncate to the original max length on host if fixed_size is specified to avoid changing shapes on device. - if fixed_size is not None: - reduced_bytes = reduced_bytes[:max_length] + reduced_bytes = reduced_bytes[:max_length] return pickle.loads(reduced_bytes) @@ -111,6 +110,7 @@ def broadcast_object_to_data_parallel_group(obj: Any, src: int = 0, fixed_size: fixed_size=fixed_size, ) + def broadcast_object_to_tensor_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: """ Broadcasts arbitrary objects across XLA-distributed tensor model parallel group. @@ -126,6 +126,7 @@ def broadcast_object_to_tensor_model_parallel_group(obj: Any, src: int = 0, fixe fixed_size=fixed_size, ) + def broadcast_object_to_pipeline_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: """ Broadcasts arbitrary objects across XLA-distributed pipeline model parallel group. @@ -142,51 +143,110 @@ def broadcast_object_to_pipeline_model_parallel_group(obj: Any, src: int = 0, fi ) -def gather_object(obj: Any) -> list[Any]: +def gather_object( + obj: Any, + groups: list[list[int]] | None = None, + world_size_function: Callable[[], int] = xr.world_size, + fixed_size: int | None = None, +) -> list[Any]: """ Gathers arbitrary objects across XLA-distributed processes. Returns list of objects from all ranks on all ranks. - - Note: Requires two all-gather operations (lengths then data). - For small objects, this overhead may be significant. + If `groups` is specified, gather is done separately in each group. """ - world_size = get_data_parallel_size() + world_size = world_size_function() # Early exit for single process if world_size == 1: return [obj] - groups = get_data_parallel_group(as_list=True) - serialized = pickle.dumps(obj) - byte_len = len(serialized) + length = len(serialized) - byte_tensor = torch.frombuffer([serialized], dtype=torch.uint8).to(xm.xla_device()) + if fixed_size is not None and length > fixed_size: + raise ValueError(f"Serialized object size {length} exceeds the specified fixed_size {fixed_size}") - len_tensor = torch.tensor([byte_len], dtype=torch.int64, device=byte_tensor.device) - # all_gather concatenates along dim=0, so [1] -> [world_size] - gathered_lengths = xm.all_gather(len_tensor, dim=0, groups=groups, pin_layout=False) + lengths = xm.all_gather( + torch.tensor([length], dtype=torch.int64, device=xm.xla_device()), + dim=0, + groups=groups, + pin_layout=False, + ) + max_length = torch.max(lengths) torch_xla.sync() - max_len = torch.max(gathered_lengths) - max_len = int(max_len.item()) - padded = torch.zeros(max_len, dtype=torch.uint8, device=byte_tensor.device) - padded[:byte_len] = byte_tensor + max_len = int(max_length.item()) - # all_gather concatenates, so [max_len] -> [world_size * max_len] - gathered_data = xm.all_gather(padded, dim=0, groups=groups, pin_layout=False) + if fixed_size is not None: + target_length = fixed_size + else: + target_length = max_len + + np_buffer = np.frombuffer(serialized, dtype=np.uint8) + padding_length = target_length - length + if padding_length > 0: + padding = np.zeros(padding_length, dtype=np.uint8) + np_buffer = np.concatenate([np_buffer, padding], axis=0) + data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) + + data_tensors = xm.all_gather( + data_tensor, + dim=0, + groups=groups, + pin_layout=False, + ) torch_xla.sync() - - gathered_data_cpu = gathered_data.cpu() - gathered_lengths_cpu = gathered_lengths.cpu() + lengths_cpu = lengths.cpu() + data_tensors_cpu = [t.cpu() for t in data_tensors] + data_bytes = [t.numpy().tobytes() for t in data_tensors_cpu] results = [] - offset = 0 for i in range(world_size): - actual_len = int(gathered_lengths_cpu[i].item()) - valid_bytes = gathered_data_cpu[offset:offset + max_len][:actual_len].numpy().tobytes() - valid_bytes = valid_bytes[0] - results.append(pickle.loads(valid_bytes)) - offset += max_len + length_i = lengths_cpu[i].item() + bytes_i = data_bytes[i][:length_i] + obj_i = pickle.loads(bytes_i) + results.append(obj_i) return results + + +def gather_object_from_data_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed data parallel group. + Returns list of objects from all ranks in the data parallel group on all ranks. + """ + groups = get_data_parallel_replica_groups() + return gather_object( + obj, + groups=groups, + world_size_function=get_data_parallel_size, + fixed_size=fixed_size, + ) + + +def gather_object_from_tensor_model_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed tensor model parallel group. + Returns list of objects from all ranks in the tensor model parallel group on all ranks. + """ + groups = get_tensor_model_parallel_replica_groups() + return gather_object( + obj, + groups=groups, + world_size_function=get_context_model_parallel_size, + fixed_size=fixed_size, + ) + + +def gather_object_from_pipeline_model_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed pipeline model parallel group. + Returns list of objects from all ranks in the pipeline model parallel group on all ranks. + """ + groups = get_pipeline_model_parallel_replica_groups() + return gather_object( + obj, + groups=groups, + world_size_function=get_context_model_parallel_size, + fixed_size=fixed_size, + ) diff --git a/optimum/neuron/peft/utils/vllm.py b/optimum/neuron/peft/utils/vllm.py index 923f14aec..c77b3d0ef 100644 --- a/optimum/neuron/peft/utils/vllm.py +++ b/optimum/neuron/peft/utils/vllm.py @@ -87,7 +87,9 @@ def get_original_merged_weights_for_vllm(model) -> dict[str, torch.Tensor]: transformation_specs.append(copy.deepcopy(module.specs)) for specs in transformation_specs: - specs.module_fully_qualified_name = specs.module_fully_qualified_name.removeprefix("base_model.model.").replace(".base_layer", "") + specs.module_fully_qualified_name = specs.module_fully_qualified_name.removeprefix( + "base_model.model." + ).replace(".base_layer", "") for spec in specs: spec.peft_type = None diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index 4bd089438..b8c08bc12 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -30,13 +30,16 @@ if is_vllm_available(): from vllm.distributed.utils import StatelessProcessGroup else: + class StatelessProcessGroup: pass + logger = logging.get_logger() # Set up the communication group for weight broadcasting using CPU communicator -Group = namedtuple('Group', 'barrier') +Group = namedtuple("Group", "barrier") + class CPUCommunicator: def __init__(self, store, rank): @@ -57,6 +60,7 @@ def barrier(self): def __del__(self): del self.store + class VLLMClient(TRLVLLMClient): """VLLMClient with CPU-based communication for Neuron environments.""" @@ -96,14 +100,13 @@ def init_communicator(self, device: Union[torch.device, str, int] = 0): # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 time.sleep(0.1) - pg = StatelessProcessGroup.create( - host=self.host, port=self.group_port, rank=self.rank, world_size=world_size - ) + pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) self.communicator = CPUCommunicator(pg, self.rank) # When the client object is deleted, close the weight update group atexit.register(self.close_communicator) + class MockVLLMClient(VLLMClient): """ Mock VLLMClient that generates random completions and triggers XLA compilation without vLLM server. @@ -175,17 +178,11 @@ def generate( completion_length = self.random.randint(self.min_completion_length, max_len) # Generate random tokens from safe vocab range - completion = [ - self.random.randint(min_token_id, max_token_id) - for _ in range(completion_length) - ] + completion = [self.random.randint(min_token_id, max_token_id) for _ in range(completion_length)] completion_ids.append(completion) # Generate realistic random logprobs (typical range: -2 to -10) - completion_logprobs = [ - -self.random.uniform(2.0, 8.0) - for _ in range(completion_length) - ] + completion_logprobs = [-self.random.uniform(2.0, 8.0) for _ in range(completion_length)] logprobs.append(completion_logprobs) return { @@ -209,4 +206,3 @@ def reset_prefix_cache(self): def close_communicator(self): """No-op: mock has no communicator.""" pass - diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 858bc69a2..fadd65a9a 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -43,6 +43,7 @@ from ..accelerate.utils import ( broadcast_object_to_pipeline_model_parallel_group, broadcast_object_to_tensor_model_parallel_group, + gather_object_from_data_parallel_group, ) from ..models.training import NeuronModelForCausalLM from ..peft import NeuronPeftModel, get_peft_model @@ -139,6 +140,7 @@ def __init__( optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, peft_config: PeftConfig | None = None, vllm_client: TRLVLLMClient | None = None, + fixed_size_for_obj_collectives: int | None = 10 * 1024 * 1024, # 10 MB ): if not is_trl_available(required_version=TRL_VERSION): raise RuntimeError(f"Using NeuronGRPOTrainer requires trl=={TRL_VERSION}.") @@ -383,7 +385,9 @@ def make_inputs_require_grad(module, input, output): self.ref_model = None else: # Create reference model using NeuronModelForCausalLM - self.ref_model = NeuronModelForCausalLM.from_pretrained(model_id, args.trn_config, **args.model_init_kwargs or {}) + self.ref_model = NeuronModelForCausalLM.from_pretrained( + model_id, args.trn_config, **args.model_init_kwargs or {} + ) # Disable dropout in the models if args.disable_dropout: @@ -465,11 +469,12 @@ def make_inputs_require_grad(module, input, output): reward_func, evaluation_mode=True, device_placement=True ) - self.dp_rank = get_data_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank() self.pp_rank = get_pipeline_model_parallel_rank() + self.fixed_size_obj_collectives = fixed_size_for_obj_collectives + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: if dataset is None: dataset = self.train_dataset @@ -498,7 +503,6 @@ def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: ) return sampler - def train( self, resume_from_checkpoint: str | bool | None = None, @@ -554,7 +558,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: df = df.drop_duplicates(subset=["prompt"]) wandb.log({"completions": wandb.Table(dataframe=df)}) - def _save_checkpoint(self, model=None, trial=None, metrics=None): return NeuronTrainer._save_checkpoint(self) @@ -576,6 +579,7 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): if isinstance(reward_func, torch.nn.Module): if is_conversational(inputs[0]): from trl.data_utils import apply_chat_template + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: @@ -591,7 +595,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): output_reward_func = reward_func( prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs ) - print("Reward function output:", reward_func_name, output_reward_func) output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) @@ -686,12 +689,14 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): # Broadcast output to all ranks trn_config = self.accelerator.state.trn_config - # TODO: change that to a better default. - fixed_size = int(2e6) # 2MB fixed size for buffer, should be enough. if trn_config.tensor_parallel_size > 1: - output = broadcast_object_to_tensor_model_parallel_group(output, fixed_size=fixed_size) + output = broadcast_object_to_tensor_model_parallel_group( + output, fixed_size=self.fixed_size_obj_collectives + ) if trn_config.pipeline_parallel_size > 1: - output = broadcast_object_to_pipeline_model_parallel_group(output, fixed_size=fixed_size) + output = broadcast_object_to_pipeline_model_parallel_group( + output, fixed_size=self.fixed_size_obj_collectives + ) # Repeat prompt_ids num_generations times to match completion_ids prompt_ids = [ids for ids in output["prompt_ids"] for _ in range(self.num_generations)] @@ -704,10 +709,7 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): return prompt_ids, completion_ids, logprobs, forward_kwargs def _to_fixed_length( - self, - tensor: torch.Tensor, - padding_value: int = 0, - padding_side: str = "right" + self, tensor: torch.Tensor, padding_value: int = 0, padding_side: str = "right" ) -> torch.Tensor: """ Pads or truncates tensor to fixed length = max_prompt_length + max_completion_length. @@ -735,7 +737,7 @@ def _get_per_token_logps_and_entropies( input_ids, attention_mask, logits_to_keep, - batch_size, # Compared to the original `trl` implementation, `batch_size` must be specified. + batch_size: int | None = None, compute_entropy=False, pixel_values=None, image_grid_thw=None, @@ -744,21 +746,18 @@ def _get_per_token_logps_and_entropies( image_sizes=None, token_type_ids=None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - # Make sure the inputs have a fixed shape. - input_ids = self._to_fixed_length( - input_ids, padding_value=self.pad_token_id, padding_side="left" - ) - attention_mask = self._to_fixed_length( - attention_mask, padding_value=0, padding_side="left" - ) - - # Force synchronization before starting computation to re-use the same graph. torch_xla.sync() - batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + + # Ensure input batch size is divisible by `batch_size` to avoid issues with XLA graph compilation. + if input_ids.size(0) % batch_size != 0: + raise ValueError( + f"The input_ids batch size must be divisible by `batch_size`, but got {input_ids.shape[0]} and " + f"{batch_size}." + ) + all_logps = [] all_entropies = [] - # TODO: check if it's ok with TORCH XLA for start in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[start : start + batch_size] attention_mask_batch = attention_mask[start : start + batch_size] @@ -794,9 +793,6 @@ def _get_per_token_logps_and_entropies( logits = model(**model_inputs).logits - # Synchronize after model forward to avoid recompiling multiple model graphs. - torch_xla.sync() - # Exclude the last value: it corresponds to the next token pred logits = logits[:, :-1, :] # (B, L-1, H) # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. @@ -825,7 +821,7 @@ def _get_per_token_logps_and_entropies( return logps, entropies def _generate_and_score_completions( - self, inputs: list[dict[str,torch.Tensor | Any]] + self, inputs: list[dict[str, torch.Tensor | Any]] ) -> dict[str, torch.Tensor | Any]: device = self.accelerator.device mode = "train" if self.model.training else "eval" @@ -856,7 +852,7 @@ def _generate_and_score_completions( torch.tensor(ids), self.max_prompt_length, padding_value=self.pad_token_id, - padding_or_truncate_side="left" + padding_or_truncate_side="left", ) for ids in prompt_ids_list ] @@ -884,7 +880,7 @@ def _generate_and_score_completions( torch.tensor(ids), self.max_completion_length, padding_value=self.pad_token_id, - padding_or_truncate_side="right" + padding_or_truncate_side="right", ) for ids in completion_ids_list ] @@ -1058,16 +1054,22 @@ def _generate_and_score_completions( metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) # Log prompt and completion texts - # self._logs["prompt"].extend(self.accelerator.gather_object(prompts_text)) - # self._logs["completion"].extend(self.accelerator.gather_object(completions_text)) + # self._logs["prompt"].extend( + # gather_object_from_data_parallel_group(prompts_text, fixed_size=self.fixed_size_obj_collectives) + # ) + # self._logs["completion"].extend( + # gather_object_from_data_parallel_group(completions_text, fixed_size=self.fixed_size_obj_collectives) + # ) logs["rewards"] = {} logs["advantages"] = [] for i, name in enumerate(self.reward_func_names): logs["rewards"][name] = rewards_per_func[:, i] logs["advantages"] = all_process_advantages - # if images is not None: - # self._logs["images"].extend(self.accelerator.gather_object(images)) + if images is not None: + self._logs["images"].extend( + gather_object_from_data_parallel_group(images, fixed_size=self.fixed_size_obj_collectives) + ) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -1083,12 +1085,8 @@ def _generate_and_score_completions( # We can simply take the max of the masked delta because values in delta are >= 0 (torch.abs). max_delta = delta_masked.max() - metrics["sampling/sampling_logp_difference/mean"].append( - self.accelerator.gather(mean_delta).mean() - ) - metrics["sampling/sampling_logp_difference/max"].append( - self.accelerator.gather(max_delta).max() - ) + metrics["sampling/sampling_logp_difference/mean"].append(self.accelerator.gather(mean_delta).mean()) + metrics["sampling/sampling_logp_difference/max"].append(self.accelerator.gather(max_delta).max()) # Original code was: # flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] @@ -1105,7 +1103,7 @@ def _generate_and_score_completions( masked_is_ratio_for_min = torch.where( completion_mask.bool(), importance_sampling_ratio, - torch.tensor(float('inf'), device=device, dtype=importance_sampling_ratio.dtype) + torch.tensor(float("inf"), device=device, dtype=importance_sampling_ratio.dtype), ) min_importance_sampling_ratio = masked_is_ratio_for_min.min() # importance_sampling_ratio values are >= 0 (torch.exp) so we can use the same computation as for delta. @@ -1124,9 +1122,6 @@ def _generate_and_score_completions( nanmax(self.accelerator.gather(max_importance_sampling_ratio)) ) - # Graph break after metrics and logs computation. - torch_xla.sync() - # Move metrics and logs to CPU. metrics = move_all_tensor_to_cpu(metrics) logs = move_all_tensor_to_cpu(logs) @@ -1140,7 +1135,6 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(logs["rewards"][name].tolist()) self._logs["advantages"].extend(logs["advantages"].tolist()) - output = { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, @@ -1221,15 +1215,12 @@ def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, thr # Handle empty case, if everything is sentinel, set threshold to +inf so no token is selected has_valid = num_valid > 0 entropy_threshold = torch.where( - has_valid, - entropy_threshold, - torch.tensor(float('inf'), device=device, dtype=entropies.dtype) + has_valid, entropy_threshold, torch.tensor(float("inf"), device=device, dtype=entropies.dtype) ) entropy_mask = (entropies > entropy_threshold) & mask.bool() return entropy_mask - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") @@ -1260,9 +1251,6 @@ def _compute_loss(self, model, inputs): token_type_ids=inputs.get("token_type_ids"), ) - print("Per-token log probabilities:", per_token_logps) - print("Entropies:", entropies) - if self.top_entropy_quantile < 1.0: entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) else: diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 54c73c186..9b44c540e 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -31,12 +31,13 @@ TRL_VERSION = "0.24.0" + def pad( tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right", max_length: int | None = None, - ) -> torch.Tensor: +) -> torch.Tensor: """ Pads a list of tensors to the same shape along the first dimension. It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid @@ -150,6 +151,7 @@ def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Te entropies = torch.cat(entropies, dim=0) return entropies.reshape(original_shape) + def neuron_parallel_compile_tokenizer_decoder_method( self, token_ids: int | list[int], @@ -202,10 +204,7 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: XLA-compatible version of nanstd. Compute the standard deviation of a tensor, ignoring NaNs. """ - # Use torch's built-in nanmean and compute variance with Bessel's correction variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) - count = torch.sum(~torch.isnan(tensor)) - variance *= count / (count - 1).clamp(min=1.0) # Bessel's correction, avoid division by zero return torch.sqrt(variance) @@ -244,9 +243,7 @@ def __init__( raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: - raise ValueError( - f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" - ) + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -293,9 +290,7 @@ def __iter__(self): if padding_size <= len(indices): indices += indices[:padding_size] else: - indices += (indices * math.ceil(padding_size / len(indices)))[ - :padding_size - ] + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[: self.total_size] @@ -326,4 +321,3 @@ def __len__(self) -> int: def set_epoch(self, epoch: int) -> None: self.epoch = epoch - diff --git a/pyproject.toml b/pyproject.toml index 1a49c1711..4b8bac540 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,15 +79,15 @@ training = [ "peft == 0.17.0", "evaluate == 0.4.3", ] -neuron = [ - "wheel", - "torch-neuron==1.13.1.2.9.74.0", - "torch==1.13.1.*", - "neuron-cc[tensorflow]==1.22.0.0", - "protobuf", - "torchvision", - "numpy==1.22.3", -] +# neuron = [ +# "wheel", +# "torch-neuron==1.13.1.2.9.74.0", +# "torch==1.13.1.*", +# "neuron-cc[tensorflow]==1.22.0.0", +# "protobuf", +# "torchvision", +# "numpy==1.22.3", +# ] neuronx = [ "wheel", "neuronx-cc==2.21.18209.0", @@ -133,7 +133,16 @@ include = ["optimum*"] "*" = ["*"] [tool.uv] -extra-index-url = ["https://pip.repos.neuron.amazonaws.com"] +index-strategy = "unsafe-best-match" + +[[tool.uv.index]] +name = "neuron" +url = "https://pip.repos.neuron.amazonaws.com/" +default = true + +[[tool.uv.index]] +name = "pypi" +url = "https://pypi.org/simple" [tool.ruff] line-length = 119 diff --git a/tests/training/test_custom_modeling.py b/tests/training/test_custom_modeling.py index 08099f9b5..51cada1ea 100644 --- a/tests/training/test_custom_modeling.py +++ b/tests/training/test_custom_modeling.py @@ -808,8 +808,9 @@ def test_peft_merge_unmerge(set_cache_for_ci): xm.mark_step() # Outputs should match - assert torch.allclose(logits_unmerged, logits_merged, rtol=1e-3, atol=1e-3), \ + assert torch.allclose(logits_unmerged, logits_merged, rtol=1e-3, atol=1e-3), ( f"Merged and unmerged outputs should match. Max diff: {(logits_unmerged - logits_merged).abs().max().item()}" + ) # Unmerge LoRA adapters model.unmerge_adapter() @@ -820,8 +821,9 @@ def test_peft_merge_unmerge(set_cache_for_ci): xm.mark_step() for name, original_weight in original_weights.items(): current_weight = current_weights[name].data - assert torch.allclose(original_weight, current_weight, rtol=1e-5, atol=1e-6), \ + assert torch.allclose(original_weight, current_weight, rtol=1e-5, atol=1e-6), ( f"Weight {name} should be restored after unmerge. Max diff: {(original_weight - current_weight).abs().max().item()}" + ) # Final output check with torch.no_grad(): @@ -829,8 +831,9 @@ def test_peft_merge_unmerge(set_cache_for_ci): logits_final = output_final.logits.clone() xm.mark_step() - assert torch.allclose(logits_unmerged, logits_final, rtol=1e-5, atol=1e-5), \ + assert torch.allclose(logits_unmerged, logits_final, rtol=1e-5, atol=1e-5), ( f"Final output should match original unmerged output. Max diff: {(logits_unmerged - logits_final).abs().max().item()}" + ) @distributed_test(world_size=8, tp_size=2, pp_size=1) @@ -892,14 +895,16 @@ def test_get_original_merged_weights_for_vllm(set_cache_for_ci): assert "model.layers.0.self_attn.q_proj.weight" in original_weights q_proj_weight = original_weights["model.layers.0.self_attn.q_proj.weight"] - assert q_proj_weight.shape == (hidden_size, hidden_size), \ + assert q_proj_weight.shape == (hidden_size, hidden_size), ( f"q_proj should be unsharded {hidden_size}x{hidden_size}, got {q_proj_weight.shape}" + ) # The custom model uses fused gate_up_proj, but original format should have separate projections assert "model.layers.0.mlp.gate_proj.weight" in original_weights assert "model.layers.0.mlp.up_proj.weight" in original_weights - assert "model.layers.0.mlp.gate_up_proj.weight" not in original_weights, \ + assert "model.layers.0.mlp.gate_up_proj.weight" not in original_weights, ( "Should use original format (separate gate/up), not custom format (fused gate_up)" + ) gate_proj_weight = original_weights["model.layers.0.mlp.gate_proj.weight"] up_proj_weight = original_weights["model.layers.0.mlp.up_proj.weight"] @@ -910,14 +915,16 @@ def test_get_original_merged_weights_for_vllm(set_cache_for_ci): merged_q_proj = original_weights[base_q_proj_name] # Weights should be different (LoRA delta was merged) - assert not torch.allclose(original_q_proj, merged_q_proj, rtol=1e-4), \ + assert not torch.allclose(original_q_proj, merged_q_proj, rtol=1e-4), ( "Merged weight should differ from original base weight (LoRA delta should be added)" + ) # Test 3: Verify model state is restored (adapters are unmerged) # After calling get_original_merged_weights_for_vllm, the model should be back to unmerged state for module in model.modules(): if hasattr(module, "merged"): - assert not module.merged, \ + assert not module.merged, ( f"Module {module.__class__.__name__} should be unmerged after get_original_merged_weights_for_vllm" + ) print("✓ All get_original_merged_weights_for_vllm tests passed") From dd4041dbbdb1d4ecffec72c35d74acc3e3ac4702 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 2 Dec 2025 17:49:01 +0100 Subject: [PATCH 52/78] collectives work --- .../grpo_qwen3/finetune_grpo_qwen3.py | 33 +- optimum/neuron/accelerate/accelerator.py | 2 +- optimum/neuron/accelerate/utils/operations.py | 28 +- optimum/neuron/trainers/grpo_trainer.py | 292 ++++++++---------- optimum/neuron/trainers/metrics/collector.py | 6 +- optimum/neuron/trainers/transformers.py | 15 +- 6 files changed, 162 insertions(+), 214 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index 6bb478665..0c6daab64 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -47,25 +47,16 @@ # For this example, we use simple rule-based rewards for demonstration. -def length_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]: +def length_reward( + prompts: list[str], completions: list[str], completion_ids: list[list[int]], **kwargs +) -> list[float]: """ Simple reward function that rewards longer responses (up to a point). - - This is a toy example. In practice, you'd want more sophisticated rewards - based on task-specific criteria (e.g., accuracy, coherence, helpfulness). - - Args: - prompts: List of input prompts - completions: List of generated completions - **kwargs: Additional arguments (e.g., trainer_state) - - Returns: - List of reward scores (one per completion) """ rewards = [] - for completion in completions: + for completion in completion_ids: # Reward based on length, but cap at 100 tokens to avoid overly long responses - length = len(completion.split()) + length = len(completion) reward = min(length / 50.0, 2.0) # Scale: 0-2 rewards.append(reward) return rewards @@ -74,17 +65,11 @@ def length_reward(prompts: list[str], completions: list[str], **kwargs) -> list[ def unique_words_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]: """ Reward function that encourages diversity by rewarding unique words. - - Args: - prompts: List of input prompts - completions: List of generated completions - **kwargs: Additional arguments - - Returns: - List of reward scores (one per completion) """ rewards = [] for completion in completions: + if isinstance(completion, list): + completion = completion[0]["content"] words = completion.lower().split() unique_words = len(set(words)) total_words = len(words) @@ -114,7 +99,7 @@ def load_grpo_dataset(): """ # Load a simple test dataset # This dataset has prompts in the "prompt" column - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-lib/DeepMath-103K", split="train") return dataset @@ -162,7 +147,7 @@ def train(model_id, tokenizer, dataset, training_args): grpo_config = NeuronGRPOConfig( # Generation parameters max_prompt_length=512, # Maximum prompt length - max_completion_length=268, # Maximum completion length + max_completion_length=300, # Maximum completion length num_generations=4, # Number of completions to generate per prompt (G in paper) temperature=0.8, # Sampling temperature # GRPO algorithm parameters diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 4b9180481..8141a8383 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -547,7 +547,7 @@ def save_state( output_dir=output_dir, safe_serialization=safe_serialization, **save_model_func_kwargs ) - def gather(self, tensor, sync: bool = False): + def gather(self, tensor: torch.Tensor, sync: bool = False) -> torch.Tensor: groups = get_data_parallel_group(as_list=True) # Ensure tensor is at least 1D for all_gather (scalars need to be unsqueezed) diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py index 8570ae73a..ab6ffa407 100644 --- a/optimum/neuron/accelerate/utils/operations.py +++ b/optimum/neuron/accelerate/utils/operations.py @@ -29,6 +29,8 @@ get_tensor_model_parallel_replica_groups, ) +from ...utils.misc import is_precompilation + def broadcast_object( obj: Any, @@ -78,7 +80,7 @@ def broadcast_object( np_buffer = np.frombuffer(bytes_, dtype=np.uint8) padding_length = target_length - length if padding_length > 0: - padding = np.zeros(padding_length, dtype=np.uint8) + padding = np.zeros([padding_length], dtype=np.uint8) np_buffer = np.concatenate([np_buffer, padding], axis=0) data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) else: @@ -167,39 +169,45 @@ def gather_object( raise ValueError(f"Serialized object size {length} exceeds the specified fixed_size {fixed_size}") lengths = xm.all_gather( - torch.tensor([length], dtype=torch.int64, device=xm.xla_device()), + torch.tensor([length], dtype=torch.int64).to(device=xm.xla_device()), dim=0, groups=groups, pin_layout=False, ) - max_length = torch.max(lengths) torch_xla.sync() - - max_len = int(max_length.item()) + lengths_cpu = lengths.cpu() + max_length = lengths_cpu.max() + max_length = int(max_length.item()) if fixed_size is not None: target_length = fixed_size else: - target_length = max_len + target_length = max_length np_buffer = np.frombuffer(serialized, dtype=np.uint8) padding_length = target_length - length if padding_length > 0: - padding = np.zeros(padding_length, dtype=np.uint8) + padding = np.zeros([padding_length], dtype=np.uint8) np_buffer = np.concatenate([np_buffer, padding], axis=0) data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) - data_tensors = xm.all_gather( + data_tensor = xm.all_gather( data_tensor, dim=0, groups=groups, pin_layout=False, ) torch_xla.sync() - lengths_cpu = lengths.cpu() - data_tensors_cpu = [t.cpu() for t in data_tensors] + + data_tensors_cpu = data_tensor.cpu().split(target_length) data_bytes = [t.numpy().tobytes() for t in data_tensors_cpu] + # During precompilation, all_gather returns tensors with uninitialized data or zeros, + # breaking the pickle.loads step below. So we return a list of the original object instead, + # it should not break anything since precompilation does not rely on the gathered objects. + if is_precompilation(): + return [obj for _ in range(world_size)] + results = [] for i in range(world_size): length_i = lengths_cpu[i].item() diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index fadd65a9a..da14b2f86 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -70,7 +70,7 @@ if is_trl_available(): from trl import GRPOConfig, GRPOTrainer - from trl.data_utils import is_conversational + from trl.data_utils import is_conversational, maybe_apply_chat_template from trl.extras.vllm_client import VLLMClient as TRLVLLMClient from trl.trainer.utils import ( RepeatSampler, @@ -638,29 +638,15 @@ def _move_model_to_vllm(self): self.llm.reset_prefix_cache() def _generate_single_turn(self, prompts: list[str], images: list | None): - """ - Generate a single turn of completions using vLLM (mock or real server). - - This overrides GRPOTrainer's implementation to work with Neuron/XLA devices. - The main difference is avoiding gather_object which doesn't work on XLA. - - MOCK MODE: Each process generates locally without gathering/broadcasting. - REAL SERVER MODE: Only main process generates, results are broadcast to all processes. - - Args: - prompts: List of prompt strings - images: Optional list of images - - Returns: - Tuple of (prompt_ids, completion_ids, logprobs, forward_kwargs) - """ - # Move model weights to vLLM if needed (no-op for mock) if self.state.global_step != getattr(self, "_last_loaded_step", -1): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step # Take unique prompts since we have num_generations duplicates - prompts_text = [prompt if isinstance(prompt, str) else prompt["content"] for prompt in prompts] + # Use maybe_apply_chat_template to handle both conversational and simple formats + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] ordered_set_of_prompts = prompts_text[:: self.num_generations] if images is not None: @@ -708,29 +694,6 @@ def _generate_single_turn(self, prompts: list[str], images: list | None): return prompt_ids, completion_ids, logprobs, forward_kwargs - def _to_fixed_length( - self, tensor: torch.Tensor, padding_value: int = 0, padding_side: str = "right" - ) -> torch.Tensor: - """ - Pads or truncates tensor to fixed length = max_prompt_length + max_completion_length. - """ - fixed_length = self.max_prompt_length + self.max_completion_length - seq_len = tensor.shape[1] - - if seq_len == fixed_length: - return tensor - elif seq_len < fixed_length: - # Pad to fixed length - pad_amount = fixed_length - seq_len - pad_config = (pad_amount, 0) if padding_side == "left" else (0, pad_amount) - return torch.nn.functional.pad(tensor, pad_config, value=padding_value) - else: - # Truncate to fixed length - if padding_side == "left": - return tensor[:, -fixed_length:] - else: - return tensor[:, :fixed_length] - def _get_per_token_logps_and_entropies( self, model, @@ -1054,12 +1017,12 @@ def _generate_and_score_completions( metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) # Log prompt and completion texts - # self._logs["prompt"].extend( - # gather_object_from_data_parallel_group(prompts_text, fixed_size=self.fixed_size_obj_collectives) - # ) - # self._logs["completion"].extend( - # gather_object_from_data_parallel_group(completions_text, fixed_size=self.fixed_size_obj_collectives) - # ) + self._logs["prompt"].extend( + gather_object_from_data_parallel_group(prompts_text, fixed_size=self.fixed_size_obj_collectives) + ) + self._logs["completion"].extend( + gather_object_from_data_parallel_group(completions_text, fixed_size=self.fixed_size_obj_collectives) + ) logs["rewards"] = {} logs["advantages"] = [] for i, name in enumerate(self.reward_func_names): @@ -1123,9 +1086,9 @@ def _generate_and_score_completions( ) # Move metrics and logs to CPU. + torch_xla.sync() metrics = move_all_tensor_to_cpu(metrics) logs = move_all_tensor_to_cpu(logs) - torch_xla.sync() metrics = {key: [val.item() for val in value] for key, value in metrics.items()} @@ -1235,136 +1198,139 @@ def _compute_loss(self, model, inputs): attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - # Compute the per_token_logps and the entropy at each position in the completion - per_token_logps, entropies = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - batch_size=self.args.per_device_train_batch_size, - compute_entropy=True, - pixel_values=inputs.get("pixel_values"), - image_grid_thw=inputs.get("image_grid_thw"), - num_images=inputs.get("num_images"), - pixel_attention_mask=inputs.get("pixel_attention_mask"), - image_sizes=inputs.get("image_sizes"), - token_type_ids=inputs.get("token_type_ids"), - ) - - if self.top_entropy_quantile < 1.0: - entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) - else: - entropy_mask = None - - # Compute the KL divergence between the model and the reference model - if self.beta != 0.0: - ref_per_token_logps = inputs["ref_per_token_logps"] - per_token_kl = ( - torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + with self.metrics_collector.time_metric("forward_pass", inputs=inputs): + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=self.args.per_device_train_batch_size, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), ) - # Compute the loss - advantages = inputs["advantages"] - # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, - # old_per_token_logps == per_token_logps. In this case we can skip its computation - # (see _generate_and_score_completions) and instead use per_token_logps.detach(). - # The exception is when using vLLM, where we always compute old_per_token_logps - # for importance sampling - old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps - - log_ratio = per_token_logps - old_per_token_logps - if self.importance_sampling_level == "token": - log_importance_weights = log_ratio - elif self.importance_sampling_level == "sequence": - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - log_importance_weights = log_importance_weights.unsqueeze(-1) - else: - raise ValueError( - f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " - "and 'sequence'." - ) - # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on - # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None - coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) - # Two-sided clipping - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) + # Compute the loss + advantages = inputs["advantages"] + # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, + # old_per_token_logps == per_token_logps. In this case we can skip its computation + # (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # The exception is when using vLLM, where we always compute old_per_token_logps + # for importance sampling + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - if entropy_mask is not None: - per_token_loss = per_token_loss * entropy_mask + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - if self.use_vllm and self.vllm_importance_sampling_correction: - per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] - - if self.beta != 0.0: - per_token_loss = per_token_loss + self.beta * per_token_kl - - if self.loss_type == "grpo": - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() - loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) - loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "dr_grpo": - loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "dapo": - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes - loss = (per_token_loss * completion_mask).sum() / normalizer - else: - raise ValueError(f"Unknown loss type: {self.loss_type}") + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) - # Log the metrics - mode = "train" if self.model.training else "eval" + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask - completion_token_count = completion_mask.sum().clamp(min=1.0) + if self.use_vllm and self.vllm_importance_sampling_correction: + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] - def masked_batch_mean(x): - if x.shape[1] == 1: # when importance_sampling_level == "sequence" - return x.mean() + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "dapo": + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * completion_mask).sum() / normalizer else: - return (x * completion_mask).sum() / completion_token_count + raise ValueError(f"Unknown loss type: {self.loss_type}") - metrics = defaultdict(list) - - if self.beta != 0.0: - mean_kl = masked_batch_mean(per_token_kl) - metrics["kl"].append(self.accelerator.gather(mean_kl).nanmean()) - - mean_entropy = masked_batch_mean(entropies) - metrics["entropy"].append(self.accelerator.gather(mean_entropy).nanmean()) + # Log the metrics + mode = "train" if self.model.training else "eval" - # Compute the clipped probability ratios - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) - is_region_clipped = is_low_clipped | is_high_clipped + completion_token_count = completion_mask.sum().clamp( + min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) + ) - low_clip = masked_batch_mean(is_low_clipped.float()) - high_clip = masked_batch_mean(is_high_clipped.float()) - clip_ratio = masked_batch_mean(is_region_clipped.float()) + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count - gathered_low_clip = self.accelerator.gather(low_clip) - metrics["clip_ratio/low_mean"].append(gathered_low_clip.nanmean()) - metrics["clip_ratio/low_min"].append(nanmin(gathered_low_clip)) - gathered_high_clip = self.accelerator.gather(high_clip) - metrics["clip_ratio/high_mean"].append(gathered_high_clip.nanmean()) - metrics["clip_ratio/high_max"].append(nanmax(gathered_high_clip)) - gathered_clip_ratio = self.accelerator.gather(clip_ratio) - metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) + metrics = defaultdict(list) - # torch_xla.sync() # Graph break before moving metrics to CPU. - metrics = move_all_tensor_to_cpu(metrics) - torch_xla.sync() - metrics = {key: [val.item() for val in value] for key, value in metrics.items()} + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + metrics["kl"].append(self.accelerator.gather(mean_kl).nanmean()) + + mean_entropy = masked_batch_mean(entropies) + metrics["entropy"].append(self.accelerator.gather(mean_entropy).nanmean()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + metrics["clip_ratio/low_mean"].append(gathered_low_clip.nanmean()) + metrics["clip_ratio/low_min"].append(nanmin(gathered_low_clip)) + gathered_high_clip = self.accelerator.gather(high_clip) + metrics["clip_ratio/high_mean"].append(gathered_high_clip.nanmean()) + metrics["clip_ratio/high_max"].append(nanmax(gathered_high_clip)) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) + + # torch_xla.sync() # Graph break before moving metrics to CPU. + metrics = move_all_tensor_to_cpu(metrics) + torch_xla.sync() + metrics = {key: [val.item() for val in value] for key, value in metrics.items()} - self._metrics[mode].update(metrics) + self._metrics[mode].update(metrics) return loss diff --git a/optimum/neuron/trainers/metrics/collector.py b/optimum/neuron/trainers/metrics/collector.py index a6d9db3e8..59da8571a 100644 --- a/optimum/neuron/trainers/metrics/collector.py +++ b/optimum/neuron/trainers/metrics/collector.py @@ -206,9 +206,9 @@ def get_metric_window_stats(self, metric_name: str) -> dict: def get_metric_unit(self, metric_name: str) -> str: """Get the unit for a specific metric.""" for plugin in self.active_plugins: - if plugin.handles_metric(metric_name): - units = plugin.get_metric_units() - return units.get(metric_name, "") + metric_units = plugin.get_metric_units() + if metric_name in metric_units: + return metric_units[metric_name] return "" def get_all_metric_units(self) -> dict[str, str]: diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index 0b7b5532d..e8ad272da 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -1371,19 +1371,8 @@ def report_and_save_summary_metrics(self): # Group and format metrics for better readability for metric_name, value in summary_metrics.items(): if isinstance(value, float): - if "time" in metric_name: - logger.info(f"{metric_name}: {value:.4f}s") - elif "per_sec" in metric_name: - logger.info(f"{metric_name}: {value:.2f}") - elif ( - "mfu" in metric_name - or "efficiency" in metric_name - or "consistency" in metric_name - or "percent" in metric_name - ): - logger.info(f"{metric_name}: {value:.2f}%") - else: - logger.info(f"{metric_name}: {value:.2f}") + unit = self.metrics_collector.get_metric_unit(metric_name) + logger.info(f"{metric_name}: {value:.2f}{unit}") else: logger.info(f"{metric_name}: {value}") logger.info("=" * 80) From 56011c4365d7a73fd7ad9930f270a24d02e45b7f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 2 Dec 2025 18:33:17 +0100 Subject: [PATCH 53/78] fix clamping bug --- .../training/grpo_qwen3/finetune_grpo_qwen3.py | 3 ++- .../training/grpo_qwen3/finetune_grpo_qwen3.sh | 2 +- optimum/neuron/trainers/grpo_trainer.py | 17 +++++++++++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index 0c6daab64..67d40f0b7 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -99,7 +99,8 @@ def load_grpo_dataset(): """ # Load a simple test dataset # This dataset has prompts in the "prompt" column - dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + # dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") return dataset diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index e2b9afcb4..c387b851e 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -30,7 +30,7 @@ NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT TP_DEGREE=1 BS=1 GRADIENT_ACCUMULATION_STEPS=4 # Smaller for GRPO due to generation overhead -LOGGING_STEPS=1 +LOGGING_STEPS=10 MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing # MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index da14b2f86..496b17fda 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -1241,7 +1241,9 @@ def _compute_loss(self, model, inputs): if self.importance_sampling_level == "token": log_importance_weights = log_ratio elif self.importance_sampling_level == "sequence": - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp( + min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) + ) log_importance_weights = log_importance_weights.unsqueeze(-1) else: raise ValueError( @@ -1271,16 +1273,23 @@ def _compute_loss(self, model, inputs): per_token_loss = per_token_loss + self.beta * per_token_kl if self.loss_type == "grpo": - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + loss = ( + (per_token_loss * completion_mask).sum(-1) + / completion_mask.sum(-1).clamp( + min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) + ) + ).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp( + min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) + ) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dapo": - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + normalizer = inputs["num_items_in_batch"] loss = (per_token_loss * completion_mask).sum() / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") From a7f58a326683101b33ee7b0f64ce735e03367f89 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 3 Dec 2025 11:30:47 +0100 Subject: [PATCH 54/78] make use_vllm the default --- .../grpo_qwen3/finetune_grpo_qwen3.py | 28 ++----------------- .../grpo_qwen3/finetune_grpo_qwen3.sh | 11 ++++++-- optimum/neuron/trainers/grpo_config.py | 14 +++++++++- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index 67d40f0b7..d0d76665a 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -29,7 +29,7 @@ from peft import LoraConfig from transformers import AutoTokenizer, HfArgumentParser -from optimum.neuron import NeuronGRPOConfig, NeuronGRPOTrainer, NeuronTrainingArguments +from optimum.neuron import NeuronGRPOConfig, NeuronGRPOTrainer from optimum.neuron.models.training import NeuronModelForCausalLM from optimum.neuron.trainers.extras import MockVLLMClient @@ -141,29 +141,7 @@ def train(model_id, tokenizer, dataset, training_args): task_type="CAUSAL_LM", ) - # Convert NeuronTrainingArguments to dict for NeuronGRPOConfig - args = training_args.to_dict() - - # GRPO-specific configuration - grpo_config = NeuronGRPOConfig( - # Generation parameters - max_prompt_length=512, # Maximum prompt length - max_completion_length=300, # Maximum completion length - num_generations=4, # Number of completions to generate per prompt (G in paper) - temperature=0.8, # Sampling temperature - # GRPO algorithm parameters - num_iterations=1, # Number of iterations per batch (μ in paper) - epsilon=0.2, # Clipping parameter - beta=0.01, # KL divergence coefficient - scale_rewards="group", # Reward scaling strategy - # vLLM parameters - use_vllm=True, # Use vLLM for generation (required for Neuron) - vllm_mode="server", # Use vLLM server mode - vllm_server_host="0.0.0.0", - vllm_server_port=8000, - # Standard training arguments from NeuronTrainingArguments - **args, - ) + grpo_config = training_args # Define reward functions # You can use multiple reward functions - they will be summed @@ -202,7 +180,7 @@ class ScriptArguments: # Main Function # ============================================================================= if __name__ == "__main__": - parser = HfArgumentParser((ScriptArguments, NeuronTrainingArguments)) + parser = HfArgumentParser((ScriptArguments, NeuronGRPOConfig)) script_args, training_args = parser.parse_args_into_dataclasses() tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index c387b851e..8ab369f82 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -29,8 +29,8 @@ PROCESSES_PER_NODE=2 NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT TP_DEGREE=1 BS=1 -GRADIENT_ACCUMULATION_STEPS=4 # Smaller for GRPO due to generation overhead -LOGGING_STEPS=10 +GRADIENT_ACCUMULATION_STEPS=1 # Smaller for GRPO due to generation overhead +LOGGING_STEPS=2 MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing # MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" @@ -73,7 +73,12 @@ torchrun $DISTRIBUTED_ARGS finetune_grpo_qwen3.py \ --logging_steps $LOGGING_STEPS \ --output_dir $OUTPUT_DIR \ --lr_scheduler_type "cosine" \ - --overwrite_output_dir + --overwrite_output_dir \ + --num_generations $NUM_GENERATIONS \ + --max_prompt_length $MAX_PROMPT_LENGTH \ + --max_completion_length $MAX_COMPLETION_LENGTH \ + --temperature $TEMPERATURE \ + --steps_per_generation $STEPS_PER_GENERATION echo "================================" echo "Training completed!" diff --git a/optimum/neuron/trainers/grpo_config.py b/optimum/neuron/trainers/grpo_config.py index c696ab2ba..89af7c99d 100644 --- a/optimum/neuron/trainers/grpo_config.py +++ b/optimum/neuron/trainers/grpo_config.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from ..utils.import_utils import is_trl_available from .training_args import NeuronTrainingArguments @@ -39,7 +39,19 @@ class NeuronGRPOConfig(NeuronTrainingArguments, GRPOConfig): with GRPOConfig for GRPO algorithm parameters. """ + use_vllm: bool = field( + default=True, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for " + "generation instead of the default model.generate(). Requires `vllm` to be installed. Required for NeuronGRPOTrainer." + }, + ) + def __post_init__(self): + # For now, NeuronGRPOTrainer requires vLLM for generation, no other way is supported. + if not self.use_vllm: + raise ValueError("NeuronGRPOTrainer requires `use_vllm` to be set to `True`.") + # Handle bf16 default (from GRPOConfig) self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 From 544cadfb3adba10bad242450719a53cf0c22e3a3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 4 Dec 2025 13:56:36 +0100 Subject: [PATCH 55/78] update with torch_xla.sync and peft --- optimum/neuron/accelerate/accelerator.py | 4 +-- optimum/neuron/accelerate/utils/misc.py | 3 ++- .../neuron/models/training/training_utils.py | 5 ++-- optimum/neuron/peft/mapping_func.py | 26 ++++++++++++------- optimum/neuron/peft/tuners/lora/layer.py | 6 ++++- optimum/neuron/trainers/transformers.py | 19 +++++++------- optimum/neuron/trainers/utils.py | 3 ++- 7 files changed, 41 insertions(+), 25 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 8141a8383..7d0d0d325 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -390,7 +390,7 @@ def prepare_model( move_model_to_device(model, xm.xla_device()) model.tie_weights() - xm.mark_step() + torch_xla.sync() # Adding the model to the list of prepared models. self._models.append(model) @@ -474,7 +474,7 @@ def _inner(folder): logger.info(f"Saving current state to {output_dir}") # Finish running the previous step before checkpointing - xm.mark_step() + torch_xla.sync() # Save the models if save_model_func is not None: diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 6cd32e8bc..9aa60b13e 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -23,6 +23,7 @@ import accelerate import torch +import torch_xla import torch_xla.core.xla_model as xm from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_rank, @@ -119,7 +120,7 @@ def wrapper(*args, **kwargs): with patcher: output = orig_func(*args, **kwargs) self.load_state_dict(orig_state_dict, assign=True) - xm.mark_step() + torch_xla.sync() del cpu_state_dict gc.collect() return output diff --git a/optimum/neuron/models/training/training_utils.py b/optimum/neuron/models/training/training_utils.py index f4fd7b561..f389f2578 100644 --- a/optimum/neuron/models/training/training_utils.py +++ b/optimum/neuron/models/training/training_utils.py @@ -89,6 +89,7 @@ def skip_first_batches(dataloader, num_batches=0): def _get_model_param_count(model: "torch.nn.Module | NxDPPModel"): """Counts the number of parameters of the model.""" + import torch_xla import torch_xla.core.xla_model as xm from neuronx_distributed.parallel_layers.parallel_state import ( get_pipeline_model_parallel_group, @@ -143,7 +144,7 @@ def numel(parameter_name, parameter) -> int: def reduce_param_count_over_pp_ranks(param_count: int): param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device()) param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True)) - xm.mark_step() + torch_xla.sync() param_count = int(param_count.detach().cpu().item()) return param_count @@ -153,7 +154,7 @@ def reduce_param_count_over_pp_ranks(param_count: int): all_param_count = reduce_param_count_over_pp_ranks(all_param_count) trainable_param_count = reduce_param_count_over_pp_ranks(trainable_param_count) - xm.mark_step() + torch_xla.sync() return trainable_param_count, all_param_count diff --git a/optimum/neuron/peft/mapping_func.py b/optimum/neuron/peft/mapping_func.py index 7e25c4e90..37deda15c 100644 --- a/optimum/neuron/peft/mapping_func.py +++ b/optimum/neuron/peft/mapping_func.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib + from transformers import PreTrainedModel from ..utils.import_utils import is_peft_available @@ -49,16 +51,22 @@ def get_peft_model( revision: str | None = None, low_cpu_mem_usage: bool = False, ) -> PeftModel | PeftMixedModel: - if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING: - raise ValueError( - "PEFT type {peft_config.peft_type} not supported in Optimum Neuron. Supported types are: " - f"{list(PEFT_TYPE_TO_TUNER_MAPPING.keys())}" + from ..models.training import NeuronModelMixin + + if isinstance(model, NeuronModelMixin): + if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING: + raise ValueError( + "PEFT type {peft_config.peft_type} not supported in Optimum Neuron. Supported types are: " + f"{list(PEFT_TYPE_TO_TUNER_MAPPING.keys())}" + ) + patcher = Patcher( + [ + ("peft.mapping_func.MODEL_TYPE_TO_PEFT_MODEL_MAPPING", MODEL_TYPE_TO_PEFT_MODEL_MAPPING), + ], ) - patcher = Patcher( - [ - ("peft.mapping_func.MODEL_TYPE_TO_PEFT_MODEL_MAPPING", MODEL_TYPE_TO_PEFT_MODEL_MAPPING), - ], - ) + else: + # No patching needed since model parallelism is not enabled, we can use PEFT as is. + patcher = contextlib.nullcontext() with patcher: peft_model = orig_get_peft_model( model, diff --git a/optimum/neuron/peft/tuners/lora/layer.py b/optimum/neuron/peft/tuners/lora/layer.py index 07f0f768d..544fae500 100644 --- a/optimum/neuron/peft/tuners/lora/layer.py +++ b/optimum/neuron/peft/tuners/lora/layer.py @@ -726,7 +726,11 @@ def unmerge(self) -> None: def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: previous_dtype = x.dtype output_q, output_k, output_v = self.base_layer(x, *args, **kwargs) - if not self.merged: + + if self.disable_adapters: + if self.merged: + self.unmerge() + elif not self.merged: for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index e8ad272da..748af77ef 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn +import torch_xla import torch_xla.core.xla_model as xm from accelerate.utils import AutocastKwargs, DataLoaderConfiguration from neuronx_distributed.parallel_layers.parallel_state import ( @@ -907,7 +908,7 @@ def setup_training( self.running_loss = torch.zeros(1, dtype=torch.double, device=xm.xla_device()) self.grad_norm = None - xm.mark_step() + torch_xla.sync() def get_batch_samples( self, @@ -1042,7 +1043,7 @@ def maybe_log_train_step_metrics(self): return if self.control.should_log: - xm.mark_step() + torch_xla.sync() running_loss_div = self.running_loss / self.dp_size reduced_loss = xm.all_reduce(xm.REDUCE_SUM, running_loss_div, groups=get_data_parallel_replica_groups()) reduced_loss = reduced_loss.detach() @@ -1091,7 +1092,7 @@ def log_closure(): def maybe_save_checkpoint(self): """Save checkpoint if saving is due.""" if self.control.should_save: - xm.mark_step() + torch_xla.sync() def save_closure(self): self._save_checkpoint() @@ -1168,7 +1169,7 @@ def _train( self.metrics_collector.start_metric("total_step") for inputs in batch_samples: - xm.mark_step() + torch_xla.sync() step += 1 do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch @@ -1184,7 +1185,7 @@ def _train( if do_sync_step: self.accelerator.gradient_state.sync_gradients = True - xm.mark_step() + torch_xla.sync() # Gradient clipping self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) @@ -1205,7 +1206,7 @@ def _train( self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - xm.mark_step() + torch_xla.sync() self.metrics_collector.stop_metric("throughput") self.metrics_collector.stop_metric("total_step") self.metrics_collector.end_gradient_accumulation_cycle(step_number=self.state.global_step) @@ -1220,12 +1221,12 @@ def _train( # PyTorch/XLA relies on the data loader to insert the mark_step for # each step. Since we are breaking the loop early, we need to manually # insert the mark_step here. - xm.mark_step() + torch_xla.sync() break # We also need to break out of the nested loop if self.control.should_epoch_stop or self.control.should_training_stop: - xm.mark_step() + torch_xla.sync() break if step < 0: @@ -1237,7 +1238,7 @@ def _train( self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - xm.mark_step() + torch_xla.sync() if self.control.should_training_stop: break diff --git a/optimum/neuron/trainers/utils.py b/optimum/neuron/trainers/utils.py index 1cdaf14b2..065be0b60 100644 --- a/optimum/neuron/trainers/utils.py +++ b/optimum/neuron/trainers/utils.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +import torch_xla import torch_xla.core.xla_model as xm @@ -51,7 +52,7 @@ def __iter__(self): def __next__(self): if not self.buffer: raise StopIteration - xm.mark_step() + torch_xla.sync() next_example = self.buffer.pop(0) self._prefetch() return next_example From b0aa60c576a24869ef99e68c52c3ccc5edc48c23 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 8 Dec 2025 16:16:32 +0100 Subject: [PATCH 56/78] wip training --- .../grpo_qwen3/finetune_grpo_qwen3.py | 5 +- .../grpo_qwen3/finetune_grpo_qwen3.sh | 49 +++----- optimum/neuron/trainers/extras/vllm_client.py | 40 ++++++ optimum/neuron/trainers/grpo_trainer.py | 12 +- optimum/neuron/trainers/trl_utils.py | 115 +++--------------- 5 files changed, 87 insertions(+), 134 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index d0d76665a..c0425d511 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -100,7 +100,8 @@ def load_grpo_dataset(): # Load a simple test dataset # This dataset has prompts in the "prompt" column # dataset = load_dataset("trl-lib/DeepMath-103K", split="train") - dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + # dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + dataset = load_dataset("trl-lib/tldr", split="train") return dataset @@ -159,7 +160,7 @@ def train(model_id, tokenizer, dataset, training_args): processing_class=tokenizer, peft_config=lora_config, # To do: disable this fake client, only for development without vLLM server. - vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), + # vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), ) # Train the model diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index 8ab369f82..7a7201b8a 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -1,28 +1,15 @@ #!/bin/bash - -# ============================================================================ -# GRPO Fine-tuning Script for Qwen3 on AWS Trainium -# ============================================================================ -# This script demonstrates how to fine-tune a Qwen3 model using GRPO -# (Group Relative Policy Optimization) on AWS Trainium devices. -# -# Prerequisites: -# 1. vLLM server running (or use mock vLLM for development) -# 2. Neuron SDK installed -# 3. Multi-node Trainium instance (e.g., trn1.32xlarge) -# -# For mock vLLM (development/testing): -# - Set USE_MOCK_VLLM=True in grpo_trainer.py -# -# For real vLLM: -# - Start vLLM server with: trl vllm-serve --model MODEL_NAME -# ============================================================================ - # Flags for Neuron compilation -export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation" +export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation --cache_dir=$HOME/cache_dir_neuron/" export NEURON_FUSE_SOFTMAX=1 -export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime -export MALLOC_ARENA_MAX=64 # Host OOM mitigation +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=7 # Async Runtime +export MALLOC_ARENA_MAX=128 # Host OOM mitigation +# Force NCCL to ignore the AWS OFI plugin +export NCCL_NET_GDR_LEVEL=0 +export FI_PROVIDER=sockets +# export FI_EFA_USE_DEVICE_RDMA=1 +# export FI_PROVIDER=efa +# export FI_EFA_FORK_SAFE=1 # Variables for training PROCESSES_PER_NODE=2 @@ -30,9 +17,11 @@ NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT TP_DEGREE=1 BS=1 GRADIENT_ACCUMULATION_STEPS=1 # Smaller for GRPO due to generation overhead -LOGGING_STEPS=2 +LOGGING_STEPS=1 MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing +# MODEL_NAME="yujiepan/qwen3-tiny-random" # Use smaller model for testing # MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" +# MODEL_NAME="HuggingFaceTB/SmolLM2-135M" OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE" SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) @@ -40,9 +29,9 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) # GRPO-specific variables NUM_GENERATIONS=4 # Number of completions per prompt (G in paper) MAX_PROMPT_LENGTH=512 -MAX_COMPLETION_LENGTH=256 +MAX_COMPLETION_LENGTH=512 TEMPERATURE=0.8 -STEPS_PER_GENERATION=4 # Generate every N steps to amortize generation cost +STEPS_PER_GENERATION=8 # Generate every N steps to amortize generation cost if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then MAX_STEPS=5 @@ -63,22 +52,24 @@ torchrun $DISTRIBUTED_ARGS finetune_grpo_qwen3.py \ --per_device_train_batch_size $BS \ --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ --gradient_checkpointing \ - --learning_rate 5e-5 \ + --learning_rate 5e-8 \ --bf16 \ --tensor_parallel_size $TP_DEGREE \ - --zero_1 false \ + --zero_1 \ --optimizer_use_master_weights false \ --optimizer_use_fp32_grad_acc false \ --async_save \ --logging_steps $LOGGING_STEPS \ --output_dir $OUTPUT_DIR \ - --lr_scheduler_type "cosine" \ + --lr_scheduler_type "constant" \ --overwrite_output_dir \ --num_generations $NUM_GENERATIONS \ --max_prompt_length $MAX_PROMPT_LENGTH \ --max_completion_length $MAX_COMPLETION_LENGTH \ --temperature $TEMPERATURE \ - --steps_per_generation $STEPS_PER_GENERATION + --steps_per_generation $STEPS_PER_GENERATION \ + --epsilon 0.1 \ + --beta 0.01 echo "================================" echo "Training completed!" diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index b8c08bc12..7fce0ef2a 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -15,8 +15,10 @@ import atexit import random +import socket import time from collections import namedtuple +from contextlib import closing from typing import Union import requests @@ -41,6 +43,22 @@ class StatelessProcessGroup: Group = namedtuple("Group", "barrier") +def find_closest_port(host: str, start_port: int, max_attempts: int = 100) -> int: + for port in range(start_port, start_port + max_attempts): + if is_port_available(host, port): + return port + raise RuntimeError(f"No available port found in range {start_port} to {start_port + max_attempts - 1}") + + +def is_port_available(host: str, port: int) -> bool: + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + try: + sock.bind((host, port)) + return True + except OSError: + return False + + class CPUCommunicator: def __init__(self, store, rank): self.rank = rank @@ -64,6 +82,28 @@ def __del__(self): class VLLMClient(TRLVLLMClient): """VLLMClient with CPU-based communication for Neuron environments.""" + def __init__( + self, + base_url: str | None = None, + host: str = "0.0.0.0", + server_port: int = 8000, + group_port: int = 51216, + connection_timeout: float = 0.0, + ): + # free_group_port = find_closest_port(host, group_port) + # if free_group_port != group_port: + # logger.warning( + # f"Requested group_port {group_port} is not available. Using closest available port {free_group_port} instead." + # ) + free_group_port = group_port + super().__init__( + base_url=base_url, + host=host, + server_port=server_port, + group_port=free_group_port, + connection_timeout=connection_timeout, + ) + def init_communicator(self, device: Union[torch.device, str, int] = 0): # Get the world size from the server url = f"{self.base_url}/get_world_size/" diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 496b17fda..c87481a3e 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -617,11 +617,11 @@ def _move_model_to_vllm(self): # Clean up parameter name for vLLM name = self._fix_param_name_to_vllm(name) - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, weight) - elif self.vllm_mode == "colocate": - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, weight)]) + # if self.vllm_mode == "server" and self.accelerator.is_main_process: + # self.vllm_client.update_named_param(name, weight) + # elif self.vllm_mode == "colocate": + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + # llm_model.load_weights([(name, weight)]) else: for name, param in self.model.named_parameters(): name = self._fix_param_name_to_vllm(name) @@ -940,7 +940,7 @@ def _generate_and_score_completions( **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): + with self.model.disable_adapter(): ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( self.model, prompt_completion_ids, diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 9b44c540e..080f16173 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Literal import numpy as np @@ -23,6 +22,7 @@ from optimum.utils import logging from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler +from trl.trainer.utils import RepeatSampler from ..utils import is_precompilation @@ -208,20 +208,7 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: return torch.sqrt(variance) -class DistributedRepeatSampler(DistributedSampler): - """ - Sampler that repeats the indices of a dataset in a structured manner. - Same as `trl.trainer.utils.RepeatSampler` but adapted to work with distributed training. - - To implement it, we simply combine the logic from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py - with the logic from https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1692. - - First, we distribute the dataset indices across the different ranks, then we repeat the indices on each rank. - - We inherit from `torch.utils.data.DistributedSampler` even though we override all of its methods to pass the checks - "isinstance(sampler, DistributedSampler)" done in `torch.utils.data.DataLoader` when using distributed training. - """ - +class DistributedRepeatSampler(RepeatSampler, DistributedSampler): def __init__( self, dataset: Dataset, @@ -234,90 +221,24 @@ def __init__( seed: int = 0, drop_last: bool = False, ): + # Initialize RepeatSampler with the actual sampling logic + RepeatSampler.__init__( + self, + data_source=dataset, + mini_repeat_count=mini_repeat_count, + batch_size=batch_size, + repeat_count=repeat_count, + shuffle=shuffle, + seed=seed, + ) + + # Store DistributedSampler attributes for interface compatibility + # (but we don't use them for actual distribution) if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - num_replicas = dist.get_world_size() + num_replicas = dist.get_world_size() if dist.is_available() else 1 if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - rank = dist.get_rank() - if rank >= num_replicas or rank < 0: - raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") - self.dataset = dataset + rank = dist.get_rank() if dist.is_available() else 0 + self.num_replicas = num_replicas self.rank = rank - self.epoch = 0 self.drop_last = drop_last - # If the dataset length is evenly divisible by # of replicas, then there - # is no need to drop any data, since the dataset will be split equally. - if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] - # Split to nearest available length that is evenly divisible. - # This is to ensure each rank receives the same amount of data when - # using this Sampler. - self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] - ) - else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] - self.total_size = self.num_samples * self.num_replicas - self.shuffle = shuffle - self.seed = seed - - self.mini_repeat_count = mini_repeat_count - self.batch_size = batch_size - self.repeat_count = repeat_count - self.shuffle = shuffle - - if shuffle: - self.generator = torch.Generator() # Create a local random generator - self.generator.manual_seed(seed) - - def __iter__(self): - # First, we produce indices for each rank. - # That is the distributed part of the sampler. - if self.shuffle: - # deterministically shuffle based on epoch and seed - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] - else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] - - if not self.drop_last: - # add extra samples to make it evenly divisible - padding_size = self.total_size - len(indices) - if padding_size <= len(indices): - indices += indices[:padding_size] - else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] - else: - # remove tail of data to make it evenly divisible. - indices = indices[: self.total_size] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples - - # Second, we repeat the indices on each rank. - # This is the non-distributed part of the sampler. - # [2, 4, 3, 1, 0, 6, 5] - # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) - indices = [indices[i : i + self.batch_size] for i in range(0, len(indices), self.batch_size)] - - # [[2, 4, 3], [1, 0, 6], [5]] - # -> [[2, 4, 3], [1, 0, 6]] - indices = [chunk for chunk in indices if len(chunk) == self.batch_size] - - for chunk in indices: - for _ in range(self.repeat_count): - for index in chunk: - for _ in range(self.mini_repeat_count): - yield index - - def __len__(self) -> int: - return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count - - def set_epoch(self, epoch: int) -> None: - self.epoch = epoch From ba62cf4b530602beb660df4bb3319364b464757d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Dec 2025 11:28:22 +0100 Subject: [PATCH 57/78] wip training --- examples/training/grpo_qwen3/finetune_grpo_qwen3.sh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index 7a7201b8a..9bd038db8 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -2,11 +2,9 @@ # Flags for Neuron compilation export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation --cache_dir=$HOME/cache_dir_neuron/" export NEURON_FUSE_SOFTMAX=1 -export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=7 # Async Runtime -export MALLOC_ARENA_MAX=128 # Host OOM mitigation +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime +export MALLOC_ARENA_MAX=64 # Host OOM mitigation # Force NCCL to ignore the AWS OFI plugin -export NCCL_NET_GDR_LEVEL=0 -export FI_PROVIDER=sockets # export FI_EFA_USE_DEVICE_RDMA=1 # export FI_PROVIDER=efa # export FI_EFA_FORK_SAFE=1 @@ -31,7 +29,7 @@ NUM_GENERATIONS=4 # Number of completions per prompt (G in paper) MAX_PROMPT_LENGTH=512 MAX_COMPLETION_LENGTH=512 TEMPERATURE=0.8 -STEPS_PER_GENERATION=8 # Generate every N steps to amortize generation cost +STEPS_PER_GENERATION=4 # Generate every N steps to amortize generation cost if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then MAX_STEPS=5 From 688de3a04ca258d2cd36879d38bafcafa8e4a42b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Dec 2025 15:25:40 +0100 Subject: [PATCH 58/78] chore: fix pyproject.toml for uv --- pyproject.toml | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4b8bac540..5c0ed092b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,8 @@ classifiers = [ ] dependencies = [ "transformers ~= 4.57.1", - "accelerate == 1.8.1", "optimum ~= 2.0.0", "huggingface_hub >= 0.31.4", - "numpy>=1.22.2, <=1.26.4", - "protobuf>=3.20.3", ] [project.urls] @@ -78,16 +75,18 @@ training = [ "trl == 0.24.0", "peft == 0.17.0", "evaluate == 0.4.3", + "accelerate == 1.8.1", +] +neuron = [ + "wheel", + "torch-neuron==1.13.1.2.9.74.0", + "torch==1.13.1.*", + "neuron-cc[tensorflow]==1.22.0.0", + "protobuf", + "torchvision", + "numpy==1.22.2", + "protobuf<=3.20.1", ] -# neuron = [ -# "wheel", -# "torch-neuron==1.13.1.2.9.74.0", -# "torch==1.13.1.*", -# "neuron-cc[tensorflow]==1.22.0.0", -# "protobuf", -# "torchvision", -# "numpy==1.22.3", -# ] neuronx = [ "wheel", "neuronx-cc==2.21.18209.0", @@ -96,6 +95,8 @@ neuronx = [ "torchvision==0.23.*", "neuronx_distributed==0.15.22404", "libneuronxla==2.2.12677.0", + "protobuf>=3.20.3", + "numpy>=1.22.2, <=1.26.4", ] diffusers = [ "diffusers==0.35.*", @@ -134,12 +135,26 @@ include = ["optimum*"] [tool.uv] index-strategy = "unsafe-best-match" +conflicts = [ + [ + { extra = "neuron" }, + { extra = "neuronx" }, + ], + [ + { extra = "neuron" }, + { extra = "training" }, + ], + [ + { extra = "neuron" }, + { extra = "vllm" }, + ], +] [[tool.uv.index]] name = "neuron" url = "https://pip.repos.neuron.amazonaws.com/" default = true - + [[tool.uv.index]] name = "pypi" url = "https://pypi.org/simple" From a2880aeab8efe9fc06cb45c883d3082d96e1957f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 10 Dec 2025 16:38:32 +0100 Subject: [PATCH 59/78] chore: update pyproject.toml for SDK 2.26.1 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c0ed092b..b7af2dcd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,8 +89,8 @@ neuron = [ ] neuronx = [ "wheel", - "neuronx-cc==2.21.18209.0", - "torch-neuronx==2.8.0.2.10.13553", + "neuronx-cc==2.21.33363.0", + "torch-neuronx==2.8.0.2.10.16998", "torch==2.8.0.*", "torchvision==0.23.*", "neuronx_distributed==0.15.22404", From 7d8914b87caae5c34a183fef2dab0c2f779512e9 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 18 Dec 2025 18:57:56 +0100 Subject: [PATCH 60/78] feat: improve nan functions for XLA --- .../grpo_qwen3/finetune_grpo_qwen3.sh | 6 ++-- optimum/neuron/trainers/trl_utils.py | 34 +++++++++++++------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index 9bd038db8..264d2d7b7 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -10,14 +10,14 @@ export MALLOC_ARENA_MAX=64 # Host OOM mitigation # export FI_EFA_FORK_SAFE=1 # Variables for training -PROCESSES_PER_NODE=2 +PROCESSES_PER_NODE=1 NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT TP_DEGREE=1 BS=1 GRADIENT_ACCUMULATION_STEPS=1 # Smaller for GRPO due to generation overhead LOGGING_STEPS=1 -MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing -# MODEL_NAME="yujiepan/qwen3-tiny-random" # Use smaller model for testing +# MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing +MODEL_NAME="yujiepan/qwen3-tiny-random" # Use smaller model for testing # MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" # MODEL_NAME="HuggingFaceTB/SmolLM2-135M" OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 080f16173..d271fcf32 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -181,10 +181,10 @@ def nanmin(tensor: torch.Tensor) -> torch.Tensor: Compute the minimum value of a tensor, ignoring NaNs. """ mask = torch.isnan(tensor) - if mask.all(): - return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) - filled = torch.where(mask, torch.tensor(float("inf"), dtype=tensor.dtype, device=tensor.device), tensor) - return torch.min(filled) + filled = torch.where(mask, torch.full_like(tensor, float("inf")), tensor) + min_value = torch.amin(filled) + all_nan = mask.all() + return torch.where(all_nan, torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device), min_value) def nanmax(tensor: torch.Tensor) -> torch.Tensor: @@ -193,19 +193,31 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: Compute the maximum value of a tensor, ignoring NaNs. """ mask = torch.isnan(tensor) - if mask.all(): - return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) - filled = torch.where(mask, torch.tensor(float("-inf"), dtype=tensor.dtype, device=tensor.device), tensor) - return torch.max(filled) + filled = torch.where(mask, torch.full_like(tensor, float("-inf")), tensor) + min_value = torch.amax(filled) + all_nan = mask.all() + return torch.where(all_nan, torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device), min_value) -def nanstd(tensor: torch.Tensor) -> torch.Tensor: +def nanstd(tensor: torch.Tensor, unbiased: bool = False) -> torch.Tensor: """ XLA-compatible version of nanstd. Compute the standard deviation of a tensor, ignoring NaNs. """ - variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) - return torch.sqrt(variance) + mask = ~torch.isnan(tensor) + count = mask.sum() + + clean = torch.where(mask, tensor, torch.zeros_like(tensor)) + mean = clean.sum() / count + + diff_squared = torch.where(mask, (clean - mean) ** 2, torch.zeros_like(tensor)) + + if unbiased: + variance = diff_squared.sum() / (count - 1).clamp(min=1) + else: + variance = diff_squared.sum() / count + + return variance.sqrt() class DistributedRepeatSampler(RepeatSampler, DistributedSampler): From ee083cb9be71aa4d1e63ce15fd35ab41a9caea27 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 18 Dec 2025 19:19:44 +0100 Subject: [PATCH 61/78] feat: compute rewards more XLA friendly --- optimum/neuron/trainers/grpo_trainer.py | 80 ++++++++++++++++++++----- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index c87481a3e..cb2429509 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -567,39 +567,87 @@ def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): device = self.accelerator.device - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + excluded_keys = {"prompt", "completion", "completion_ids"} + keys = [key for key in inputs[0] if key not in excluded_keys] reward_kwargs = {key: [example[key] for example in inputs] for key in keys} reward_kwargs["trainer_state"] = self.state - for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) - ): + # Separate model-based vs callable reward functions by index + model_indices = [] + callable_indices = [] + for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, torch.nn.Module): - if is_conversational(inputs[0]): - from trl.data_utils import apply_chat_template + model_indices.append(i) + else: + callable_indices.append(i) + + # Collect results: list of (index, tensor) tuples + reward_columns = [] + + if model_indices: + # Pre-compute texts once if needed (all models use same text format) + texts = None + is_conv = is_conversational(inputs[0]) + if is_conv: + from trl.data_utils import apply_chat_template + + for i in model_indices: + reward_func = self.reward_funcs[i] + reward_processing_class = self.reward_processing_classes[i] + + if is_conv: messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( - text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, ) reward_inputs = NeuronTrainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] - torch_xla.sync() - else: - output_reward_func = reward_func( - prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + logits = reward_func(**reward_inputs).logits[:, 0] + + reward_columns.append((i, logits)) + + if callable_indices: + callable_rewards_cpu = [] + + for i in callable_indices: + reward_func = self.reward_funcs[i] + output = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, ) - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + # Use float('nan') not torch.nan - avoids tensor creation + callable_rewards_cpu.append([r if r is not None else float("nan") for r in output]) + + # Single tensor creation and transfer for all callable rewards + if callable_rewards_cpu: + callable_tensor = torch.tensor( + callable_rewards_cpu, dtype=torch.float32, device=device + ) # Shape: [num_callable_funcs, num_samples] + + for local_idx, global_idx in enumerate(callable_indices): + reward_columns.append((global_idx, callable_tensor[local_idx])) + + # Sort by original index to maintain correct column order + reward_columns.sort(key=lambda x: x[0]) + + # Stack all columns at once instead of indexed assignment in loop + rewards_per_func = torch.stack([col for _, col in reward_columns], dim=1) - rewards_per_func = self.accelerator.gather(rewards_per_func) torch_xla.sync() + rewards_per_func = self.accelerator.gather(rewards_per_func) return rewards_per_func def _move_model_to_vllm(self): From 998fd64990cbfd889783279bc5f47151cb6cb8df Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 19 Dec 2025 15:01:29 +0100 Subject: [PATCH 62/78] feat: optimization for XLA --- optimum/neuron/trainers/grpo_trainer.py | 341 ++++++++++++------------ optimum/neuron/trainers/trl_utils.py | 154 +++++++---- 2 files changed, 274 insertions(+), 221 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index cb2429509..153be8e26 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -18,6 +18,7 @@ from typing import Any, Callable import datasets +import numpy as np import torch import torch_xla import torch_xla.core.xla_model as xm @@ -57,11 +58,11 @@ from .trl_utils import ( TRL_VERSION, DistributedRepeatSampler, + batch_pad_sequences, nanmax, nanmin, nanstd, neuron_parallel_compile_tokenizer_decoder_method, - pad_or_truncate_to_length, ) @@ -475,6 +476,14 @@ def make_inputs_require_grad(module, input, output): self.fixed_size_obj_collectives = fixed_size_for_obj_collectives + # Pre-create constant tensors for XLA optimization. + # These are used in clamp operations and comparisons. Creating them once avoids + # repeated tensor allocations that cause XLA graph fragmentation. + device = xm.xla_device() + self._one_float = torch.tensor(1.0, dtype=torch.float32, device=device) + self._one_long = torch.tensor(1, dtype=torch.long, device=device) + self._inf_float = torch.tensor(float("inf"), dtype=torch.float32, device=device) + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: if dataset is None: dataset = self.train_dataset @@ -511,7 +520,15 @@ def train( def log(self, logs: dict[str, float], start_time: float | None = None) -> None: mode = "train" if self.model.training else "eval" - metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + # Average the metrics. Values can be either floats or CPU tensors, so we handle both. + # Using sum() works for both types; we call .item() only on tensors when computing the average. + metrics = {} + for key, val_list in self._metrics[mode].items(): + if len(val_list) == 0: + continue + # Convert any tensor values to floats for averaging + float_vals = [v.item() if isinstance(v, torch.Tensor) else v for v in val_list] + metrics[key] = sum(float_vals) / len(float_vals) # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. @@ -618,27 +635,27 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): reward_columns.append((i, logits)) if callable_indices: - callable_rewards_cpu = [] + # Use numpy for intermediate storage to avoid Python list overhead + # and enable efficient single-transfer to XLA device + num_samples = len(prompts) + callable_rewards_np = np.empty((len(callable_indices), num_samples), dtype=np.float32) - for i in callable_indices: - reward_func = self.reward_funcs[i] + for local_idx, global_idx in enumerate(callable_indices): + reward_func = self.reward_funcs[global_idx] output = reward_func( prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs, ) - # Use float('nan') not torch.nan - avoids tensor creation - callable_rewards_cpu.append([r if r is not None else float("nan") for r in output]) + for j, r in enumerate(output): + callable_rewards_np[local_idx, j] = r if r is not None else np.nan - # Single tensor creation and transfer for all callable rewards - if callable_rewards_cpu: - callable_tensor = torch.tensor( - callable_rewards_cpu, dtype=torch.float32, device=device - ) # Shape: [num_callable_funcs, num_samples] + # Single tensor creation and transfer from numpy array + callable_tensor = torch.from_numpy(callable_rewards_np).to(device=device) - for local_idx, global_idx in enumerate(callable_indices): - reward_columns.append((global_idx, callable_tensor[local_idx])) + for local_idx, global_idx in enumerate(callable_indices): + reward_columns.append((global_idx, callable_tensor[local_idx])) # Sort by original index to maintain correct column order reward_columns.sort(key=lambda x: x[0]) @@ -757,43 +774,79 @@ def _get_per_token_logps_and_entropies( image_sizes=None, token_type_ids=None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - torch_xla.sync() - batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + total_batch_size = input_ids.size(0) + batch_size = batch_size or total_batch_size # Chunk inputs into smaller batches to reduce memory peak # Ensure input batch size is divisible by `batch_size` to avoid issues with XLA graph compilation. - if input_ids.size(0) % batch_size != 0: + if total_batch_size % batch_size != 0: raise ValueError( - f"The input_ids batch size must be divisible by `batch_size`, but got {input_ids.shape[0]} and " + f"The input_ids batch size must be divisible by `batch_size`, but got {total_batch_size} and " f"{batch_size}." ) - all_logps = [] - all_entropies = [] - for start in range(0, input_ids.size(0), batch_size): - input_ids_batch = input_ids[start : start + batch_size] - attention_mask_batch = attention_mask[start : start + batch_size] + num_chunks = total_batch_size // batch_size + device = input_ids.device + + # Pre-allocate output tensors to avoid list accumulation and repeated concatenation. + # This creates a single graph for all chunks instead of growing graphs. + all_logps = torch.empty(total_batch_size, logits_to_keep, dtype=torch.float32, device=device) + all_entropies = ( + torch.empty(total_batch_size, logits_to_keep, dtype=torch.float32, device=device) + if compute_entropy + else None + ) + + # Pre-compute VLM slicing indices if needed (avoids .item() calls inside loop). + # For VLMs with image_grid_thw, we need to compute pixel_values slicing indices upfront. + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + # num_images is a list of ints, so we can compute cumulative sums on CPU + cum_imgs = [0] + for n in num_images: + cum_imgs.append(cum_imgs[-1] + n) + + # Compute row boundaries for each sample using CPU-computed indices + rows_per_sample_list = [] + for i in range(len(num_images)): + start_img = cum_imgs[i] + end_img = cum_imgs[i + 1] + rows_per_sample_list.append(rows_per_image[start_img:end_img].sum()) + rows_per_sample = torch.stack(rows_per_sample_list) + # Compute cumulative row indices on device + cum_rows = torch.cat( + [torch.zeros(1, dtype=rows_per_sample.dtype, device=device), rows_per_sample.cumsum(0)] + ) + # Move to CPU once to get all slice indices (single sync instead of per-chunk) + torch_xla.sync() + cum_rows_cpu = cum_rows.cpu().tolist() + + for chunk_idx in range(num_chunks): + start = chunk_idx * batch_size + end = start + batch_size + + input_ids_batch = input_ids[start:end] + attention_mask_batch = attention_mask[start:end] # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: - rows_per_image = image_grid_thw.prod(dim=-1) - rows_per_sample = torch.split(rows_per_image, num_images) - rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) - cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) - # TODO: not support with torch XLA, fix it later. - row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + # Use pre-computed CPU indices to avoid .item() calls + row_start = int(cum_rows_cpu[start]) + row_end = int(cum_rows_cpu[end]) model_inputs["pixel_values"] = pixel_values[row_start:row_end] - cum_imgs = torch.tensor([0] + num_images).cumsum(0) - img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + img_start = cum_imgs[start] + img_end = cum_imgs[end] model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + model_inputs["pixel_values"] = pixel_values[start:end] + if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start:end] if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + model_inputs["image_sizes"] = image_sizes[start:end] if token_type_ids is not None: - model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + model_inputs["token_type_ids"] = token_type_ids[start:end] # Only add logits_to_keep if the model supports it if "logits_to_keep" in self.model_kwarg_keys: @@ -814,22 +867,18 @@ def _get_per_token_logps_and_entropies( completion_ids = input_ids_batch[:, -logits_to_keep:] logps = selective_log_softmax(logits, completion_ids) # compute logprobs - all_logps.append(logps) + + # Write directly to pre-allocated tensor instead of list append + all_logps[start:end] = logps if compute_entropy: with torch.no_grad(): entropies = entropy_from_logits(logits) - all_entropies.append(entropies) - - torch_xla.sync() + all_entropies[start:end] = entropies - logps = torch.cat(all_logps, dim=0) - entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None - - # Force synchronization after computation to ensure graph is re-used. torch_xla.sync() - return logps, entropies + return all_logps, all_entropies def _generate_and_score_completions( self, inputs: list[dict[str, torch.Tensor | Any]] @@ -857,77 +906,46 @@ def _generate_and_score_completions( forward_kwargs, ) = self._generate(prompts, images) - # Convert lists of token IDs to padded tensors - prompt_ids = [ - pad_or_truncate_to_length( - torch.tensor(ids), - self.max_prompt_length, - padding_value=self.pad_token_id, - padding_or_truncate_side="left", - ) - for ids in prompt_ids_list - ] - # prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] - # prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_mask = [ - pad_or_truncate_to_length( - torch.ones(len(ids), dtype=torch.long), - self.max_prompt_length, - padding_value=0, - padding_or_truncate_side="left", - ) - for ids in prompt_ids_list - ] - prompt_ids = torch.stack(prompt_ids, dim=0).to(device) - prompt_mask = torch.stack(prompt_mask, dim=0).to(device) - # prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - # prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - # completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - # completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - # completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - # completion_mask = pad(completion_mask, padding_value=0, padding_side="right") - completion_ids = [ - pad_or_truncate_to_length( - torch.tensor(ids), - self.max_completion_length, - padding_value=self.pad_token_id, - padding_or_truncate_side="right", - ) - for ids in completion_ids_list - ] - completion_mask = [ - pad_or_truncate_to_length( - torch.ones(len(ids), dtype=torch.long), - self.max_completion_length, - padding_value=0, - padding_or_truncate_side="right", - ) - for ids in completion_ids_list - ] - completion_ids = torch.stack(completion_ids, dim=0).to(device) - completion_mask = torch.stack(completion_mask, dim=0).to(device) + # Convert lists of token IDs to padded tensors using XLA-optimized batch padding. + # This avoids creating many small tensors and multiple device transfers. + prompt_ids, prompt_mask = batch_pad_sequences( + prompt_ids_list, + target_length=self.max_prompt_length, + padding_value=self.pad_token_id, + padding_side="left", + dtype=torch.long, + device=device, + ) + + completion_ids, completion_mask = batch_pad_sequences( + completion_ids_list, + target_length=self.max_completion_length, + padding_value=self.pad_token_id, + padding_side="right", + dtype=torch.long, + device=device, + ) if sampling_per_token_logps_list is not None: - # sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] - # sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") - sampling_per_token_logps = [ - pad_or_truncate_to_length( - torch.tensor(logps), - self.max_completion_length, - padding_value=0.0, - padding_or_truncate_side="right", - ) - for logps in sampling_per_token_logps_list - ] - sampling_per_token_logps = torch.stack(sampling_per_token_logps, dim=0).to(device) + sampling_per_token_logps, _ = batch_pad_sequences( + sampling_per_token_logps_list, + target_length=self.max_completion_length, + padding_value=0.0, + padding_side="right", + dtype=torch.float32, + device=device, + ) else: sampling_per_token_logps = None - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask. + # Use tensor operations instead of Python list iteration for XLA compatibility. if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + # Check if last token is NOT eos or pad (meaning sequence was truncated) + last_tokens = completion_ids[:, -1] + # A sequence is NOT truncated if its last token is eos or pad + is_not_truncated = (last_tokens == self.eos_token_id) | (last_tokens == self.pad_token_id) + completion_mask = completion_mask * is_not_truncated.unsqueeze(1).long() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) @@ -1038,7 +1056,8 @@ def _generate_and_score_completions( f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." ) - is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + # Use direct comparison instead of torch.zeros_like to avoid tensor allocation + is_std_zero = std_rewards.abs() < 1e-8 if self.scale_rewards != "none": advantages = advantages / (std_rewards + 1e-4) @@ -1111,10 +1130,12 @@ def _generate_and_score_completions( # torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) # ) # But it is not XLA friendly because it involves dynamic indexing before reduction, so we rewrite it as: + # Use pre-created inf constant (cast to proper dtype if needed) + inf_val = self._inf_float.to(dtype=importance_sampling_ratio.dtype) masked_is_ratio_for_min = torch.where( completion_mask.bool(), importance_sampling_ratio, - torch.tensor(float("inf"), device=device, dtype=importance_sampling_ratio.dtype), + inf_val, ) min_importance_sampling_ratio = masked_is_ratio_for_min.min() # importance_sampling_ratio values are >= 0 (torch.exp) so we can use the same computation as for delta. @@ -1133,13 +1154,12 @@ def _generate_and_score_completions( nanmax(self.accelerator.gather(max_importance_sampling_ratio)) ) - # Move metrics and logs to CPU. + # Move metrics and logs to CPU. Keep metrics as CPU tensors instead of calling .item() + # immediately - this defers the sync overhead to when metrics are actually logged. torch_xla.sync() metrics = move_all_tensor_to_cpu(metrics) logs = move_all_tensor_to_cpu(logs) - metrics = {key: [val.item() for val in value] for key, value in metrics.items()} - # Update the actual metrics and logs. self._metrics[mode].update(metrics) for name in self.reward_func_names: @@ -1176,35 +1196,22 @@ def _generate_and_score_completions( return output def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: - # Original code does the following: - # local = entropies[mask.bool()].float() - # # Use a negative pad_value as a sentinel because entropy values are always >= 0. - # # This guarantees that the sentinel cannot collide with any real entropy value. - # pad_value = -1e9 - - # # Pad across processes so that every rank has the same tensor length - # padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) - # gathered = self.accelerator.gather(padded) - - # # Drop sentinel values (safe because no entropy can be negative) - # gathered = gathered[gathered != pad_value] - - # if gathered.numel() == 0: - # return torch.zeros_like(entropies, dtype=torch.bool) - - # entropy_threshold = torch.quantile(gathered, threshold) - # masked_entropies = entropies * mask.float() - # entropy_mask = masked_entropies >= entropy_threshold - # return entropy_mask & mask.bool() # ensure padding tokens are always masked out + """ + Compute a mask for high-entropy tokens (above the given quantile threshold). + This XLA-optimized implementation avoids: + 1. Dynamic indexing (sorted_values[quantile_idx]) by using torch.gather + 2. Repeated tensor creation by using pre-created constants + 3. Complex control flow that would cause graph breaks + """ pad_value = -1e9 - device = entropies.device + dtype = entropies.dtype - masked_entropies = torch.where( - mask.bool(), - entropies, - torch.tensor(pad_value, device=device, dtype=entropies.dtype), - ) + # Create pad tensor from pre-allocated constant (avoids allocation in hot path) + # Note: pad_value is negative, so we can't use self._inf_float directly + pad_tensor = torch.full_like(entropies[:1, :1], pad_value).expand_as(entropies) + + masked_entropies = torch.where(mask.bool(), entropies, pad_tensor) local_flat = masked_entropies.view(-1) gathered = self.accelerator.gather(local_flat) @@ -1212,22 +1219,23 @@ def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, thr # Sort gathered values, so that pad_value sentinels are at the beginning sorted_values, _ = torch.sort(gathered) - # Compute the number of valid (non-sentinel) values - num_valid = (sorted_values != pad_value).sum() - num_sentinels = (sorted_values == pad_value).sum() - valid_start_idx = num_sentinels + # Compute the number of valid (non-sentinel) values using a tolerance for float comparison + is_sentinel = sorted_values < (pad_value + 1e-6) # pad_value is -1e9 + num_sentinels = is_sentinel.sum() num_valid_values = gathered.numel() - num_sentinels # Get the quantile index and the corresponding entropy threshold value - quantile_idx = valid_start_idx + (threshold * num_valid_values).long() - quantile_idx = quantile_idx.clamp(max=gathered.numel() - 1) - entropy_threshold = sorted_values[quantile_idx] - - # Handle empty case, if everything is sentinel, set threshold to +inf so no token is selected - has_valid = num_valid > 0 - entropy_threshold = torch.where( - has_valid, entropy_threshold, torch.tensor(float("inf"), device=device, dtype=entropies.dtype) - ) + # Use torch.gather instead of dynamic indexing to maintain XLA compatibility + quantile_idx = num_sentinels + (threshold * num_valid_values.float()).long() + quantile_idx = quantile_idx.clamp(min=0, max=gathered.numel() - 1) + + # Use gather for XLA-compatible indexing (gather works with tensor indices) + entropy_threshold = sorted_values.gather(0, quantile_idx.view(1)).squeeze(0) + + # Handle empty case: if everything is sentinel, set threshold to +inf so no token is selected + has_valid = num_valid_values > 0 + inf_val = self._inf_float.to(dtype=dtype) + entropy_threshold = torch.where(has_valid, entropy_threshold, inf_val) entropy_mask = (entropies > entropy_threshold) & mask.bool() return entropy_mask @@ -1289,9 +1297,8 @@ def _compute_loss(self, model, inputs): if self.importance_sampling_level == "token": log_importance_weights = log_ratio elif self.importance_sampling_level == "sequence": - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp( - min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) - ) + # Use pre-created constant instead of creating tensor in hot path + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1) log_importance_weights = log_importance_weights.unsqueeze(-1) else: raise ValueError( @@ -1320,18 +1327,13 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl + # Use scalar min value for clamp instead of creating tensors. + # PyTorch clamp accepts Python scalars which avoids tensor allocation overhead. if self.loss_type == "grpo": - loss = ( - (per_token_loss * completion_mask).sum(-1) - / completion_mask.sum(-1).clamp( - min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) - ) - ).mean() + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1)).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp( - min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) - ) + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) @@ -1345,9 +1347,8 @@ def _compute_loss(self, model, inputs): # Log the metrics mode = "train" if self.model.training else "eval" - completion_token_count = completion_mask.sum().clamp( - min=torch.tensor(1.0, dtype=completion_mask.dtype, device=completion_mask.device) - ) + # Use scalar min value for clamp instead of creating tensor + completion_token_count = completion_mask.sum().clamp(min=1) def masked_batch_mean(x): if x.shape[1] == 1: # when importance_sampling_level == "sequence" @@ -1382,10 +1383,10 @@ def masked_batch_mean(x): gathered_clip_ratio = self.accelerator.gather(clip_ratio) metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) - # torch_xla.sync() # Graph break before moving metrics to CPU. + # Move metrics to CPU but keep as tensors. The log() method will call .item() + # when averaging. This defers sync overhead to logging time. metrics = move_all_tensor_to_cpu(metrics) torch_xla.sync() - metrics = {key: [val.item() for val in value] for key, value in metrics.items()} self._metrics[mode].update(metrics) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index d271fcf32..088c6e079 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -68,53 +68,6 @@ def pad( return output -def pad_or_truncate_to_length( - tensor: torch.Tensor, - length: int, - dim: int = 0, - padding_value: int = 0, - padding_or_truncate_side: Literal["left", "right"] = "right", -) -> torch.Tensor: - """ - Pads or truncates a tensor to a given length along the provided dimension. - - Args: - tensor: Input tensor to pad or truncate - length: Target length - dim: Dimension along which to pad/truncate - padding_value: Value to use for padding - padding_or_truncate_side: Side for both padding and truncation - - "left": Pads on left, truncates from left (keeps last tokens) - - "right": Pads on right, truncates from right (keeps first tokens) - """ - current_length = tensor.shape[dim] - if current_length == length: - return tensor - elif current_length > length: - # Truncate - slice_ = [slice(None)] * tensor.dim() - if padding_or_truncate_side == "left": - # Keep last tokens (truncate from left) - slice_[dim] = slice(current_length - length, current_length) - elif padding_or_truncate_side == "right": - # Keep first tokens (truncate from right) - slice_[dim] = slice(0, length) - else: - raise ValueError("padding_or_truncate_side must be 'left' or 'right'") - return tensor[slice_] - else: - # Pad - padding_shape = list(tensor.shape) - padding_shape[dim] = length - current_length - padding = torch.full(padding_shape, padding_value, dtype=tensor.dtype, device=tensor.device) - if padding_or_truncate_side == "left": - return torch.cat([padding, tensor], dim=dim) - elif padding_or_truncate_side == "right": - return torch.cat([tensor, padding], dim=dim) - else: - raise ValueError("padding_or_truncate_side must be 'left' or 'right'") - - def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: """ Compute the Shannon entropy (in nats) for each row of *logits* in a memory-efficient way. @@ -124,6 +77,9 @@ def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Te along this flattened dimension, reducing peak memory usage. The result is reshaped back to match the input's leading dimensions. + This implementation uses pre-allocated output tensors instead of list accumulation to avoid + XLA graph fragmentation and repeated tensor allocations. + Args: logits (`torch.Tensor`): Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions @@ -141,14 +97,19 @@ def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Te # Flatten all leading dimensions into one flat_logits = logits.reshape(-1, num_classes) + total_rows = flat_logits.size(0) + + # Pre-allocate output tensor to avoid list accumulation + entropies = torch.empty(total_rows, dtype=logits.dtype, device=logits.device) - entropies = [] - for chunk in flat_logits.split(chunk_size, dim=0): + # Process in chunks, writing directly to pre-allocated tensor + for start in range(0, total_rows, chunk_size): + end = min(start + chunk_size, total_rows) + chunk = flat_logits[start:end] logps = F.log_softmax(chunk, dim=-1) chunk_entropy = -(torch.exp(logps) * logps).sum(-1) - entropies.append(chunk_entropy) + entropies[start:end] = chunk_entropy - entropies = torch.cat(entropies, dim=0) return entropies.reshape(original_shape) @@ -175,6 +136,97 @@ def neuron_parallel_compile_tokenizer_decoder_method( return "dummy" +def batch_pad_sequences( + sequences: list[list[int | float]], + target_length: int, + padding_value: int | float = 0, + padding_side: Literal["left", "right"] = "right", + dtype: torch.dtype = torch.long, + device: torch.device | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + XLA-optimized batch padding of variable-length sequences. + + Unlike per-sequence padding with list comprehensions, this function: + 1. Pre-allocates output arrays using numpy (fast CPU operations) + 2. Transfers to device as a single operation (one host->device copy) + 3. Creates the mask alongside the padded sequences (no separate allocation) + + This avoids creating many small tensors and multiple device transfers that cause + XLA graph fragmentation. + + Args: + sequences (`list[list[int | float]]`): + List of variable-length sequences. Each sequence is a list of token IDs (ints) + or log probabilities (floats). + target_length (`int`): + Fixed target length for all sequences. Sequences longer than this will be + truncated; shorter sequences will be padded. + padding_value (`int | float`, *optional*, defaults to `0`): + Value to use for padding positions. + padding_side (`Literal["left", "right"]`, *optional*, defaults to `"right"`): + Side on which to add padding. Also determines truncation behavior: + - `"left"`: Pads on left, truncates from left (keeps last tokens) + - `"right"`: Pads on right, truncates from right (keeps first tokens) + dtype (`torch.dtype`, *optional*, defaults to `torch.long`): + Output tensor dtype for the padded sequences. + device (`torch.device | None`, *optional*, defaults to `None`): + Target device for the output tensors. If `None`, tensors remain on CPU. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: + A tuple of `(padded_sequences, mask)` where: + - `padded_sequences` has shape `(batch_size, target_length)` and dtype `dtype` + - `mask` has shape `(batch_size, target_length)` and dtype `torch.long`, with + `1` for real tokens and `0` for padding positions + """ + batch_size = len(sequences) + + # Determine numpy dtype for intermediate computation + if dtype in (torch.float32, torch.float64, torch.float16, torch.bfloat16): + np_dtype = np.float32 + else: + np_dtype = np.int64 + + # Pre-allocate numpy arrays (fast CPU operations) + padded = np.full((batch_size, target_length), padding_value, dtype=np_dtype) + mask = np.zeros((batch_size, target_length), dtype=np.int64) + + for i, seq in enumerate(sequences): + seq_len = len(seq) + if seq_len == 0: + continue + + if seq_len >= target_length: + # Truncation needed + if padding_side == "left": + # Keep last target_length tokens + padded[i] = seq[seq_len - target_length :] + else: + # Keep first target_length tokens + padded[i] = seq[:target_length] + mask[i] = 1 + else: + # Padding needed + if padding_side == "left": + start_idx = target_length - seq_len + padded[i, start_idx:] = seq + mask[i, start_idx:] = 1 + else: + padded[i, :seq_len] = seq + mask[i, :seq_len] = 1 + + # Single conversion and transfer to device + padded_tensor = torch.from_numpy(padded).to(dtype=dtype) + mask_tensor = torch.from_numpy(mask).to(dtype=torch.long) + + if device is not None: + padded_tensor = padded_tensor.to(device) + mask_tensor = mask_tensor.to(device) + + return padded_tensor, mask_tensor + + def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ XLA-compatible version of nanmin that doesn't use dynamic indexing. From 4407687d9ca9f2d6f81d09e2656af1080fe419b7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 29 Jan 2026 14:36:39 +0100 Subject: [PATCH 63/78] debug: training produces NaNs --- .../grpo_qwen3/finetune_grpo_qwen3.py | 2 +- optimum/neuron/trainers/extras/vllm_client.py | 29 +++++++++++-------- optimum/neuron/trainers/trl_utils.py | 10 +++---- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index c0425d511..d23f7cfd3 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -160,7 +160,7 @@ def train(model_id, tokenizer, dataset, training_args): processing_class=tokenizer, peft_config=lora_config, # To do: disable this fake client, only for development without vLLM server. - # vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), + vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), ) # Train the model diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index 7fce0ef2a..d8d6d15c3 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -149,15 +149,16 @@ def init_communicator(self, device: Union[torch.device, str, int] = 0): class MockVLLMClient(VLLMClient): """ - Mock VLLMClient that generates random completions and triggers XLA compilation without vLLM server. + Mock VLLMClient that generates completions without a vLLM server. - Used for neuron_parallel_compile and testing. Generates random tokens, not real LLM outputs. + Used for neuron_parallel_compile and testing. Generates completions by cycling + through prompt tokens (echo mode), producing deterministic, non-garbage output. Args: tokenizer: Tokenizer for encoding/decoding max_completion_length: Maximum completion length min_completion_length: Minimum completion length (default: 10) - seed: Random seed for reproducibility + seed: Random seed for reproducibility (used for completion length variation) """ def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None): @@ -169,7 +170,7 @@ def __init__(self, tokenizer, max_completion_length=256, min_completion_length=1 logger.warning( "Using MockVLLMClient for neuron_parallel_compile or testing. " - "This generates random dummy completions and should only be used for compilation/testing." + "This generates echo completions and should only be used for compilation/testing." ) def generate( @@ -188,7 +189,7 @@ def generate( generation_kwargs=None, ): """ - Generate random completions with random lengths. + Generate completions by cycling through prompt tokens (echo mode). Returns dict with prompt_ids, completion_ids, and logprobs. """ @@ -196,10 +197,9 @@ def generate( completion_ids = [] logprobs = [] - # Determine vocab range (avoid special tokens) + # Fallback tokens if prompt is empty vocab_size = self.tokenizer.vocab_size - min_token_id = min(100, vocab_size - 1) - max_token_id = vocab_size - 1 + fallback_token_id = min(100, vocab_size - 1) for prompt in prompts: # Tokenize prompt @@ -217,12 +217,17 @@ def generate( max_len = min(max_tokens, self.max_completion_length) completion_length = self.random.randint(self.min_completion_length, max_len) - # Generate random tokens from safe vocab range - completion = [self.random.randint(min_token_id, max_token_id) for _ in range(completion_length)] + # Echo mode: cycle through prompt tokens + if len(prompt_tokens) > 0: + completion = [prompt_tokens[i % len(prompt_tokens)] for i in range(completion_length)] + else: + # Fallback if prompt is empty + completion = [fallback_token_id] * completion_length + completion_ids.append(completion) - # Generate realistic random logprobs (typical range: -2 to -10) - completion_logprobs = [-self.random.uniform(2.0, 8.0) for _ in range(completion_length)] + # Logprobs: simulate higher confidence for echoed tokens + completion_logprobs = [-self.random.uniform(0.5, 2.0) for _ in range(completion_length)] logprobs.append(completion_logprobs) return { diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 088c6e079..1ad434c18 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -234,9 +234,8 @@ def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ mask = torch.isnan(tensor) filled = torch.where(mask, torch.full_like(tensor, float("inf")), tensor) - min_value = torch.amin(filled) - all_nan = mask.all() - return torch.where(all_nan, torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device), min_value) + min_value = torch.min(filled) + return min_value def nanmax(tensor: torch.Tensor) -> torch.Tensor: @@ -246,9 +245,8 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: """ mask = torch.isnan(tensor) filled = torch.where(mask, torch.full_like(tensor, float("-inf")), tensor) - min_value = torch.amax(filled) - all_nan = mask.all() - return torch.where(all_nan, torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device), min_value) + max_value = torch.max(filled) + return max_value def nanstd(tensor: torch.Tensor, unbiased: bool = False) -> torch.Tensor: From 131df144a94f939a9908934205e1a7220565d555 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 29 Jan 2026 18:54:53 +0100 Subject: [PATCH 64/78] fix: no NaNs anymore --- .../grpo_qwen3/finetune_grpo_qwen3.py | 2 +- .../grpo_qwen3/finetune_grpo_qwen3.sh | 8 +-- optimum/neuron/trainers/grpo_trainer.py | 52 ++++++++++++------- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index d23f7cfd3..c0425d511 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -160,7 +160,7 @@ def train(model_id, tokenizer, dataset, training_args): processing_class=tokenizer, peft_config=lora_config, # To do: disable this fake client, only for development without vLLM server. - vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), + # vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), ) # Train the model diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index 264d2d7b7..c4fdbc017 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -10,14 +10,14 @@ export MALLOC_ARENA_MAX=64 # Host OOM mitigation # export FI_EFA_FORK_SAFE=1 # Variables for training -PROCESSES_PER_NODE=1 +PROCESSES_PER_NODE=2 NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT TP_DEGREE=1 BS=1 GRADIENT_ACCUMULATION_STEPS=1 # Smaller for GRPO due to generation overhead LOGGING_STEPS=1 -# MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing -MODEL_NAME="yujiepan/qwen3-tiny-random" # Use smaller model for testing +MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing +# MODEL_NAME="yujiepan/qwen3-tiny-random" # Use smaller model for testing # MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" # MODEL_NAME="HuggingFaceTB/SmolLM2-135M" OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" @@ -50,7 +50,7 @@ torchrun $DISTRIBUTED_ARGS finetune_grpo_qwen3.py \ --per_device_train_batch_size $BS \ --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ --gradient_checkpointing \ - --learning_rate 5e-8 \ + --learning_rate 5e-4 \ --bf16 \ --tensor_parallel_size $TP_DEGREE \ --zero_1 \ diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 153be8e26..a774c7cae 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -795,6 +795,8 @@ def _get_per_token_logps_and_entropies( if compute_entropy else None ) + # all_logps = [] + # all_entropies = [] if compute_entropy else None # Pre-compute VLM slicing indices if needed (avoids .item() calls inside loop). # For VLMs with image_grid_thw, we need to compute pixel_values slicing indices upfront. @@ -870,13 +872,15 @@ def _get_per_token_logps_and_entropies( # Write directly to pre-allocated tensor instead of list append all_logps[start:end] = logps + # all_logps.append(logps) if compute_entropy: with torch.no_grad(): entropies = entropy_from_logits(logits) all_entropies[start:end] = entropies + # all_entropies.append(entropies) - torch_xla.sync() + torch_xla.sync() return all_logps, all_entropies @@ -990,7 +994,8 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=self.vllm_importance_sampling_cap + importance_sampling_ratio, + max=torch.tensor(self.vllm_importance_sampling_cap, device=importance_sampling_ratio.device), ) # Compute the per-token log probabilities for the reference model @@ -1198,13 +1203,11 @@ def _generate_and_score_completions( def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: """ Compute a mask for high-entropy tokens (above the given quantile threshold). - - This XLA-optimized implementation avoids: - 1. Dynamic indexing (sorted_values[quantile_idx]) by using torch.gather - 2. Repeated tensor creation by using pre-created constants - 3. Complex control flow that would cause graph breaks """ pad_value = -1e9 + gathered = self.accelerator.gather(entropies) + return entropies + pad_value = -1e9 dtype = entropies.dtype # Create pad tensor from pre-allocated constant (avoids allocation in hot path) @@ -1227,7 +1230,10 @@ def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, thr # Get the quantile index and the corresponding entropy threshold value # Use torch.gather instead of dynamic indexing to maintain XLA compatibility quantile_idx = num_sentinels + (threshold * num_valid_values.float()).long() - quantile_idx = quantile_idx.clamp(min=0, max=gathered.numel() - 1) + quantile_idx = quantile_idx.clamp( + min=torch.tensor(0, device=quantile_idx.device), + max=torch.tensor(gathered.numel() - 1, device=quantile_idx.device), + ) # Use gather for XLA-compatible indexing (gather works with tensor indices) entropy_threshold = sorted_values.gather(0, quantile_idx.view(1)).squeeze(0) @@ -1297,8 +1303,10 @@ def _compute_loss(self, model, inputs): if self.importance_sampling_level == "token": log_importance_weights = log_ratio elif self.importance_sampling_level == "sequence": - # Use pre-created constant instead of creating tensor in hot path - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1) + # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp( + min=torch.tensor(1, device=completion_mask.device) + ) log_importance_weights = log_importance_weights.unsqueeze(-1) else: raise ValueError( @@ -1309,11 +1317,15 @@ def _compute_loss(self, model, inputs): # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + coef_2 = torch.clamp( + coef_1, + torch.tensor(1 - self.epsilon_low, device=coef_1.device), + torch.tensor(1 + self.epsilon_high, device=coef_1.device), + ) # Two-sided clipping if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) + coef_1 = torch.clamp(coef_1, max=torch.tensor(self.args.delta, device=coef_1.device)) per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) @@ -1327,13 +1339,17 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl - # Use scalar min value for clamp instead of creating tensors. - # PyTorch clamp accepts Python scalars which avoids tensor allocation overhead. + # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals. if self.loss_type == "grpo": - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1)).mean() + loss = ( + (per_token_loss * completion_mask).sum(-1) + / completion_mask.sum(-1).clamp(min=torch.tensor(1, device=completion_mask.device)) + ).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1) + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp( + min=torch.tensor(1, device=completion_mask.device) + ) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) @@ -1347,8 +1363,8 @@ def _compute_loss(self, model, inputs): # Log the metrics mode = "train" if self.model.training else "eval" - # Use scalar min value for clamp instead of creating tensor - completion_token_count = completion_mask.sum().clamp(min=1) + # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals + completion_token_count = completion_mask.sum().clamp(min=torch.tensor(1, device=completion_mask.device)) def masked_batch_mean(x): if x.shape[1] == 1: # when importance_sampling_level == "sequence" From 7a0167eac45e8971cde2d4a20c728f8a5b26b27b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 30 Jan 2026 15:29:56 +0100 Subject: [PATCH 65/78] rewrite _get_per_token_logps_and_entropies for better breaks --- .../grpo_qwen3/finetune_grpo_qwen3.sh | 4 +- optimum/neuron/trainers/grpo_trainer.py | 90 +++++++------------ 2 files changed, 35 insertions(+), 59 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index c4fdbc017..ff27a8502 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -26,8 +26,8 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) # GRPO-specific variables NUM_GENERATIONS=4 # Number of completions per prompt (G in paper) -MAX_PROMPT_LENGTH=512 -MAX_COMPLETION_LENGTH=512 +MAX_PROMPT_LENGTH=32 +MAX_COMPLETION_LENGTH=32 TEMPERATURE=0.8 STEPS_PER_GENERATION=4 # Generate every N steps to amortize generation cost diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index a774c7cae..3b294f00e 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -785,70 +785,40 @@ def _get_per_token_logps_and_entropies( ) num_chunks = total_batch_size // batch_size - device = input_ids.device - - # Pre-allocate output tensors to avoid list accumulation and repeated concatenation. - # This creates a single graph for all chunks instead of growing graphs. - all_logps = torch.empty(total_batch_size, logits_to_keep, dtype=torch.float32, device=device) - all_entropies = ( - torch.empty(total_batch_size, logits_to_keep, dtype=torch.float32, device=device) - if compute_entropy - else None - ) - # all_logps = [] - # all_entropies = [] if compute_entropy else None - - # Pre-compute VLM slicing indices if needed (avoids .item() calls inside loop). - # For VLMs with image_grid_thw, we need to compute pixel_values slicing indices upfront. - if image_grid_thw is not None and pixel_values is not None: - rows_per_image = image_grid_thw.prod(dim=-1) - # num_images is a list of ints, so we can compute cumulative sums on CPU - cum_imgs = [0] - for n in num_images: - cum_imgs.append(cum_imgs[-1] + n) - - # Compute row boundaries for each sample using CPU-computed indices - rows_per_sample_list = [] - for i in range(len(num_images)): - start_img = cum_imgs[i] - end_img = cum_imgs[i + 1] - rows_per_sample_list.append(rows_per_image[start_img:end_img].sum()) - rows_per_sample = torch.stack(rows_per_sample_list) - # Compute cumulative row indices on device - cum_rows = torch.cat( - [torch.zeros(1, dtype=rows_per_sample.dtype, device=device), rows_per_sample.cumsum(0)] - ) - # Move to CPU once to get all slice indices (single sync instead of per-chunk) - torch_xla.sync() - cum_rows_cpu = cum_rows.cpu().tolist() + + all_logps = [] + all_entropies = [] if compute_entropy else None + + chunked_input_ids = torch.split(input_ids, batch_size, dim=0) + chunked_attention_mask = torch.split(attention_mask, batch_size, dim=0) for chunk_idx in range(num_chunks): - start = chunk_idx * batch_size - end = start + batch_size + input_ids_batch = chunked_input_ids[chunk_idx] + attention_mask_batch = chunked_attention_mask[chunk_idx] - input_ids_batch = input_ids[start:end] - attention_mask_batch = attention_mask[start:end] + # TODO: rewrite this without slicing to avoid XLA graph recompilations + start = chunk_idx * batch_size # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None and pixel_values is not None: - # Use pre-computed CPU indices to avoid .item() calls - row_start = int(cum_rows_cpu[start]) - row_end = int(cum_rows_cpu[end]) + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() model_inputs["pixel_values"] = pixel_values[row_start:row_end] - img_start = cum_imgs[start] - img_end = cum_imgs[end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: - model_inputs["pixel_values"] = pixel_values[start:end] - + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask[start:end] + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes[start:end] + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] if token_type_ids is not None: - model_inputs["token_type_ids"] = token_type_ids[start:end] + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] # Only add logits_to_keep if the model supports it if "logits_to_keep" in self.model_kwarg_keys: @@ -857,6 +827,9 @@ def _get_per_token_logps_and_entropies( model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + # Sync before forward to isolate the forward pass graph + torch_xla.sync() + logits = model(**model_inputs).logits # Exclude the last value: it corresponds to the next token pred @@ -870,18 +843,21 @@ def _get_per_token_logps_and_entropies( completion_ids = input_ids_batch[:, -logits_to_keep:] logps = selective_log_softmax(logits, completion_ids) # compute logprobs - # Write directly to pre-allocated tensor instead of list append - all_logps[start:end] = logps - # all_logps.append(logps) - if compute_entropy: with torch.no_grad(): entropies = entropy_from_logits(logits) - all_entropies[start:end] = entropies - # all_entropies.append(entropies) + # Sync after forward to materialize results before list append torch_xla.sync() + all_logps.append(logps) + if compute_entropy: + all_entropies.append(entropies) + + # Single concat at the end - one clean graph + all_logps = torch.cat(all_logps, dim=0) + all_entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + return all_logps, all_entropies def _generate_and_score_completions( From ac366874dd4b774ac2a06d3fc556834d5a97bf6a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 30 Jan 2026 15:55:22 +0100 Subject: [PATCH 66/78] optimize _compute_loss --- optimum/neuron/trainers/grpo_trainer.py | 45 ++++++++++++++----------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 3b294f00e..a88e7f2c2 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -1281,7 +1281,7 @@ def _compute_loss(self, model, inputs): elif self.importance_sampling_level == "sequence": # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp( - min=torch.tensor(1, device=completion_mask.device) + min=self._one_float, ) log_importance_weights = log_importance_weights.unsqueeze(-1) else: @@ -1315,17 +1315,13 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl - # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals. if self.loss_type == "grpo": loss = ( - (per_token_loss * completion_mask).sum(-1) - / completion_mask.sum(-1).clamp(min=torch.tensor(1, device=completion_mask.device)) + (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=self._one_float) ).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp( - min=torch.tensor(1, device=completion_mask.device) - ) + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=self._one_float) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) @@ -1339,8 +1335,7 @@ def _compute_loss(self, model, inputs): # Log the metrics mode = "train" if self.model.training else "eval" - # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals - completion_token_count = completion_mask.sum().clamp(min=torch.tensor(1, device=completion_mask.device)) + completion_token_count = completion_mask.sum().clamp(min=self._one_float) def masked_batch_mean(x): if x.shape[1] == 1: # when importance_sampling_level == "sequence" @@ -1349,13 +1344,14 @@ def masked_batch_mean(x): return (x * completion_mask).sum() / completion_token_count metrics = defaultdict(list) + metrics_to_gather = {} if self.beta != 0.0: mean_kl = masked_batch_mean(per_token_kl) - metrics["kl"].append(self.accelerator.gather(mean_kl).nanmean()) + metrics_to_gather["kl"] = mean_kl mean_entropy = masked_batch_mean(entropies) - metrics["entropy"].append(self.accelerator.gather(mean_entropy).nanmean()) + metrics_to_gather["entropy"] = mean_entropy # Compute the clipped probability ratios is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) @@ -1366,14 +1362,25 @@ def masked_batch_mean(x): high_clip = masked_batch_mean(is_high_clipped.float()) clip_ratio = masked_batch_mean(is_region_clipped.float()) - gathered_low_clip = self.accelerator.gather(low_clip) - metrics["clip_ratio/low_mean"].append(gathered_low_clip.nanmean()) - metrics["clip_ratio/low_min"].append(nanmin(gathered_low_clip)) - gathered_high_clip = self.accelerator.gather(high_clip) - metrics["clip_ratio/high_mean"].append(gathered_high_clip.nanmean()) - metrics["clip_ratio/high_max"].append(nanmax(gathered_high_clip)) - gathered_clip_ratio = self.accelerator.gather(clip_ratio) - metrics["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean()) + metrics_to_gather["low_clip"] = low_clip + metrics_to_gather["high_clip"] = high_clip + metrics_to_gather["clip_ratio"] = clip_ratio + + stacked_metrics = torch.stack(list(metrics_to_gather.values()), dim=0) + gathered_stacked_metrics = self.accelerator.gather(stacked_metrics) + gathered_metrics = {} + + for i, key in enumerate(metrics_to_gather.keys()): + gathered_metrics[key] = gathered_stacked_metrics[i] + + if self.beta != 0.0: + metrics["kl"].append(gathered_metrics["kl"].nanmean()) + metrics["entropy"].append(gathered_metrics["entropy"].nanmean()) + metrics["clip_ratio/low_mean"].append(gathered_metrics["low_clip"].nanmean()) + metrics["clip_ratio/low_min"].append(nanmin(gathered_metrics["low_clip"])) + metrics["clip_ratio/high_mean"].append(gathered_metrics["high_clip"].nanmean()) + metrics["clip_ratio/high_max"].append(nanmax(gathered_metrics["high_clip"])) + metrics["clip_ratio/region_mean"].append(gathered_metrics["clip_ratio"].nanmean()) # Move metrics to CPU but keep as tensors. The log() method will call .item() # when averaging. This defers sync overhead to logging time. From 8c816f6a34b5958d44fa2bcbab80e5b3f20a541f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 30 Jan 2026 16:08:50 +0100 Subject: [PATCH 67/78] optimize _generate_and_score_completions --- optimum/neuron/trainers/grpo_trainer.py | 61 +++++++++++++++---------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index a88e7f2c2..38303ba83 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -1065,23 +1065,34 @@ def _generate_and_score_completions( metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) # Log prompt and completion texts - self._logs["prompt"].extend( - gather_object_from_data_parallel_group(prompts_text, fixed_size=self.fixed_size_obj_collectives) - ) - self._logs["completion"].extend( - gather_object_from_data_parallel_group(completions_text, fixed_size=self.fixed_size_obj_collectives) + to_gather = [prompts_text, completions_text] + if images is not None: + to_gather.append(images) + + if self.fixed_size_obj_collectives is not None: + fixed_size = len(to_gather) * self.fixed_size_obj_collectives + else: + fixed_size = None + + gathered = gather_object_from_data_parallel_group( + to_gather, + fixed_size=fixed_size, ) + gathered_prompts_text = [item[0] for item in gathered] + gathered_completions_text = [item[1] for item in gathered] + self._logs["prompt"].extend(gathered_prompts_text) + self._logs["completion"].extend(gathered_completions_text) + + if images is not None: + gathered_images = [item[2] for item in gathered] + self._logs["images"].extend(gathered_images) + logs["rewards"] = {} logs["advantages"] = [] for i, name in enumerate(self.reward_func_names): logs["rewards"][name] = rewards_per_func[:, i] logs["advantages"] = all_process_advantages - if images is not None: - self._logs["images"].extend( - gather_object_from_data_parallel_group(images, fixed_size=self.fixed_size_obj_collectives) - ) - if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) # Original code was: @@ -1096,9 +1107,6 @@ def _generate_and_score_completions( # We can simply take the max of the masked delta because values in delta are >= 0 (torch.abs). max_delta = delta_masked.max() - metrics["sampling/sampling_logp_difference/mean"].append(self.accelerator.gather(mean_delta).mean()) - metrics["sampling/sampling_logp_difference/max"].append(self.accelerator.gather(max_delta).max()) - # Original code was: # flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] # min_importance_sampling_ratio = ( @@ -1125,18 +1133,23 @@ def _generate_and_score_completions( mean_importance_sampling_ratio = sum_flat_is_ratio / (completion_mask_count + 1e-10) max_importance_sampling_ratio = flat_is_ratio_masked.max() - metrics["sampling/importance_sampling_ratio/min"].append( - nanmin(self.accelerator.gather(min_importance_sampling_ratio)) - ) - metrics["sampling/importance_sampling_ratio/mean"].append( - self.accelerator.gather(mean_importance_sampling_ratio).nanmean() - ) - metrics["sampling/importance_sampling_ratio/max"].append( - nanmax(self.accelerator.gather(max_importance_sampling_ratio)) - ) + sampling_metrics_to_gather = { + "mean_delta": mean_delta, + "max_delta": max_delta, + "min_is_ratio": min_importance_sampling_ratio, + "mean_is_ratio": mean_importance_sampling_ratio, + "max_is_ratio": max_importance_sampling_ratio, + } + + stacked = torch.stack(list(sampling_metrics_to_gather.values()), dim=0) + gathered_stacked = self.accelerator.gather(stacked) + gathered = {k: gathered_stacked[i] for i, k in enumerate(sampling_metrics_to_gather.keys())} + metrics["sampling/sampling_logp_difference/mean"].append(gathered["mean_delta"].mean()) + metrics["sampling/sampling_logp_difference/max"].append(gathered["max_delta"].max()) + metrics["sampling/importance_sampling_ratio/min"].append(nanmin(gathered["min_is_ratio"])) + metrics["sampling/importance_sampling_ratio/mean"].append(gathered["mean_is_ratio"].nanmean()) + metrics["sampling/importance_sampling_ratio/max"].append(nanmax(gathered["max_is_ratio"])) - # Move metrics and logs to CPU. Keep metrics as CPU tensors instead of calling .item() - # immediately - this defers the sync overhead to when metrics are actually logged. torch_xla.sync() metrics = move_all_tensor_to_cpu(metrics) logs = move_all_tensor_to_cpu(logs) From 4fa42bacc2ccd4a7946ec678327f508e5828d5d8 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 11:16:43 +0100 Subject: [PATCH 68/78] fix: use separate model for ref model to avoid XLA NaN issues --- examples/training/grpo_qwen3/finetune_grpo_qwen3.py | 2 +- examples/training/grpo_qwen3/finetune_grpo_qwen3.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py index c0425d511..d23f7cfd3 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py @@ -160,7 +160,7 @@ def train(model_id, tokenizer, dataset, training_args): processing_class=tokenizer, peft_config=lora_config, # To do: disable this fake client, only for development without vLLM server. - # vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), + vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), ) # Train the model diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh index ff27a8502..ef7a3d793 100755 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh @@ -26,8 +26,8 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) # GRPO-specific variables NUM_GENERATIONS=4 # Number of completions per prompt (G in paper) -MAX_PROMPT_LENGTH=32 -MAX_COMPLETION_LENGTH=32 +MAX_PROMPT_LENGTH=256 +MAX_COMPLETION_LENGTH=768 TEMPERATURE=0.8 STEPS_PER_GENERATION=4 # Generate every N steps to amortize generation cost From 8fc448ab3a8b7a8591f2b59772e6b11219e8fbe7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 11:21:55 +0100 Subject: [PATCH 69/78] fix: use separate model for ref model to avoid XLA NaN issues --- optimum/neuron/trainers/grpo_trainer.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index 38303ba83..f655bfb58 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -383,7 +383,11 @@ def make_inputs_require_grad(module, input, output): if self.beta == 0.0: self.ref_model = None elif isinstance(model, NeuronPeftModel): - self.ref_model = None + # Create reference model using base model class + # Original implementation was disabling adapters, but we create a separate model instance instead for XLA compatibility. + base_model_class = model.get_base_model().__class__ + base_trn_config = model.get_base_model().trn_config + self.ref_model = base_model_class.from_pretrained(model_id, base_trn_config) else: # Create reference model using NeuronModelForCausalLM self.ref_model = NeuronModelForCausalLM.from_pretrained( @@ -987,16 +991,9 @@ def _generate_and_score_completions( **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: - with self.model.disable_adapter(): - ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) + # Here the original implementation used `model.disable_adapters()` instead of having a copy for the + # reference model. We removed it here because it broke with XLA. + raise ValueError("Ref model is None but beta is not 0.0") else: ref_per_token_logps = None From 39660dccbccb771554b1ac8f77a5bc58365447ef Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 11:22:19 +0100 Subject: [PATCH 70/78] chore: vllm_client.py remove unused functions --- optimum/neuron/trainers/extras/vllm_client.py | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index d8d6d15c3..78c6a8842 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -15,10 +15,8 @@ import atexit import random -import socket import time from collections import namedtuple -from contextlib import closing from typing import Union import requests @@ -43,22 +41,6 @@ class StatelessProcessGroup: Group = namedtuple("Group", "barrier") -def find_closest_port(host: str, start_port: int, max_attempts: int = 100) -> int: - for port in range(start_port, start_port + max_attempts): - if is_port_available(host, port): - return port - raise RuntimeError(f"No available port found in range {start_port} to {start_port + max_attempts - 1}") - - -def is_port_available(host: str, port: int) -> bool: - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - try: - sock.bind((host, port)) - return True - except OSError: - return False - - class CPUCommunicator: def __init__(self, store, rank): self.rank = rank @@ -90,17 +72,11 @@ def __init__( group_port: int = 51216, connection_timeout: float = 0.0, ): - # free_group_port = find_closest_port(host, group_port) - # if free_group_port != group_port: - # logger.warning( - # f"Requested group_port {group_port} is not available. Using closest available port {free_group_port} instead." - # ) - free_group_port = group_port super().__init__( base_url=base_url, host=host, server_port=server_port, - group_port=free_group_port, + group_port=group_port, connection_timeout=connection_timeout, ) From 6828a9b1bc56902d9b1c0f68b8ef2bc25d2d5474 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 11:54:30 +0100 Subject: [PATCH 71/78] chore: remove useless docstrings in vllm_client.py --- optimum/neuron/trainers/extras/vllm_client.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py index 78c6a8842..f3a0c3fd8 100644 --- a/optimum/neuron/trainers/extras/vllm_client.py +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -129,16 +129,9 @@ class MockVLLMClient(VLLMClient): Used for neuron_parallel_compile and testing. Generates completions by cycling through prompt tokens (echo mode), producing deterministic, non-garbage output. - - Args: - tokenizer: Tokenizer for encoding/decoding - max_completion_length: Maximum completion length - min_completion_length: Minimum completion length (default: 10) - seed: Random seed for reproducibility (used for completion length variation) """ def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None): - # Don't call super().__init__() - we don't need server connection self.tokenizer = tokenizer self.max_completion_length = max_completion_length self.min_completion_length = min(min_completion_length, max_completion_length) @@ -164,11 +157,6 @@ def generate( guided_decoding_regex=None, generation_kwargs=None, ): - """ - Generate completions by cycling through prompt tokens (echo mode). - - Returns dict with prompt_ids, completion_ids, and logprobs. - """ prompt_ids = [] completion_ids = [] logprobs = [] @@ -213,17 +201,13 @@ def generate( } def init_communicator(self, device): - """No-op: mock has no communicator.""" pass def update_named_param(self, name, weights): - """No-op: mock has no model to update.""" pass def reset_prefix_cache(self): - """No-op: mock has no cache.""" pass def close_communicator(self): - """No-op: mock has no communicator.""" pass From 69252e83146eca3a6ed40bd80bfef6da39929b53 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 12:01:13 +0100 Subject: [PATCH 72/78] chore: add safeguard for the GRPO feature --- optimum/neuron/trainers/grpo_config.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/trainers/grpo_config.py b/optimum/neuron/trainers/grpo_config.py index 89af7c99d..0a498c8b0 100644 --- a/optimum/neuron/trainers/grpo_config.py +++ b/optimum/neuron/trainers/grpo_config.py @@ -39,6 +39,14 @@ class NeuronGRPOConfig(NeuronTrainingArguments, GRPOConfig): with GRPOConfig for GRPO algorithm parameters. """ + experimental: bool = field( + default=False, + metadata={ + "help": "NeuronGRPOTrainer is experimental and not production-ready. Set to `True` to acknowledge this and " + "proceed. If `False` (the default), an error will be raised at initialization." + }, + ) + use_vllm: bool = field( default=True, metadata={ @@ -48,26 +56,27 @@ class NeuronGRPOConfig(NeuronTrainingArguments, GRPOConfig): ) def __post_init__(self): + if not self.experimental: + raise ValueError( + "NeuronGRPOTrainer is experimental and not production-ready. To proceed, set `experimental=True` in " + "your NeuronGRPOConfig. This flag exists to ensure users are aware of the current state of the implementation." + ) + # For now, NeuronGRPOTrainer requires vLLM for generation, no other way is supported. if not self.use_vllm: raise ValueError("NeuronGRPOTrainer requires `use_vllm` to be set to `True`.") - # Handle bf16 default (from GRPOConfig) self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - # Call NeuronTrainingArguments.__post_init__ to initialize Neuron-specific settings NeuronTrainingArguments.__post_init__(self) - # Convert scale_rewards boolean to string (from GRPOConfig) self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) num_processes = self.world_size - # The current default effective batch size if self.generation_batch_size is None and self.steps_per_generation is None: self.steps_per_generation = self.gradient_accumulation_steps self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation elif self.generation_batch_size is not None and self.steps_per_generation is None: - # Just ensure the value is divisible by the global batch size if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: raise ValueError( f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " From c2e6582b9d71a2f01614aebaa6e0b9472755826f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 12:18:13 +0100 Subject: [PATCH 73/78] chore: grpo_trainer.py cleanup --- optimum/neuron/trainers/grpo_trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py index f655bfb58..9008a7dcb 100644 --- a/optimum/neuron/trainers/grpo_trainer.py +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -234,7 +234,6 @@ def make_inputs_require_grad(module, input, output): self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) # Reward functions - for now, only support callable reward functions - # TODO: Add support for reward models when they can be properly loaded on Neuron if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] @@ -384,7 +383,8 @@ def make_inputs_require_grad(module, input, output): self.ref_model = None elif isinstance(model, NeuronPeftModel): # Create reference model using base model class - # Original implementation was disabling adapters, but we create a separate model instance instead for XLA compatibility. + # Original implementation was disabling adapters, but we create a separate model + # instance instead for XLA compatibility. base_model_class = model.get_base_model().__class__ base_trn_config = model.get_base_model().trn_config self.ref_model = base_model_class.from_pretrained(model_id, base_trn_config) @@ -438,6 +438,7 @@ def make_inputs_require_grad(module, input, output): self.vllm_client = MockVLLMClient(tokenizer, max_completion_length=self.max_completion_length) else: self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + # Only main process initializes the communicator for weight updates if self.accelerator.is_main_process: self.vllm_client.init_communicator(device="cpu") @@ -686,6 +687,7 @@ def _move_model_to_vllm(self): # Clean up parameter name for vLLM name = self._fix_param_name_to_vllm(name) + # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. # if self.vllm_mode == "server" and self.accelerator.is_main_process: # self.vllm_client.update_named_param(name, weight) # elif self.vllm_mode == "colocate": @@ -1191,9 +1193,6 @@ def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, thr Compute a mask for high-entropy tokens (above the given quantile threshold). """ pad_value = -1e9 - gathered = self.accelerator.gather(entropies) - return entropies - pad_value = -1e9 dtype = entropies.dtype # Create pad tensor from pre-allocated constant (avoids allocation in hot path) From 22e1c229a1241ef2319305b8f49dbb0c180a0724 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 15:56:50 +0100 Subject: [PATCH 74/78] chore: untrack example --- .../grpo_qwen3/finetune_grpo_qwen3.py | 202 ------------------ .../grpo_qwen3/finetune_grpo_qwen3.sh | 75 ------- 2 files changed, 277 deletions(-) delete mode 100755 examples/training/grpo_qwen3/finetune_grpo_qwen3.py delete mode 100755 examples/training/grpo_qwen3/finetune_grpo_qwen3.sh diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py b/examples/training/grpo_qwen3/finetune_grpo_qwen3.py deleted file mode 100755 index d23f7cfd3..000000000 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.py +++ /dev/null @@ -1,202 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Example script for fine-tuning a Qwen3 model using GRPO (Group Relative Policy Optimization) on Neuron devices. - -This script demonstrates how to use NeuronGRPOTrainer to train a model with reinforcement learning -using reward functions. GRPO is particularly effective for reasoning tasks and instruction following. - -For more information about GRPO, see: https://huggingface.co/papers/2402.03300 -""" - -from dataclasses import dataclass, field - -import torch -from datasets import load_dataset -from peft import LoraConfig -from transformers import AutoTokenizer, HfArgumentParser - -from optimum.neuron import NeuronGRPOConfig, NeuronGRPOTrainer -from optimum.neuron.models.training import NeuronModelForCausalLM -from optimum.neuron.trainers.extras import MockVLLMClient - - -x = MockVLLMClient # To avoid linter warning about unused import - -# ============================================================================= -# Reward Functions -# ============================================================================= -# GRPO requires reward functions to score the generated completions. -# These can be: -# 1. Model-based: Use a reward model to score completions -# 2. Rule-based: Custom Python functions that compute rewards -# -# For this example, we use simple rule-based rewards for demonstration. - - -def length_reward( - prompts: list[str], completions: list[str], completion_ids: list[list[int]], **kwargs -) -> list[float]: - """ - Simple reward function that rewards longer responses (up to a point). - """ - rewards = [] - for completion in completion_ids: - # Reward based on length, but cap at 100 tokens to avoid overly long responses - length = len(completion) - reward = min(length / 50.0, 2.0) # Scale: 0-2 - rewards.append(reward) - return rewards - - -def unique_words_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]: - """ - Reward function that encourages diversity by rewarding unique words. - """ - rewards = [] - for completion in completions: - if isinstance(completion, list): - completion = completion[0]["content"] - words = completion.lower().split() - unique_words = len(set(words)) - total_words = len(words) - # Reward diversity: ratio of unique words - reward = unique_words / max(total_words, 1) - rewards.append(reward) - return rewards - - -# ============================================================================= -# Data Loading and Preprocessing Function -# ============================================================================= -# GRPO requires datasets with a "prompt" column. The trainer will generate -# multiple completions for each prompt and score them using reward functions. - - -def load_grpo_dataset(): - """ - Load and prepare a dataset for GRPO training. - - For this example, we use the "trl-internal-testing/zen" dataset which is - a simple test dataset. In practice, you'd use a dataset appropriate for - your task (e.g., math problems, coding tasks, instruction following). - - Returns: - Dataset with "prompt" column - """ - # Load a simple test dataset - # This dataset has prompts in the "prompt" column - # dataset = load_dataset("trl-lib/DeepMath-103K", split="train") - # dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") - dataset = load_dataset("trl-lib/tldr", split="train") - - return dataset - - -# ============================================================================= -# Model Loading and Training Loop Function -# ============================================================================= -def train(model_id, tokenizer, dataset, training_args): - """ - Train the model using GRPO. - - Args: - model_id: HuggingFace model ID or path - tokenizer: Tokenizer for the model - dataset: Training dataset with "prompt" column - training_args: NeuronTrainingArguments - """ - # NOTE: Models with custom modeling implementation need a TrainingNeuronConfig - # This is automatically created when using NeuronTrainingArguments - trn_config = training_args.trn_config - dtype = torch.bfloat16 if training_args.bf16 else torch.float32 - model = NeuronModelForCausalLM.from_pretrained( - model_id, - trn_config, - torch_dtype=dtype, - # Use FlashAttention2 for better performance - # attn_implementation="flash_attention_2", - attn_implementation="eager", - ) - - # LoRA configuration for efficient fine-tuning - lora_config = LoraConfig( - r=64, - lora_alpha=128, - lora_dropout=0.05, - target_modules=["q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"], - bias="none", - task_type="CAUSAL_LM", - ) - - grpo_config = training_args - - # Define reward functions - # You can use multiple reward functions - they will be summed - reward_funcs = [ - length_reward, - unique_words_reward, - ] - - # Create the GRPO trainer - trainer = NeuronGRPOTrainer( - model=model, - reward_funcs=reward_funcs, - args=grpo_config, - train_dataset=dataset, - processing_class=tokenizer, - peft_config=lora_config, - # To do: disable this fake client, only for development without vLLM server. - vllm_client=MockVLLMClient(tokenizer, max_completion_length=grpo_config.max_completion_length), - ) - - # Train the model - trainer.train() - - -# ============================================================================= -# Defining the script-specific arguments -# ============================================================================= -@dataclass -class ScriptArguments: - model_id: str = field( - metadata={"help": "The model that you want to train from the Hugging Face hub."}, - ) - - -# ============================================================================= -# Main Function -# ============================================================================= -if __name__ == "__main__": - parser = HfArgumentParser((ScriptArguments, NeuronGRPOConfig)) - script_args, training_args = parser.parse_args_into_dataclasses() - - tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) - - # Ensure tokenizer has pad token - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # Load dataset - dataset = load_grpo_dataset() - - # Start training - train( - model_id=script_args.model_id, - tokenizer=tokenizer, - dataset=dataset, - training_args=training_args, - ) diff --git a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh b/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh deleted file mode 100755 index ef7a3d793..000000000 --- a/examples/training/grpo_qwen3/finetune_grpo_qwen3.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash -# Flags for Neuron compilation -export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation --cache_dir=$HOME/cache_dir_neuron/" -export NEURON_FUSE_SOFTMAX=1 -export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime -export MALLOC_ARENA_MAX=64 # Host OOM mitigation -# Force NCCL to ignore the AWS OFI plugin -# export FI_EFA_USE_DEVICE_RDMA=1 -# export FI_PROVIDER=efa -# export FI_EFA_FORK_SAFE=1 - -# Variables for training -PROCESSES_PER_NODE=2 -NUM_EPOCHS=1 # GRPO typically needs fewer epochs than SFT -TP_DEGREE=1 -BS=1 -GRADIENT_ACCUMULATION_STEPS=1 # Smaller for GRPO due to generation overhead -LOGGING_STEPS=1 -MODEL_NAME="Qwen/Qwen3-0.6B" # Use smaller model for testing -# MODEL_NAME="yujiepan/qwen3-tiny-random" # Use smaller model for testing -# MODEL_NAME="michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" -# MODEL_NAME="HuggingFaceTB/SmolLM2-135M" -OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-grpo-finetuned" -DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE" -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -# GRPO-specific variables -NUM_GENERATIONS=4 # Number of completions per prompt (G in paper) -MAX_PROMPT_LENGTH=256 -MAX_COMPLETION_LENGTH=768 -TEMPERATURE=0.8 -STEPS_PER_GENERATION=4 # Generate every N steps to amortize generation cost - -if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then - MAX_STEPS=5 -else - MAX_STEPS=100 # Limit steps for testing -fi - -# Note: Adjust these parameters based on your hardware and task -# - Increase num_generations for better exploration (but slower training) -# - Adjust temperature for sampling diversity -# - Tune epsilon and beta for GRPO algorithm sensitivity - -torchrun $DISTRIBUTED_ARGS finetune_grpo_qwen3.py \ - --model_id $MODEL_NAME \ - --num_train_epochs $NUM_EPOCHS \ - --do_train \ - --max_steps $MAX_STEPS \ - --per_device_train_batch_size $BS \ - --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ - --gradient_checkpointing \ - --learning_rate 5e-4 \ - --bf16 \ - --tensor_parallel_size $TP_DEGREE \ - --zero_1 \ - --optimizer_use_master_weights false \ - --optimizer_use_fp32_grad_acc false \ - --async_save \ - --logging_steps $LOGGING_STEPS \ - --output_dir $OUTPUT_DIR \ - --lr_scheduler_type "constant" \ - --overwrite_output_dir \ - --num_generations $NUM_GENERATIONS \ - --max_prompt_length $MAX_PROMPT_LENGTH \ - --max_completion_length $MAX_COMPLETION_LENGTH \ - --temperature $TEMPERATURE \ - --steps_per_generation $STEPS_PER_GENERATION \ - --epsilon 0.1 \ - --beta 0.01 - -echo "================================" -echo "Training completed!" -echo "Model saved to: $OUTPUT_DIR" -echo "================================" From ba1ac45daf586209bce3b50b63dbf8e20ddfe481 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 15:57:09 +0100 Subject: [PATCH 75/78] chore: clean trl_utils.py --- optimum/neuron/trainers/trl_utils.py | 52 ++++++---------------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 1ad434c18..ff3ee9054 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -45,7 +45,7 @@ def pad( """ batch_size = len(tensors) if max_length is None: - max_length = np.max([t.shape[0] for t in tensors]).tolist() + max_length = max([t.shape[0] for t in tensors]) output_shape = (max_length,) + tensors[0].shape[1:] @@ -70,47 +70,15 @@ def pad( def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: """ - Compute the Shannon entropy (in nats) for each row of *logits* in a memory-efficient way. + Compute the Shannon entropy (in nats) for each row of *logits*. - Instead of materializing the full softmax for all rows at once, the logits are flattened to shape (N, num_classes), - where N is the product of all leading dimensions. Computation is then performed in chunks of size `chunk_size` - along this flattened dimension, reducing peak memory usage. The result is reshaped back to match the input's - leading dimensions. - - This implementation uses pre-allocated output tensors instead of list accumulation to avoid - XLA graph fragmentation and repeated tensor allocations. - - Args: - logits (`torch.Tensor`): - Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions - are preserved in the output. - chunk_size (`int`, *optional*, defaults to `128`): - Number of rows from the flattened logits to process per iteration. Smaller values reduce memory usage at - the cost of more iterations. - - Returns: - `torch.Tensor`: - Entropy values with shape `logits.shape[:-1]`. + Original implementation from trl.trainer.utils.entropy_from_logits provide a memory efficient alternative, + but it accumulates results in a list which can lead to graph fragmentation on XLA devices. + Here we keep things simple and compute the entropy in one go. """ - original_shape = logits.shape[:-1] # all dims except num_classes - num_classes = logits.shape[-1] - - # Flatten all leading dimensions into one - flat_logits = logits.reshape(-1, num_classes) - total_rows = flat_logits.size(0) - - # Pre-allocate output tensor to avoid list accumulation - entropies = torch.empty(total_rows, dtype=logits.dtype, device=logits.device) - - # Process in chunks, writing directly to pre-allocated tensor - for start in range(0, total_rows, chunk_size): - end = min(start + chunk_size, total_rows) - chunk = flat_logits[start:end] - logps = F.log_softmax(chunk, dim=-1) - chunk_entropy = -(torch.exp(logps) * logps).sum(-1) - entropies[start:end] = chunk_entropy - - return entropies.reshape(original_shape) + logps = F.log_softmax(logits, dim=-1) + entropy = -(torch.exp(logps) * logps).sum(-1) + return entropy def neuron_parallel_compile_tokenizer_decoder_method( @@ -233,7 +201,7 @@ def nanmin(tensor: torch.Tensor) -> torch.Tensor: Compute the minimum value of a tensor, ignoring NaNs. """ mask = torch.isnan(tensor) - filled = torch.where(mask, torch.full_like(tensor, float("inf")), tensor) + filled = torch.where(mask, torch.tensor(float("inf"), device=tensor.device), tensor) min_value = torch.min(filled) return min_value @@ -244,7 +212,7 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: Compute the maximum value of a tensor, ignoring NaNs. """ mask = torch.isnan(tensor) - filled = torch.where(mask, torch.full_like(tensor, float("-inf")), tensor) + filled = torch.where(mask, torch.tensor(float("-inf"), device=tensor.device), tensor) max_value = torch.max(filled) return max_value From 5a19bd6981ac81050e216e7c16fba315eaacb43e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 15:58:58 +0100 Subject: [PATCH 76/78] chore: clean trl_utils.py --- optimum/neuron/trainers/trl_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index ff3ee9054..22e74a35c 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -45,7 +45,7 @@ def pad( """ batch_size = len(tensors) if max_length is None: - max_length = max([t.shape[0] for t in tensors]) + max_length = max(t.shape[0] for t in tensors) output_shape = (max_length,) + tensors[0].shape[1:] From 51b68be44ffdada1eac1464a2db20cb55913a37e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 16:01:35 +0100 Subject: [PATCH 77/78] chore: clean trl_utils.py --- optimum/neuron/trainers/trl_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 22e74a35c..60e402a4e 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -185,12 +185,8 @@ def batch_pad_sequences( mask[i, :seq_len] = 1 # Single conversion and transfer to device - padded_tensor = torch.from_numpy(padded).to(dtype=dtype) - mask_tensor = torch.from_numpy(mask).to(dtype=torch.long) - - if device is not None: - padded_tensor = padded_tensor.to(device) - mask_tensor = mask_tensor.to(device) + padded_tensor = torch.from_numpy(padded).to(dtype=dtype, device=device) + mask_tensor = torch.from_numpy(mask).to(dtype=torch.long, device=device) return padded_tensor, mask_tensor @@ -231,7 +227,7 @@ def nanstd(tensor: torch.Tensor, unbiased: bool = False) -> torch.Tensor: diff_squared = torch.where(mask, (clean - mean) ** 2, torch.zeros_like(tensor)) if unbiased: - variance = diff_squared.sum() / (count - 1).clamp(min=1) + variance = diff_squared.sum() / (count - 1).clamp(min=torch.tensor(1.0, device=tensor.device)) else: variance = diff_squared.sum() / count From bce138b410c5e8e475f29f43714e04af7020e33e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 4 Feb 2026 16:18:44 +0100 Subject: [PATCH 78/78] fix: add training extra for doc building --- .github/actions/install_optimum_neuron/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/install_optimum_neuron/action.yml b/.github/actions/install_optimum_neuron/action.yml index 590ce13ff..cc46565a3 100644 --- a/.github/actions/install_optimum_neuron/action.yml +++ b/.github/actions/install_optimum_neuron/action.yml @@ -7,4 +7,4 @@ runs: shell: bash run: | source aws_neuron_venv_pytorch/bin/activate - python -m pip install .[neuronx,tests] + python -m pip install .[neuronx,tests,training]