diff --git a/.github/actions/install_optimum_neuron/action.yml b/.github/actions/install_optimum_neuron/action.yml index 590ce13ff..cc46565a3 100644 --- a/.github/actions/install_optimum_neuron/action.yml +++ b/.github/actions/install_optimum_neuron/action.yml @@ -7,4 +7,4 @@ runs: shell: bash run: | source aws_neuron_venv_pytorch/bin/activate - python -m pip install .[neuronx,tests] + python -m pip install .[neuronx,tests,training] diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 260ad04b6..913bd32ec 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -37,8 +37,10 @@ "trainers": [ "NeuronTrainer", "NeuronSFTTrainer", + "NeuronGRPOTrainer", "NeuronTrainingArguments", "NeuronSFTConfig", + "NeuronGRPOConfig", ], "modeling_traced": ["NeuronTracedModel"], "modeling": [ @@ -156,6 +158,8 @@ from .models.inference.yolos import NeuronYolosForObjectDetection from .pipelines import pipeline from .trainers import ( + NeuronGRPOConfig, + NeuronGRPOTrainer, NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainer, diff --git a/optimum/neuron/models/training/transformations_utils.py b/optimum/neuron/models/training/transformations_utils.py index 7c8ec9281..a53e57841 100644 --- a/optimum/neuron/models/training/transformations_utils.py +++ b/optimum/neuron/models/training/transformations_utils.py @@ -101,7 +101,7 @@ def peft_type(self) -> str | None: return self._peft_type @peft_type.setter - def peft_type(self, value: str): + def peft_type(self, value: str | None): self._peft_type = value @abstractmethod @@ -533,6 +533,9 @@ def _lora_adapt_state_dict( f"{module_fully_qualified_name}.{name}.lora_A.{param_name}" for name in self.linear_names ] + if not all(name in state_dict for name in lora_A_weight_names): + continue + logger.warning("Taking the mean of the LoRA A weights since there is only one LoRA A weight after fusing.") lora_A_weight = torch.mean( torch.stack([state_dict.pop(name) for name in lora_A_weight_names], dim=0), @@ -650,9 +653,7 @@ def _lora_to_original_weights( break if weight_name is None or to_duplicate_name is None: - raise ValueError( - f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}." - ) + continue # When saved, the name of the adapter is removed in the weight qualified name since weights for each # adapter are saved separately. @@ -700,9 +701,7 @@ def _lora_to_original_weights( if to_concat_and_duplicate_name is not None and to_unfuse_name is not None: break if to_concat_and_duplicate_name is None or to_unfuse_name is None: - raise ValueError( - f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}." - ) + continue weight_name_without_adapter_name = remove_adapter_name(to_concat_and_duplicate_name) linear_sharded_weights = sharded_state_dicts[weight_name_without_adapter_name] @@ -1100,6 +1099,9 @@ def _lora_adapt_state_dict( lora_A_weight_names = [lora_A_q_name, lora_A_k_name, lora_A_v_name] + if not all(name in state_dict for name in lora_A_weight_names): + continue + logger.warning("Taking the mean of the LoRA A weights since there is only one LoRA A weight after fusing.") lora_A_weight = torch.mean( torch.stack([state_dict.pop(name) for name in lora_A_weight_names], dim=0), diff --git a/optimum/neuron/trainers/__init__.py b/optimum/neuron/trainers/__init__.py index df612b2dc..1d87a937d 100644 --- a/optimum/neuron/trainers/__init__.py +++ b/optimum/neuron/trainers/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .grpo_config import NeuronGRPOConfig +from .grpo_trainer import NeuronGRPOTrainer from .sft_config import NeuronSFTConfig from .sft_trainer import NeuronSFTTrainer from .training_args import NeuronTrainingArguments diff --git a/optimum/neuron/trainers/extras/__init__.py b/optimum/neuron/trainers/extras/__init__.py new file mode 100644 index 000000000..5ce9115ed --- /dev/null +++ b/optimum/neuron/trainers/extras/__init__.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vllm_client import MockVLLMClient, VLLMClient + + +__all__ = ["VLLMClient", "MockVLLMClient"] diff --git a/optimum/neuron/trainers/extras/vllm_client.py b/optimum/neuron/trainers/extras/vllm_client.py new file mode 100644 index 000000000..f3a0c3fd8 --- /dev/null +++ b/optimum/neuron/trainers/extras/vllm_client.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import random +import time +from collections import namedtuple +from typing import Union + +import requests +import torch +import torch_xla +from optimum.utils import logging +from trl.extras.vllm_client import VLLMClient as TRLVLLMClient +from trl.import_utils import is_vllm_available + + +if is_vllm_available(): + from vllm.distributed.utils import StatelessProcessGroup +else: + + class StatelessProcessGroup: + pass + + +logger = logging.get_logger() + +# Set up the communication group for weight broadcasting using CPU communicator +Group = namedtuple("Group", "barrier") + + +class CPUCommunicator: + def __init__(self, store, rank): + self.rank = rank + self.store = store + self.group = Group(barrier=self.barrier) + + def broadcast(self, tensor, src): + # Move tensor to CPU to ensure compatibility with vLLM server + if tensor.device.type == "xla": + tensor = tensor.cpu() + torch_xla.sync() + self.store.broadcast_obj(tensor, src=self.rank) + + def barrier(self): + self.store.barrier() + + def __del__(self): + del self.store + + +class VLLMClient(TRLVLLMClient): + """VLLMClient with CPU-based communication for Neuron environments.""" + + def __init__( + self, + base_url: str | None = None, + host: str = "0.0.0.0", + server_port: int = 8000, + group_port: int = 51216, + connection_timeout: float = 0.0, + ): + super().__init__( + base_url=base_url, + host=host, + server_port=server_port, + group_port=group_port, + connection_timeout=connection_timeout, + ) + + def init_communicator(self, device: Union[torch.device, str, int] = 0): + # Get the world size from the server + url = f"{self.base_url}/get_world_size/" + response = requests.get(url) + if response.status_code == 200: + vllm_world_size = response.json()["world_size"] + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + world_size = vllm_world_size + 1 # add the client to the world + self.rank = vllm_world_size # the client's rank is the last process + + # Initialize weight update group + url = f"{self.base_url}/init_communicator/" + + # Use dummy UUID for CPU/Neuron environments + client_device_uuid = "42" + + # In the server side, the host is set to 0.0.0.0 + response = self.session.post( + url, + json={ + "host": "0.0.0.0", + "port": self.group_port, + "world_size": world_size, + "client_device_uuid": client_device_uuid, + }, + ) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + # Brief delay to allow server initialization. While not strictly required (client socket will retry on + # connection failure), this prevents log warnings like: + # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 + time.sleep(0.1) + + pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) + self.communicator = CPUCommunicator(pg, self.rank) + + # When the client object is deleted, close the weight update group + atexit.register(self.close_communicator) + + +class MockVLLMClient(VLLMClient): + """ + Mock VLLMClient that generates completions without a vLLM server. + + Used for neuron_parallel_compile and testing. Generates completions by cycling + through prompt tokens (echo mode), producing deterministic, non-garbage output. + """ + + def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None): + self.tokenizer = tokenizer + self.max_completion_length = max_completion_length + self.min_completion_length = min(min_completion_length, max_completion_length) + self.random = random.Random(seed) + + logger.warning( + "Using MockVLLMClient for neuron_parallel_compile or testing. " + "This generates echo completions and should only be used for compilation/testing." + ) + + def generate( + self, + prompts: list[str], + images=None, + n: int = 1, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 256, + repetition_penalty: float = 1.0, + truncate_prompt_tokens=None, + guided_decoding_regex=None, + generation_kwargs=None, + ): + prompt_ids = [] + completion_ids = [] + logprobs = [] + + # Fallback tokens if prompt is empty + vocab_size = self.tokenizer.vocab_size + fallback_token_id = min(100, vocab_size - 1) + + for prompt in prompts: + # Tokenize prompt + prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False) + + # Truncate if needed + if truncate_prompt_tokens is not None and len(prompt_tokens) > truncate_prompt_tokens: + prompt_tokens = prompt_tokens[-truncate_prompt_tokens:] + + prompt_ids.append(prompt_tokens) + + # Generate n completions per prompt + for _ in range(n): + # Random completion length within bounds + max_len = min(max_tokens, self.max_completion_length) + completion_length = self.random.randint(self.min_completion_length, max_len) + + # Echo mode: cycle through prompt tokens + if len(prompt_tokens) > 0: + completion = [prompt_tokens[i % len(prompt_tokens)] for i in range(completion_length)] + else: + # Fallback if prompt is empty + completion = [fallback_token_id] * completion_length + + completion_ids.append(completion) + + # Logprobs: simulate higher confidence for echoed tokens + completion_logprobs = [-self.random.uniform(0.5, 2.0) for _ in range(completion_length)] + logprobs.append(completion_logprobs) + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "logprobs": logprobs, + } + + def init_communicator(self, device): + pass + + def update_named_param(self, name, weights): + pass + + def reset_prefix_cache(self): + pass + + def close_communicator(self): + pass diff --git a/optimum/neuron/trainers/grpo_config.py b/optimum/neuron/trainers/grpo_config.py new file mode 100644 index 000000000..0a498c8b0 --- /dev/null +++ b/optimum/neuron/trainers/grpo_config.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..utils.import_utils import is_trl_available +from .training_args import NeuronTrainingArguments +from .trl_utils import TRL_VERSION + + +if is_trl_available(): + from trl import GRPOConfig +else: + + @dataclass + class GRPOConfig: + def __init__(self, *args, **kwargs): + raise RuntimeError(f"You need to install the `trl=={TRL_VERSION}` library to use the `NeuronGRPOConfig`.") + + +@dataclass +class NeuronGRPOConfig(NeuronTrainingArguments, GRPOConfig): + """ + Configuration class for Neuron-optimized GRPO training. + + This class combines NeuronTrainingArguments for Trainium-specific settings + with GRPOConfig for GRPO algorithm parameters. + """ + + experimental: bool = field( + default=False, + metadata={ + "help": "NeuronGRPOTrainer is experimental and not production-ready. Set to `True` to acknowledge this and " + "proceed. If `False` (the default), an error will be raised at initialization." + }, + ) + + use_vllm: bool = field( + default=True, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for " + "generation instead of the default model.generate(). Requires `vllm` to be installed. Required for NeuronGRPOTrainer." + }, + ) + + def __post_init__(self): + if not self.experimental: + raise ValueError( + "NeuronGRPOTrainer is experimental and not production-ready. To proceed, set `experimental=True` in " + "your NeuronGRPOConfig. This flag exists to ensure users are aware of the current state of the implementation." + ) + + # For now, NeuronGRPOTrainer requires vLLM for generation, no other way is supported. + if not self.use_vllm: + raise ValueError("NeuronGRPOTrainer requires `use_vllm` to be set to `True`.") + + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + NeuronTrainingArguments.__post_init__(self) + + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + + num_processes = self.world_size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + self.steps_per_generation = self.generation_batch_size // ( + self.per_device_train_batch_size * num_processes + ) + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.do_eval and self.eval_strategy != "no": + # Just ensure the value is divisible by the global batch size + if (self.per_device_eval_batch_size * num_processes) % self.num_generations != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by num_generations ({self.num_generations})." + ) + + # The generation batch must contain full prompt groups (no partials), so it must be divisible by + # num_generations. + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) + + if self.num_generations < 2: + raise ValueError( + "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) + + if self.delta is not None and self.use_liger_loss: + raise ValueError("Liger loss does not support two-sided GRPO loss yet.") diff --git a/optimum/neuron/trainers/grpo_trainer.py b/optimum/neuron/trainers/grpo_trainer.py new file mode 100644 index 000000000..9008a7dcb --- /dev/null +++ b/optimum/neuron/trainers/grpo_trainer.py @@ -0,0 +1,1414 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from collections import defaultdict, deque +from typing import Any, Callable + +import datasets +import numpy as np +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from accelerate.utils import set_seed +from neuronx_distributed import parallel_layers +from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + get_pipeline_model_parallel_rank, + get_tensor_model_parallel_rank, +) +from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu +from optimum.utils import logging +from torch.utils.data import Dataset, IterableDataset, Sampler +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_wandb_available, +) +from transformers.utils import is_rich_available + +from ..accelerate.utils import ( + broadcast_object_to_pipeline_model_parallel_group, + broadcast_object_to_tensor_model_parallel_group, + gather_object_from_data_parallel_group, +) +from ..models.training import NeuronModelForCausalLM +from ..peft import NeuronPeftModel, get_peft_model +from ..peft.utils.vllm import get_original_merged_weights_for_vllm +from ..utils import is_precompilation, is_trl_available +from ..utils.import_utils import is_peft_available +from .extras import MockVLLMClient, VLLMClient +from .grpo_config import NeuronGRPOConfig +from .training_args import NeuronTrainingArguments +from .transformers import NeuronTrainer +from .trl_utils import ( + TRL_VERSION, + DistributedRepeatSampler, + batch_pad_sequences, + nanmax, + nanmin, + nanstd, + neuron_parallel_compile_tokenizer_decoder_method, +) + + +if is_wandb_available(): + import wandb + +if is_trl_available(): + from trl import GRPOConfig, GRPOTrainer + from trl.data_utils import is_conversational, maybe_apply_chat_template + from trl.extras.vllm_client import VLLMClient as TRLVLLMClient + from trl.trainer.utils import ( + RepeatSampler, + disable_dropout_in_model, + entropy_from_logits, + identity, + print_prompt_completions_sample, + selective_log_softmax, + ) +else: + + class GRPOTrainer: + pass + + class GRPOConfig: + pass + + class TRLVLLMClient: + pass + + +if is_peft_available(): + from peft import PeftConfig +else: + + class PeftConfig: + pass + + +# Create a new class that inherits from NeuronTrainer to use this class instead of the transformers Trainer, +# but has the same methods and attributes as GRPOTrainer. +# We can then inherit from this class to create our NeuronGRPOTrainer. +_GRPOTrainer = type( + "_GRPOTrainer", + (NeuronTrainer,), + GRPOTrainer.__dict__.copy(), +) + + +logger = logging.get_logger() + + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] + + +class NeuronGRPOTrainer(_GRPOTrainer): + """ + `GRPOTrainer` adapted for Neuron (Trainium) devices. + + This algorithm was initially proposed in the paper [DeepSeekMath: Pushing the Limits + of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). + """ + + def __init__( + self, + model: str | PreTrainedModel | torch.nn.Module, + reward_funcs: RewardFunc | list[RewardFunc], + args: GRPOConfig | None = None, + train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None, + eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, + peft_config: PeftConfig | None = None, + vllm_client: TRLVLLMClient | None = None, + fixed_size_for_obj_collectives: int | None = 10 * 1024 * 1024, # 10 MB + ): + if not is_trl_available(required_version=TRL_VERSION): + raise RuntimeError(f"Using NeuronGRPOTrainer requires trl=={TRL_VERSION}.") + + # Patch tokenizer decode method for Neuron parallel compilation to avoid failures. + if is_precompilation() and hasattr(processing_class, "_decode"): + processing_class._decode = neuron_parallel_compile_tokenizer_decoder_method.__get__( + processing_class, processing_class.__class__ + ) + + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = NeuronGRPOConfig(f"{model_name}-GRPO") + elif isinstance(args, NeuronTrainingArguments) and not isinstance(args, NeuronGRPOConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token + dict_args.pop("push_to_hub_token", None) + args = NeuronGRPOConfig(**dict_args) + + # Model + if isinstance(model, str): + model = NeuronModelForCausalLM.from_pretrained(model, args.trn_config, **args.model_init_kwargs or {}) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + model_id = model.config._name_or_path + + # Processing class + if processing_class is None: + from transformers import AutoProcessor + + processing_class = AutoProcessor.from_pretrained(model_id, truncation_side="left") + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Store tokens and token IDs for generation and reward computation + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + # Store model forward signature keys for checking supported kwargs + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + # PEFT configuration and model wrapping + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 + + if peft_config is not None and not isinstance(model, NeuronPeftModel): + # Enable gradient checkpointing if needed + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + if args.gradient_checkpointing and ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) + + # Reward functions - for now, only support callable reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + raise NotImplementedError( + "Loading reward models from model IDs is not yet implemented for NeuronGRPOTrainer. " + "Please provide either a PreTrainedModel or a custom callable reward function." + ) + if isinstance(reward_func, PreTrainedModel): + raise NotImplementedError( + "Using PreTrainedModel reward functions is not yet fully implemented for NeuronGRPOTrainer. " + "Please use a custom callable reward function for now." + ) + # Custom callable reward function + self.reward_func_names.append(reward_func.__name__) + + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + self.reward_weights = self.reward_weights.to(xm.xla_device()) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " + f"reward functions ({len(reward_funcs)})." + ) + + # Note: We skip the loop that sets up tokenizers for PreTrainedModel reward functions + # since we currently raise errors for those anyway + self.reward_processing_classes = reward_processing_classes + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size + self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap + self.use_liger_loss = args.use_liger_loss + self.loss_type = args.loss_type + self.scale_rewards = args.scale_rewards + self.importance_sampling_level = args.importance_sampling_level + self.mask_truncated_completions = args.mask_truncated_completions + self.top_entropy_quantile = args.top_entropy_quantile + + # Validate liger kernel configuration + if self.use_liger_loss: + raise RuntimeError( + "Liger Kernel loss is not supported on Neuron devices. " + "Please set use_liger_loss=False in your GRPOConfig." + ) + + # Neuron GRPO only supports vLLM generation + if not self.use_vllm: + raise NotImplementedError( + "NeuronGRPOTrainer currently only supports vLLM generation. " + "Please set use_vllm=True in your GRPOConfig." + ) + + # Only server mode is supported for now + if self.vllm_mode != "server": + raise NotImplementedError( + "NeuronGRPOTrainer currently only supports vLLM server mode. " + "Please set vllm_mode='server' in your GRPOConfig." + ) + + if self._is_vlm: + raise NotImplementedError( + "Vision-language models are not yet supported in NeuronGRPOTrainer. " + "Please use text-only models for now." + ) + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + raise NotImplementedError( + "Iterable datasets are not yet supported in NeuronGRPOTrainer. Please use a standard dataset instead." + ) + + # Multi-step + self.num_iterations = args.num_iterations + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self._step = 0 + self._buffered_inputs = None + + # Suppress FLOP estimation warning + model.warnings_issued["estimate_tokens"] = True + + # Initialize NeuronTrainer + NeuronTrainer.__init__( + self, + model=model, + args=args, + data_collator=identity, # No data collation is needed in GRPO + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + ) + + # Set _train_batch_size for compatibility with GRPOTrainer's get_train_dataloader + # NeuronTrainer doesn't set this, but GRPOTrainer expects it + self._train_batch_size = args.train_batch_size + + # Set FSDP flag to False (NeuronTrainer doesn't support FSDP) + # GRPOTrainer's methods check this attribute + self.is_fsdp_enabled = False + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + self.ref_model = None + elif isinstance(model, NeuronPeftModel): + # Create reference model using base model class + # Original implementation was disabling adapters, but we create a separate model + # instance instead for XLA compatibility. + base_model_class = model.get_base_model().__class__ + base_trn_config = model.get_base_model().trn_config + self.ref_model = base_model_class.from_pretrained(model_id, base_trn_config) + else: + # Create reference model using NeuronModelForCausalLM + self.ref_model = NeuronModelForCausalLM.from_pretrained( + model_id, args.trn_config, **args.model_init_kwargs or {} + ) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + self._logs = { + "images": deque(maxlen=args.generation_batch_size), + "prompt": deque(maxlen=args.generation_batch_size), + "completion": deque(maxlen=args.generation_batch_size), + "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + "advantages": deque(maxlen=args.generation_batch_size), + } + + # Ensure each process receives a unique seed + set_seed(args.seed, device_specific=True) + + # vLLM setup - server mode only + if vllm_client is not None: + # Use injected client (for testing, mocking, or custom implementations) + self.vllm_client = vllm_client + else: + # Default: Create VLLMClient from args + from ..utils import is_vllm_available + + if not is_vllm_available(): + raise ImportError("vLLM is not available. Please install vLLM to use NeuronGRPOTrainer.") + + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + + # For `neuron_parallel_compile`, use a mock VLLM client that doesn't make actual server requests. + if is_precompilation(): + self.vllm_client = MockVLLMClient(tokenizer, max_completion_length=self.max_completion_length) + else: + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + + # Only main process initializes the communicator for weight updates + if self.accelerator.is_main_process: + self.vllm_client.init_communicator(device="cpu") + + # vLLM specific sampling arguments + self.guided_decoding_regex = args.vllm_guided_decoding_regex + + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # Synchronize all processes after vLLM client setup + self.accelerator.wait_for_everyone() + + # Gradient accumulation requires scaled loss + self.model_accepts_loss_kwargs = False + + # Add tags for models + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + # Prepare reference model + if self.ref_model is not None: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + # Sync reference model callback + if args.sync_ref_model: + from trl.trainer.callbacks import SyncRefModelCallback + + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + # Prepare reward functions + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + self.dp_rank = get_data_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() + self.pp_rank = get_pipeline_model_parallel_rank() + + self.fixed_size_obj_collectives = fixed_size_for_obj_collectives + + # Pre-create constant tensors for XLA optimization. + # These are used in clamp operations and comparisons. Creating them once avoids + # repeated tensor allocations that cause XLA graph fragmentation. + device = xm.xla_device() + self._one_float = torch.tensor(1.0, dtype=torch.float32, device=device) + self._one_long = torch.tensor(1, dtype=torch.long, device=device) + self._inf_float = torch.tensor(float("inf"), dtype=torch.float32, device=device) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: + if dataset is None: + dataset = self.train_dataset + if self.accelerator.num_processes == 1: + sampler = RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + else: + trn_config = self.accelerator.state.trn_config + num_replicas = trn_config.data_parallel_size + rank = parallel_layers.parallel_state.get_data_parallel_rank() + sampler = DistributedRepeatSampler( + dataset=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + num_replicas=num_replicas, + rank=rank, + ) + return sampler + + def train( + self, + resume_from_checkpoint: str | bool | None = None, + ): + return NeuronTrainer.train(self, resume_from_checkpoint=resume_from_checkpoint) + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + # Average the metrics. Values can be either floats or CPU tensors, so we handle both. + # Using sum() works for both types; we call .item() only on tensors when computing the average. + metrics = {} + for key, val_list in self._metrics[mode].items(): + if len(val_list) == 0: + continue + # Convert any tensor values to floats for averaging + float_vals = [v.item() if isinstance(v, torch.Tensor) else v for v in val_list] + metrics[key] = sum(float_vals) / len(float_vals) + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + + # Using the NeuronTrainer log method instead of super().log. + NeuronTrainer.log(self, logs) + + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + if is_rich_available(): + print_prompt_completions_sample( + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"], + self._logs["advantages"], + self.state.global_step, + self.num_completions_to_print, + ) + + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + import pandas as pd + + table = { + "step": [str(self.state.global_step)] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + **self._logs["rewards"], + "advantage": self._logs["advantages"], + } + + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) + + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + wandb.log({"completions": wandb.Table(dataframe=df)}) + + def _save_checkpoint(self, model=None, trial=None, metrics=None): + return NeuronTrainer._save_checkpoint(self) + + def _prepare_inputs(self, inputs: Any) -> dict[str, Any]: + # Explicitly call GRPOTrainer's _prepare_inputs + return GRPOTrainer._prepare_inputs(self, inputs) + + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + + excluded_keys = {"prompt", "completion", "completion_ids"} + keys = [key for key in inputs[0] if key not in excluded_keys] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + reward_kwargs["trainer_state"] = self.state + + # Separate model-based vs callable reward functions by index + model_indices = [] + callable_indices = [] + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, torch.nn.Module): + model_indices.append(i) + else: + callable_indices.append(i) + + # Collect results: list of (index, tensor) tuples + reward_columns = [] + + if model_indices: + # Pre-compute texts once if needed (all models use same text format) + texts = None + is_conv = is_conversational(inputs[0]) + + if is_conv: + from trl.data_utils import apply_chat_template + + for i in model_indices: + reward_func = self.reward_funcs[i] + reward_processing_class = self.reward_processing_classes[i] + + if is_conv: + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, + ) + reward_inputs = NeuronTrainer._prepare_inputs(self, reward_inputs) + + with torch.inference_mode(): + logits = reward_func(**reward_inputs).logits[:, 0] + + reward_columns.append((i, logits)) + + if callable_indices: + # Use numpy for intermediate storage to avoid Python list overhead + # and enable efficient single-transfer to XLA device + num_samples = len(prompts) + callable_rewards_np = np.empty((len(callable_indices), num_samples), dtype=np.float32) + + for local_idx, global_idx in enumerate(callable_indices): + reward_func = self.reward_funcs[global_idx] + output = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + for j, r in enumerate(output): + callable_rewards_np[local_idx, j] = r if r is not None else np.nan + + # Single tensor creation and transfer from numpy array + callable_tensor = torch.from_numpy(callable_rewards_np).to(device=device) + + for local_idx, global_idx in enumerate(callable_indices): + reward_columns.append((global_idx, callable_tensor[local_idx])) + + # Sort by original index to maintain correct column order + reward_columns.sort(key=lambda x: x[0]) + + # Stack all columns at once instead of indexed assignment in loop + rewards_per_func = torch.stack([col for _, col in reward_columns], dim=1) + + torch_xla.sync() + rewards_per_func = self.accelerator.gather(rewards_per_func) + return rewards_per_func + + def _move_model_to_vllm(self): + if isinstance(self.model, NeuronPeftModel): + # Get original (unsharded, untransformed) merged weights for vLLM + original_weights = get_original_merged_weights_for_vllm(self.model) + # For now, we only support CPU communicator in Neuron environments. + # The CPU communicator moves weights to CPU before broadcasting, but to avoid a lot of device -> host moves, + # we move the weights to CPU here once before broadcasting, and the communicator will just broadcast them. + original_weights = move_all_tensor_to_cpu(original_weights) + torch_xla.sync() + + # Send weights to vLLM server (only main process for server mode) + for name, weight in original_weights.items(): + # Clean up parameter name for vLLM + name = self._fix_param_name_to_vllm(name) + + # TODO: Currently not supported, to implement asap in later PRs with vLLM integration. + # if self.vllm_mode == "server" and self.accelerator.is_main_process: + # self.vllm_client.update_named_param(name, weight) + # elif self.vllm_mode == "colocate": + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + # llm_model.load_weights([(name, weight)]) + else: + for name, param in self.model.named_parameters(): + name = self._fix_param_name_to_vllm(name) + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == "colocate": + self.llm.reset_prefix_cache() + + def _generate_single_turn(self, prompts: list[str], images: list | None): + if self.state.global_step != getattr(self, "_last_loaded_step", -1): + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Take unique prompts since we have num_generations duplicates + # Use maybe_apply_chat_template to handle both conversational and simple formats + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + ordered_set_of_prompts = prompts_text[:: self.num_generations] + + if images is not None: + ordered_set_of_images = images[:: self.num_generations] + else: + ordered_set_of_images = None + + # Generate on main process only, then broadcast to all ranks + if self.tp_rank == self.pp_rank == 0: + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) + else: + output = None + + # Broadcast output to all ranks + trn_config = self.accelerator.state.trn_config + if trn_config.tensor_parallel_size > 1: + output = broadcast_object_to_tensor_model_parallel_group( + output, fixed_size=self.fixed_size_obj_collectives + ) + if trn_config.pipeline_parallel_size > 1: + output = broadcast_object_to_pipeline_model_parallel_group( + output, fixed_size=self.fixed_size_obj_collectives + ) + + # Repeat prompt_ids num_generations times to match completion_ids + prompt_ids = [ids for ids in output["prompt_ids"] for _ in range(self.num_generations)] + completion_ids = output["completion_ids"] + logprobs = output["logprobs"] + + # No forward_kwargs for mock vLLM + forward_kwargs = {} + + return prompt_ids, completion_ids, logprobs, forward_kwargs + + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size: int | None = None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + total_batch_size = input_ids.size(0) + batch_size = batch_size or total_batch_size # Chunk inputs into smaller batches to reduce memory peak + + # Ensure input batch size is divisible by `batch_size` to avoid issues with XLA graph compilation. + if total_batch_size % batch_size != 0: + raise ValueError( + f"The input_ids batch size must be divisible by `batch_size`, but got {total_batch_size} and " + f"{batch_size}." + ) + + num_chunks = total_batch_size // batch_size + + all_logps = [] + all_entropies = [] if compute_entropy else None + + chunked_input_ids = torch.split(input_ids, batch_size, dim=0) + chunked_attention_mask = torch.split(attention_mask, batch_size, dim=0) + + for chunk_idx in range(num_chunks): + input_ids_batch = chunked_input_ids[chunk_idx] + attention_mask_batch = chunked_attention_mask[chunk_idx] + + # TODO: rewrite this without slicing to avoid XLA graph recompilations + start = chunk_idx * batch_size + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + # Sync before forward to isolate the forward pass graph + torch_xla.sync() + + logits = model(**model_inputs).logits + + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + + completion_ids = input_ids_batch[:, -logits_to_keep:] + logps = selective_log_softmax(logits, completion_ids) # compute logprobs + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + + # Sync after forward to materialize results before list append + torch_xla.sync() + + all_logps.append(logps) + if compute_entropy: + all_entropies.append(entropies) + + # Single concat at the end - one clean graph + all_logps = torch.cat(all_logps, dim=0) + all_entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + + return all_logps, all_entropies + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + ( + prompt_ids_list, + completion_ids_list, + num_items_in_batch, + sampling_per_token_logps_list, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert lists of token IDs to padded tensors using XLA-optimized batch padding. + # This avoids creating many small tensors and multiple device transfers. + prompt_ids, prompt_mask = batch_pad_sequences( + prompt_ids_list, + target_length=self.max_prompt_length, + padding_value=self.pad_token_id, + padding_side="left", + dtype=torch.long, + device=device, + ) + + completion_ids, completion_mask = batch_pad_sequences( + completion_ids_list, + target_length=self.max_completion_length, + padding_value=self.pad_token_id, + padding_side="right", + dtype=torch.long, + device=device, + ) + + if sampling_per_token_logps_list is not None: + sampling_per_token_logps, _ = batch_pad_sequences( + sampling_per_token_logps_list, + target_length=self.max_completion_length, + padding_value=0.0, + padding_side="right", + dtype=torch.float32, + device=device, + ) + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask. + # Use tensor operations instead of Python list iteration for XLA compatibility. + if self.mask_truncated_completions: + # Check if last token is NOT eos or pad (meaning sequence was truncated) + last_tokens = completion_ids[:, -1] + # A sequence is NOT truncated if its last token is eos or pad + is_not_truncated = (last_tokens == self.eos_token_id) | (last_tokens == self.pad_token_id) + completion_mask = completion_mask * is_not_truncated.unsqueeze(1).long() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + with torch.no_grad(): + # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) + importance_sampling_ratio = torch.clamp( + importance_sampling_ratio, + max=torch.tensor(self.vllm_importance_sampling_cap, device=importance_sampling_ratio.device), + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + # Here the original implementation used `model.disable_adapters()` instead of having a copy for the + # reference model. We removed it here because it broke with XLA. + raise ValueError("Ref model is None but beta is not 0.0") + else: + ref_per_token_logps = None + + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = rewards - mean_grouped_rewards + + if self.scale_rewards in ["group", "none"]: + # If self.scale_rewards = "none", we'll still log group level std + std_rewards = rewards.view(-1, self.num_generations).std(dim=1) + std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) + elif self.scale_rewards == "batch": + # Compute global std + std_rewards = rewards.std().expand_as(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." + ) + + # Use direct comparison instead of torch.zeros_like to avoid tensor allocation + is_std_zero = std_rewards.abs() < 1e-8 + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + + metrics = defaultdict(list) + logs = {} + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]) + metrics[f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]) + metrics[f"rewards/{reward_func_name}/std"].append(std_func_rewards) + + metrics["reward"].append(mean_grouped_rewards.mean()) + metrics["reward_std"].append(std_rewards.mean()) + metrics["frac_reward_zero_std"].append(is_std_zero.float().mean()) + + # Log prompt and completion texts + to_gather = [prompts_text, completions_text] + if images is not None: + to_gather.append(images) + + if self.fixed_size_obj_collectives is not None: + fixed_size = len(to_gather) * self.fixed_size_obj_collectives + else: + fixed_size = None + + gathered = gather_object_from_data_parallel_group( + to_gather, + fixed_size=fixed_size, + ) + gathered_prompts_text = [item[0] for item in gathered] + gathered_completions_text = [item[1] for item in gathered] + self._logs["prompt"].extend(gathered_prompts_text) + self._logs["completion"].extend(gathered_completions_text) + + if images is not None: + gathered_images = [item[2] for item in gathered] + self._logs["images"].extend(gathered_images) + + logs["rewards"] = {} + logs["advantages"] = [] + for i, name in enumerate(self.reward_func_names): + logs["rewards"][name] = rewards_per_func[:, i] + logs["advantages"] = all_process_advantages + + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + # Original code was: + # delta = delta[completion_mask.bool()] + # mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + # But it is not XLA friendly because it involves dynamic indexing before reduction, so we rewrite it as: + completion_mask_count = completion_mask.sum() + delta_masked = delta * completion_mask + sum_delta = delta_masked.sum() + mean_delta = sum_delta / (completion_mask_count + 1e-10) + # We can simply take the max of the masked delta because values in delta are >= 0 (torch.abs). + max_delta = delta_masked.max() + + # Original code was: + # flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + # min_importance_sampling_ratio = ( + # torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # mean_importance_sampling_ratio = ( + # torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # max_importance_sampling_ratio = ( + # torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + # ) + # But it is not XLA friendly because it involves dynamic indexing before reduction, so we rewrite it as: + # Use pre-created inf constant (cast to proper dtype if needed) + inf_val = self._inf_float.to(dtype=importance_sampling_ratio.dtype) + masked_is_ratio_for_min = torch.where( + completion_mask.bool(), + importance_sampling_ratio, + inf_val, + ) + min_importance_sampling_ratio = masked_is_ratio_for_min.min() + # importance_sampling_ratio values are >= 0 (torch.exp) so we can use the same computation as for delta. + flat_is_ratio_masked = importance_sampling_ratio * completion_mask + sum_flat_is_ratio = flat_is_ratio_masked.sum() + mean_importance_sampling_ratio = sum_flat_is_ratio / (completion_mask_count + 1e-10) + max_importance_sampling_ratio = flat_is_ratio_masked.max() + + sampling_metrics_to_gather = { + "mean_delta": mean_delta, + "max_delta": max_delta, + "min_is_ratio": min_importance_sampling_ratio, + "mean_is_ratio": mean_importance_sampling_ratio, + "max_is_ratio": max_importance_sampling_ratio, + } + + stacked = torch.stack(list(sampling_metrics_to_gather.values()), dim=0) + gathered_stacked = self.accelerator.gather(stacked) + gathered = {k: gathered_stacked[i] for i, k in enumerate(sampling_metrics_to_gather.keys())} + metrics["sampling/sampling_logp_difference/mean"].append(gathered["mean_delta"].mean()) + metrics["sampling/sampling_logp_difference/max"].append(gathered["max_delta"].max()) + metrics["sampling/importance_sampling_ratio/min"].append(nanmin(gathered["min_is_ratio"])) + metrics["sampling/importance_sampling_ratio/mean"].append(gathered["mean_is_ratio"].nanmean()) + metrics["sampling/importance_sampling_ratio/max"].append(nanmax(gathered["max_is_ratio"])) + + torch_xla.sync() + metrics = move_all_tensor_to_cpu(metrics) + logs = move_all_tensor_to_cpu(logs) + + # Update the actual metrics and logs. + self._metrics[mode].update(metrics) + for name in self.reward_func_names: + self._logs["rewards"][name].extend(logs["rewards"][name].tolist()) + self._logs["advantages"].extend(logs["advantages"].tolist()) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if self.use_vllm and self.vllm_importance_sampling_correction: + output["importance_sampling_ratio"] = importance_sampling_ratio + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + + return output + + def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Compute a mask for high-entropy tokens (above the given quantile threshold). + """ + pad_value = -1e9 + dtype = entropies.dtype + + # Create pad tensor from pre-allocated constant (avoids allocation in hot path) + # Note: pad_value is negative, so we can't use self._inf_float directly + pad_tensor = torch.full_like(entropies[:1, :1], pad_value).expand_as(entropies) + + masked_entropies = torch.where(mask.bool(), entropies, pad_tensor) + + local_flat = masked_entropies.view(-1) + gathered = self.accelerator.gather(local_flat) + + # Sort gathered values, so that pad_value sentinels are at the beginning + sorted_values, _ = torch.sort(gathered) + + # Compute the number of valid (non-sentinel) values using a tolerance for float comparison + is_sentinel = sorted_values < (pad_value + 1e-6) # pad_value is -1e9 + num_sentinels = is_sentinel.sum() + num_valid_values = gathered.numel() - num_sentinels + + # Get the quantile index and the corresponding entropy threshold value + # Use torch.gather instead of dynamic indexing to maintain XLA compatibility + quantile_idx = num_sentinels + (threshold * num_valid_values.float()).long() + quantile_idx = quantile_idx.clamp( + min=torch.tensor(0, device=quantile_idx.device), + max=torch.tensor(gathered.numel() - 1, device=quantile_idx.device), + ) + + # Use gather for XLA-compatible indexing (gather works with tensor indices) + entropy_threshold = sorted_values.gather(0, quantile_idx.view(1)).squeeze(0) + + # Handle empty case: if everything is sentinel, set threshold to +inf so no token is selected + has_valid = num_valid_values > 0 + inf_val = self._inf_float.to(dtype=dtype) + entropy_threshold = torch.where(has_valid, entropy_threshold, inf_val) + + entropy_mask = (entropies > entropy_threshold) & mask.bool() + return entropy_mask + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + else: + return self._compute_loss(model, inputs) + + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + with self.metrics_collector.time_metric("forward_pass", inputs=inputs): + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=self.args.per_device_train_batch_size, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + ) + + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # Compute the loss + advantages = inputs["advantages"] + # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, + # old_per_token_logps == per_token_logps. In this case we can skip its computation + # (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # The exception is when using vLLM, where we always compute old_per_token_logps + # for importance sampling + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + # Use tensor min value for clamp to avoid torch neuron SDK bug with Python literals + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp( + min=self._one_float, + ) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp( + coef_1, + torch.tensor(1 - self.epsilon_low, device=coef_1.device), + torch.tensor(1 + self.epsilon_high, device=coef_1.device), + ) + + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=torch.tensor(self.args.delta, device=coef_1.device)) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + if self.use_vllm and self.vllm_importance_sampling_correction: + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == "grpo": + loss = ( + (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=self._one_float) + ).mean() + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=self._one_float) + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "dapo": + normalizer = inputs["num_items_in_batch"] + loss = (per_token_loss * completion_mask).sum() / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Log the metrics + mode = "train" if self.model.training else "eval" + + completion_token_count = completion_mask.sum().clamp(min=self._one_float) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count + + metrics = defaultdict(list) + metrics_to_gather = {} + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + metrics_to_gather["kl"] = mean_kl + + mean_entropy = masked_batch_mean(entropies) + metrics_to_gather["entropy"] = mean_entropy + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + metrics_to_gather["low_clip"] = low_clip + metrics_to_gather["high_clip"] = high_clip + metrics_to_gather["clip_ratio"] = clip_ratio + + stacked_metrics = torch.stack(list(metrics_to_gather.values()), dim=0) + gathered_stacked_metrics = self.accelerator.gather(stacked_metrics) + gathered_metrics = {} + + for i, key in enumerate(metrics_to_gather.keys()): + gathered_metrics[key] = gathered_stacked_metrics[i] + + if self.beta != 0.0: + metrics["kl"].append(gathered_metrics["kl"].nanmean()) + metrics["entropy"].append(gathered_metrics["entropy"].nanmean()) + metrics["clip_ratio/low_mean"].append(gathered_metrics["low_clip"].nanmean()) + metrics["clip_ratio/low_min"].append(nanmin(gathered_metrics["low_clip"])) + metrics["clip_ratio/high_mean"].append(gathered_metrics["high_clip"].nanmean()) + metrics["clip_ratio/high_max"].append(nanmax(gathered_metrics["high_clip"])) + metrics["clip_ratio/region_mean"].append(gathered_metrics["clip_ratio"].nanmean()) + + # Move metrics to CPU but keep as tensors. The log() method will call .item() + # when averaging. This defers sync overhead to logging time. + metrics = move_all_tensor_to_cpu(metrics) + torch_xla.sync() + + self._metrics[mode].update(metrics) + + return loss + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + """ + Evaluation and prediction are not supported in NeuronGRPOTrainer. + + The trainer is designed for training only. NeuronTrainer does not provide + evaluation loop functionality, and GRPO-specific evaluation would require + significant additional implementation. + """ + raise NotImplementedError( + "Evaluation and prediction are not supported in NeuronGRPOTrainer. " + "The trainer is designed for training only." + ) diff --git a/optimum/neuron/trainers/metrics/collector.py b/optimum/neuron/trainers/metrics/collector.py index a6d9db3e8..59da8571a 100644 --- a/optimum/neuron/trainers/metrics/collector.py +++ b/optimum/neuron/trainers/metrics/collector.py @@ -206,9 +206,9 @@ def get_metric_window_stats(self, metric_name: str) -> dict: def get_metric_unit(self, metric_name: str) -> str: """Get the unit for a specific metric.""" for plugin in self.active_plugins: - if plugin.handles_metric(metric_name): - units = plugin.get_metric_units() - return units.get(metric_name, "") + metric_units = plugin.get_metric_units() + if metric_name in metric_units: + return metric_units[metric_name] return "" def get_all_metric_units(self) -> dict[str, str]: diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index 48138912f..748af77ef 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn +import torch_xla import torch_xla.core.xla_model as xm from accelerate.utils import AutocastKwargs, DataLoaderConfiguration from neuronx_distributed.parallel_layers.parallel_state import ( @@ -103,7 +104,7 @@ from ..utils.misc import is_main_worker, is_precompilation from .metrics import TrainingMetricsCollector from .training_args import NeuronTrainingArguments -from .utils import XLAPrefetchIterator +from .utils import XLAPrefetchIterator, move_inputs_to_device logger = logging.get_logger() @@ -907,7 +908,7 @@ def setup_training( self.running_loss = torch.zeros(1, dtype=torch.double, device=xm.xla_device()) self.grad_norm = None - xm.mark_step() + torch_xla.sync() def get_batch_samples( self, @@ -943,10 +944,7 @@ def get_batch_samples( if self.pp_size == 1 and device is not None and device.type == "xla": if prefetch_size is None: - for idx, batch in enumerate(batch_samples): - batch_samples[idx] = { - k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() - } + batch_samples = move_inputs_to_device(batch_samples, device) else: batch_samples = XLAPrefetchIterator(batch_samples, prefetch_size) @@ -999,9 +997,23 @@ def compute_loss( return (loss, outputs) if return_outputs else loss + def _prepare_inputs(self, inputs: Any) -> Any: + """ + Prepare inputs before feeding them to the model. + + This is a no-op for standard NeuronTrainer as inputs are already moved to device in get_batch_samples(). + Subclasses can override this method for custom preprocessing (e.g., GRPOTrainer uses + this for generation, scoring, and tokenization). + + """ + return inputs + def training_step( self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch: int | torch.Tensor | None = None ) -> torch.Tensor: + # Prepare inputs (no-op for base trainer, overridden by subclasses for custom preprocessing) + inputs = self._prepare_inputs(inputs) + manager = self.autocast_smart_context_manager() with manager: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) @@ -1031,7 +1043,7 @@ def maybe_log_train_step_metrics(self): return if self.control.should_log: - xm.mark_step() + torch_xla.sync() running_loss_div = self.running_loss / self.dp_size reduced_loss = xm.all_reduce(xm.REDUCE_SUM, running_loss_div, groups=get_data_parallel_replica_groups()) reduced_loss = reduced_loss.detach() @@ -1080,7 +1092,7 @@ def log_closure(): def maybe_save_checkpoint(self): """Save checkpoint if saving is due.""" if self.control.should_save: - xm.mark_step() + torch_xla.sync() def save_closure(self): self._save_checkpoint() @@ -1157,7 +1169,7 @@ def _train( self.metrics_collector.start_metric("total_step") for inputs in batch_samples: - xm.mark_step() + torch_xla.sync() step += 1 do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch @@ -1173,7 +1185,7 @@ def _train( if do_sync_step: self.accelerator.gradient_state.sync_gradients = True - xm.mark_step() + torch_xla.sync() # Gradient clipping self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) @@ -1194,7 +1206,7 @@ def _train( self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - xm.mark_step() + torch_xla.sync() self.metrics_collector.stop_metric("throughput") self.metrics_collector.stop_metric("total_step") self.metrics_collector.end_gradient_accumulation_cycle(step_number=self.state.global_step) @@ -1209,12 +1221,12 @@ def _train( # PyTorch/XLA relies on the data loader to insert the mark_step for # each step. Since we are breaking the loop early, we need to manually # insert the mark_step here. - xm.mark_step() + torch_xla.sync() break # We also need to break out of the nested loop if self.control.should_epoch_stop or self.control.should_training_stop: - xm.mark_step() + torch_xla.sync() break if step < 0: @@ -1226,7 +1238,7 @@ def _train( self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - xm.mark_step() + torch_xla.sync() if self.control.should_training_stop: break @@ -1360,19 +1372,8 @@ def report_and_save_summary_metrics(self): # Group and format metrics for better readability for metric_name, value in summary_metrics.items(): if isinstance(value, float): - if "time" in metric_name: - logger.info(f"{metric_name}: {value:.4f}s") - elif "per_sec" in metric_name: - logger.info(f"{metric_name}: {value:.2f}") - elif ( - "mfu" in metric_name - or "efficiency" in metric_name - or "consistency" in metric_name - or "percent" in metric_name - ): - logger.info(f"{metric_name}: {value:.2f}%") - else: - logger.info(f"{metric_name}: {value:.2f}") + unit = self.metrics_collector.get_metric_unit(metric_name) + logger.info(f"{metric_name}: {value:.2f}{unit}") else: logger.info(f"{metric_name}: {value}") logger.info("=" * 80) diff --git a/optimum/neuron/trainers/trl_utils.py b/optimum/neuron/trainers/trl_utils.py index 4046dd93b..60e402a4e 100644 --- a/optimum/neuron/trainers/trl_utils.py +++ b/optimum/neuron/trainers/trl_utils.py @@ -13,4 +13,258 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from optimum.utils import logging +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from trl.trainer.utils import RepeatSampler + +from ..utils import is_precompilation + + +logger = logging.get_logger() + TRL_VERSION = "0.24.0" + + +def pad( + tensors: list[torch.Tensor], + padding_value: int = 0, + padding_side: str = "right", + max_length: int | None = None, +) -> torch.Tensor: + """ + Pads a list of tensors to the same shape along the first dimension. + It differs from `trl` by enfoncing the same sequence length for all tensors, which is required to avoid + recompilation. + """ + batch_size = len(tensors) + if max_length is None: + max_length = max(t.shape[0] for t in tensors) + + output_shape = (max_length,) + tensors[0].shape[1:] + + # Create an output tensor filled with the padding value + output = torch.full((batch_size, *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) + + for i, t in enumerate(tensors): + if padding_side == "left": + seq_start = output_shape[0] - t.shape[0] + elif padding_side == "right": + seq_start = 0 + else: + raise ValueError("padding_side must be 'left' or 'right'") + + # Define the slices + seq_slice = slice(seq_start, seq_start + t.shape[0]) + slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) + output[i][slices] = t + + return output + + +def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: + """ + Compute the Shannon entropy (in nats) for each row of *logits*. + + Original implementation from trl.trainer.utils.entropy_from_logits provide a memory efficient alternative, + but it accumulates results in a list which can lead to graph fragmentation on XLA devices. + Here we keep things simple and compute the entropy in one go. + """ + logps = F.log_softmax(logits, dim=-1) + entropy = -(torch.exp(logps) * logps).sum(-1) + return entropy + + +def neuron_parallel_compile_tokenizer_decoder_method( + self, + token_ids: int | list[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool | None = None, + **kwargs, +) -> str: + """ + Patched `tokenizer._decode` method for `neuron_parallel_compile`. + This is needed because any tensor operation on the XLA device during `neuron_parallel_compile` produces rubbish + results, which is not an issue in general, but causes failure when the token IDS end up being out of range for the + tokenizer vocabulary. + """ + if not is_precompilation(): + raise RuntimeError("This patch method should only be used with `neuron_parallel_compile`.") + + # We log the token IDs to force the data mouvement to CPU, which would happen during actual decoding. + logger.debug("Using patched tokenizer.decode method for Neuron parallel compilation, token_ids = ", token_ids) + + # Returns a dummy string, we do not care about the value in this context. + return "dummy" + + +def batch_pad_sequences( + sequences: list[list[int | float]], + target_length: int, + padding_value: int | float = 0, + padding_side: Literal["left", "right"] = "right", + dtype: torch.dtype = torch.long, + device: torch.device | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + XLA-optimized batch padding of variable-length sequences. + + Unlike per-sequence padding with list comprehensions, this function: + 1. Pre-allocates output arrays using numpy (fast CPU operations) + 2. Transfers to device as a single operation (one host->device copy) + 3. Creates the mask alongside the padded sequences (no separate allocation) + + This avoids creating many small tensors and multiple device transfers that cause + XLA graph fragmentation. + + Args: + sequences (`list[list[int | float]]`): + List of variable-length sequences. Each sequence is a list of token IDs (ints) + or log probabilities (floats). + target_length (`int`): + Fixed target length for all sequences. Sequences longer than this will be + truncated; shorter sequences will be padded. + padding_value (`int | float`, *optional*, defaults to `0`): + Value to use for padding positions. + padding_side (`Literal["left", "right"]`, *optional*, defaults to `"right"`): + Side on which to add padding. Also determines truncation behavior: + - `"left"`: Pads on left, truncates from left (keeps last tokens) + - `"right"`: Pads on right, truncates from right (keeps first tokens) + dtype (`torch.dtype`, *optional*, defaults to `torch.long`): + Output tensor dtype for the padded sequences. + device (`torch.device | None`, *optional*, defaults to `None`): + Target device for the output tensors. If `None`, tensors remain on CPU. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: + A tuple of `(padded_sequences, mask)` where: + - `padded_sequences` has shape `(batch_size, target_length)` and dtype `dtype` + - `mask` has shape `(batch_size, target_length)` and dtype `torch.long`, with + `1` for real tokens and `0` for padding positions + """ + batch_size = len(sequences) + + # Determine numpy dtype for intermediate computation + if dtype in (torch.float32, torch.float64, torch.float16, torch.bfloat16): + np_dtype = np.float32 + else: + np_dtype = np.int64 + + # Pre-allocate numpy arrays (fast CPU operations) + padded = np.full((batch_size, target_length), padding_value, dtype=np_dtype) + mask = np.zeros((batch_size, target_length), dtype=np.int64) + + for i, seq in enumerate(sequences): + seq_len = len(seq) + if seq_len == 0: + continue + + if seq_len >= target_length: + # Truncation needed + if padding_side == "left": + # Keep last target_length tokens + padded[i] = seq[seq_len - target_length :] + else: + # Keep first target_length tokens + padded[i] = seq[:target_length] + mask[i] = 1 + else: + # Padding needed + if padding_side == "left": + start_idx = target_length - seq_len + padded[i, start_idx:] = seq + mask[i, start_idx:] = 1 + else: + padded[i, :seq_len] = seq + mask[i, :seq_len] = 1 + + # Single conversion and transfer to device + padded_tensor = torch.from_numpy(padded).to(dtype=dtype, device=device) + mask_tensor = torch.from_numpy(mask).to(dtype=torch.long, device=device) + + return padded_tensor, mask_tensor + + +def nanmin(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanmin that doesn't use dynamic indexing. + Compute the minimum value of a tensor, ignoring NaNs. + """ + mask = torch.isnan(tensor) + filled = torch.where(mask, torch.tensor(float("inf"), device=tensor.device), tensor) + min_value = torch.min(filled) + return min_value + + +def nanmax(tensor: torch.Tensor) -> torch.Tensor: + """ + XLA-compatible version of nanmax that doesn't use dynamic indexing. + Compute the maximum value of a tensor, ignoring NaNs. + """ + mask = torch.isnan(tensor) + filled = torch.where(mask, torch.tensor(float("-inf"), device=tensor.device), tensor) + max_value = torch.max(filled) + return max_value + + +def nanstd(tensor: torch.Tensor, unbiased: bool = False) -> torch.Tensor: + """ + XLA-compatible version of nanstd. + Compute the standard deviation of a tensor, ignoring NaNs. + """ + mask = ~torch.isnan(tensor) + count = mask.sum() + + clean = torch.where(mask, tensor, torch.zeros_like(tensor)) + mean = clean.sum() / count + + diff_squared = torch.where(mask, (clean - mean) ** 2, torch.zeros_like(tensor)) + + if unbiased: + variance = diff_squared.sum() / (count - 1).clamp(min=torch.tensor(1.0, device=tensor.device)) + else: + variance = diff_squared.sum() / count + + return variance.sqrt() + + +class DistributedRepeatSampler(RepeatSampler, DistributedSampler): + def __init__( + self, + dataset: Dataset, + mini_repeat_count: int, + batch_size: int = 1, + repeat_count: int = 1, + num_replicas: int | None = None, + rank: int | None = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ): + # Initialize RepeatSampler with the actual sampling logic + RepeatSampler.__init__( + self, + data_source=dataset, + mini_repeat_count=mini_repeat_count, + batch_size=batch_size, + repeat_count=repeat_count, + shuffle=shuffle, + seed=seed, + ) + + # Store DistributedSampler attributes for interface compatibility + # (but we don't use them for actual distribution) + if num_replicas is None: + num_replicas = dist.get_world_size() if dist.is_available() else 1 + if rank is None: + rank = dist.get_rank() if dist.is_available() else 0 + + self.num_replicas = num_replicas + self.rank = rank + self.drop_last = drop_last diff --git a/optimum/neuron/trainers/utils.py b/optimum/neuron/trainers/utils.py index 595f546cb..065be0b60 100644 --- a/optimum/neuron/trainers/utils.py +++ b/optimum/neuron/trainers/utils.py @@ -14,9 +14,23 @@ # limitations under the License. import torch +import torch_xla import torch_xla.core.xla_model as xm +def move_inputs_to_device(inputs, device: torch.device): + if isinstance(inputs, torch.Tensor): + return inputs.to(device) + elif isinstance(inputs, dict): + return {k: move_inputs_to_device(v, device) for k, v in inputs.items()} + elif isinstance(inputs, list): + return [move_inputs_to_device(v, device) for v in inputs] + elif isinstance(inputs, tuple): + return tuple(move_inputs_to_device(v, device) for v in inputs) + else: + return inputs + + class XLAPrefetchIterator: def __init__(self, examples: list[dict[str, torch.Tensor]], prefetch_size: int = 1): self.examples = examples @@ -28,7 +42,7 @@ def __init__(self, examples: list[dict[str, torch.Tensor]], prefetch_size: int = def _prefetch(self): while len(self.buffer) < self.prefetch_size and self.current_index < len(self.examples): example = self.examples[self.current_index] - example_on_xla = {k: v.to(xm.xla_device()) for k, v in example.items()} + example_on_xla = move_inputs_to_device(example, xm.xla_device()) self.buffer.append(example_on_xla) self.current_index += 1 @@ -38,7 +52,7 @@ def __iter__(self): def __next__(self): if not self.buffer: raise StopIteration - xm.mark_step() + torch_xla.sync() next_example = self.buffer.pop(0) self._prefetch() return next_example diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 0265cb514..223880e3d 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -49,6 +49,7 @@ "is_neuronx_available", "is_torch_neuronx_available", "is_trl_available", + "is_vllm_available", ], "input_generators": [ "DTYPE_MAPPER", @@ -117,6 +118,7 @@ is_neuronx_available, is_torch_neuronx_available, is_trl_available, + is_vllm_available, ) from .input_generators import ( DTYPE_MAPPER,