diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index 358f023e9a..6daeb66a67 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -14,10 +14,12 @@ # limitations under the License. from collections import OrderedDict +import math from typing import Unpack import torch import torch.nn as nn +import torch.nn.functional as F import transformer_engine.pytorch import transformers from transformer_engine.pytorch.attention import InferenceParams @@ -44,6 +46,13 @@ class NVLlamaConfig(LlamaConfig): attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + use_moe: bool = False + moe_num_experts: int = 8 + moe_top_k: int = 1 + moe_capacity_factor: float = 1.25 + moe_min_capacity: int = 4 + moe_drop_tokens: bool = True + moe_aux_loss_coef: float = 0.01 class NVLlamaPreTrainedModel(PreTrainedModel): @@ -51,7 +60,7 @@ class NVLlamaPreTrainedModel(PreTrainedModel): config_class = NVLlamaConfig base_model_prefix = "model" - _no_split_modules = ("TransformerLayer",) + _no_split_modules = ("TransformerLayer", "NVLlamaMoETransformerLayer") _skip_keys_device_placement = ("past_key_values",) def init_empty_weights(self): @@ -89,6 +98,255 @@ def _init_weights(self, module): super()._init_weights(module) +def _te_device() -> str: + return "meta" if torch.get_default_device() == torch.device("meta") else "cuda" + + +class NVLlamaMoEFeedForward(nn.Module): + """MoE feed-forward network using Transformer Engine grouped GEMMs.""" + + def __init__(self, config: NVLlamaConfig): + super().__init__() + if config.moe_top_k < 1: + raise ValueError("moe_top_k must be >= 1.") + + self.num_experts = config.moe_num_experts + self.hidden_size = config.hidden_size + self.ffn_hidden_size = config.intermediate_size + self.top_k = config.moe_top_k + self.capacity_factor = config.moe_capacity_factor + self.min_capacity = config.moe_min_capacity + self.drop_tokens = config.moe_drop_tokens + + self.router = nn.Linear( + config.hidden_size, + config.moe_num_experts, + bias=False, + dtype=config.dtype, + device=_te_device(), + ) + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.fc1_gate = transformer_engine.pytorch.GroupedLinear( + config.hidden_size, + config.intermediate_size, + num_gemms=config.moe_num_experts, + bias=False, + params_dtype=config.dtype, + device=_te_device(), + init_method=_init_method, + ) + self.fc1_up = transformer_engine.pytorch.GroupedLinear( + config.hidden_size, + config.intermediate_size, + num_gemms=config.moe_num_experts, + bias=False, + params_dtype=config.dtype, + device=_te_device(), + init_method=_init_method, + ) + self.fc2 = transformer_engine.pytorch.GroupedLinear( + config.intermediate_size, + config.hidden_size, + num_gemms=config.moe_num_experts, + bias=False, + params_dtype=config.dtype, + device=_te_device(), + init_method=_init_method, + ) + + def _compute_capacity(self, num_tokens: int) -> int: + base_capacity = math.ceil(self.capacity_factor * num_tokens * self.top_k / self.num_experts) + return max(self.min_capacity, base_capacity) + + def _select_expert_tokens( + self, + flat_expert_idx: torch.Tensor, + flat_probs: torch.Tensor, + capacity: int, + ) -> torch.Tensor: + if not self.drop_tokens: + return torch.ones_like(flat_expert_idx, dtype=torch.bool) + + selected = torch.zeros_like(flat_expert_idx, dtype=torch.bool) + for expert_id in range(self.num_experts): + expert_mask = flat_expert_idx == expert_id + if not torch.any(expert_mask): + continue + expert_positions = torch.nonzero(expert_mask, as_tuple=False).squeeze(-1) + if expert_positions.numel() <= capacity: + selected[expert_positions] = True + continue + expert_probs = flat_probs[expert_positions] + top_positions = torch.topk(expert_probs, k=capacity, sorted=False).indices + selected[expert_positions[top_positions]] = True + return selected + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: + original_shape = hidden_states.shape + if hidden_states.dim() == 3: + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + logits = self.router(hidden_states) + router_probs = F.softmax(logits, dim=-1, dtype=torch.float32) + topk_probs, topk_indices = torch.topk(router_probs, k=self.top_k, dim=-1) + topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) + + routing_map = torch.zeros( + hidden_states.size(0), + self.num_experts, + dtype=torch.int32, + device=hidden_states.device, + ) + routing_map.scatter_(1, topk_indices, 1) + + importance = router_probs.sum(dim=0) + load = routing_map.to(router_probs.dtype).sum(dim=0) + aux_loss = self.num_experts * torch.sum(importance * load) / ( + hidden_states.size(0) * self.top_k + ) + load_fraction = load / (hidden_states.size(0) * self.top_k) + load_entropy = -torch.sum(load_fraction * torch.log(load_fraction + 1e-9)) + load_max = load_fraction.max() + + token_ids = torch.arange(hidden_states.size(0), device=hidden_states.device) + flat_token_idx = token_ids[:, None].expand(-1, self.top_k).reshape(-1) + flat_expert_idx = topk_indices.reshape(-1) + flat_probs = topk_probs.to(hidden_states.dtype).reshape(-1) + + capacity = self._compute_capacity(hidden_states.size(0)) + selected_mask = self._select_expert_tokens(flat_expert_idx, flat_probs, capacity) + dropped_tokens = flat_expert_idx.numel() - selected_mask.sum() + + selected_token_idx = flat_token_idx[selected_mask] + selected_expert_idx = flat_expert_idx[selected_mask] + selected_probs = flat_probs[selected_mask] + + if selected_token_idx.numel() == 0: + output = hidden_states.new_zeros((hidden_states.size(0), self.hidden_size)) + if len(original_shape) == 3: + output = output.view(original_shape) + stats = { + "load_entropy": load_entropy, + "load_max": load_max, + "dropped_tokens": dropped_tokens.to(hidden_states.dtype), + "capacity": torch.tensor(capacity, device=hidden_states.device, dtype=hidden_states.dtype), + } + return output, aux_loss, stats + + sort_order = torch.argsort(selected_expert_idx, stable=True) + selected_token_idx = selected_token_idx[sort_order] + selected_expert_idx = selected_expert_idx[sort_order] + selected_probs = selected_probs[sort_order] + + permuted = hidden_states.index_select(0, selected_token_idx) + m_splits = torch.bincount(selected_expert_idx, minlength=self.num_experts).tolist() + gate_out = self.fc1_gate(permuted, m_splits) + up_out = self.fc1_up(permuted, m_splits) + moe_out = F.silu(gate_out) * up_out + moe_out = self.fc2(moe_out, m_splits) + + output = hidden_states.new_zeros((hidden_states.size(0), self.hidden_size)) + output.index_add_(0, selected_token_idx, moe_out * selected_probs.unsqueeze(-1)) + + if len(original_shape) == 3: + output = output.view(original_shape) + stats = { + "load_entropy": load_entropy, + "load_max": load_max, + "dropped_tokens": dropped_tokens.to(hidden_states.dtype), + "capacity": torch.tensor(capacity, device=hidden_states.device, dtype=hidden_states.dtype), + } + return output, aux_loss, stats + + +class NVLlamaMoETransformerLayer(nn.Module): + """Transformer block with TE attention and MoE MLP.""" + + def __init__(self, config: NVLlamaConfig, layer_number: int): + super().__init__() + self.attn_input_format = config.attn_input_format + self.self_attn_mask_type = config.self_attn_mask_type + + self.attn_norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device=_te_device(), + ) + self.self_attn = transformer_engine.pytorch.MultiheadAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_gqa_groups=config.num_key_value_heads, + attention_dropout=0.0, + layernorm_epsilon=config.rms_norm_eps, + bias=False, + attn_mask_type=config.self_attn_mask_type, + qkv_format=config.attn_input_format, + qkv_weight_interleaved=True, + layer_number=layer_number, + params_dtype=config.dtype, + device=_te_device(), + ) + self.mlp_norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device=_te_device(), + ) + self.mlp = NVLlamaMoEFeedForward(config) + self.last_moe_aux_loss: torch.Tensor | None = None + self.last_moe_load_entropy: torch.Tensor | None = None + self.last_moe_load_max: torch.Tensor | None = None + self.last_moe_dropped_tokens: torch.Tensor | None = None + self.last_moe_capacity: torch.Tensor | None = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + rotary_pos_emb: torch.Tensor | None = None, + inference_params: InferenceParams | None = None, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_kv: torch.Tensor | None = None, + cu_seqlens_q_padded: torch.Tensor | None = None, + cu_seqlens_kv_padded: torch.Tensor | None = None, + max_seqlen_q: int | None = None, + max_seqlen_kv: int | None = None, + pad_between_seqs: bool | None = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + attn_out = self.self_attn( + hidden_states, + attention_mask=attention_mask, + attn_mask_type=self.self_attn_mask_type, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + pad_between_seqs=pad_between_seqs, + ) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + mlp_out, aux_loss, stats = self.mlp(hidden_states) + self.last_moe_aux_loss = aux_loss + self.last_moe_load_entropy = stats["load_entropy"] + self.last_moe_load_max = stats["load_max"] + self.last_moe_dropped_tokens = stats["dropped_tokens"] + self.last_moe_capacity = stats["capacity"] + hidden_states = residual + mlp_out + return hidden_states + + class NVLlamaModel(NVLlamaPreTrainedModel): """Llama3 model implemented in Transformer Engine.""" @@ -104,37 +362,42 @@ def __init__(self, config: LlamaConfig): def _init_method(x): torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) - self.layers = nn.ModuleList( - [ - transformer_engine.pytorch.TransformerLayer( - hidden_size=config.hidden_size, - ffn_hidden_size=config.intermediate_size, - num_attention_heads=config.num_attention_heads, - bias=False, - layernorm_epsilon=config.rms_norm_eps, - hidden_dropout=0, - attention_dropout=0, - fuse_qkv_params=True, - qkv_weight_interleaved=True, - normalization="RMSNorm", - activation="swiglu", - attn_input_format=config.attn_input_format, - self_attn_mask_type=config.self_attn_mask_type, - num_gqa_groups=config.num_key_value_heads, - layer_number=layer_idx + 1, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=_init_method, - output_layer_init_method=_init_method, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) + if config.use_moe: + self.layers = nn.ModuleList( + [NVLlamaMoETransformerLayer(config, layer_idx + 1) for layer_idx in range(config.num_hidden_layers)] + ) + else: + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + activation="swiglu", + attn_input_format=config.attn_input_format, + self_attn_mask_type=config.self_attn_mask_type, + num_gqa_groups=config.num_key_value_heads, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device=_te_device(), + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) self.norm = transformer_engine.pytorch.RMSNorm( config.hidden_size, eps=config.rms_norm_eps, dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + device=_te_device(), ) # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original @@ -173,6 +436,12 @@ def forward( """ all_hidden_states = [] output_hidden_states = kwargs.get("output_hidden_states", False) + moe_aux_loss: torch.Tensor | None = None + moe_load_entropy: torch.Tensor | None = None + moe_load_max: torch.Tensor | None = None + moe_dropped_tokens: torch.Tensor | None = None + moe_capacity: torch.Tensor | None = None + moe_layers = 0 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -238,6 +507,34 @@ def forward( max_seqlen_kv=kwargs.get("max_length_k", None), pad_between_seqs=kwargs.get("pad_between_seqs", None), ) + if self.config.use_moe and isinstance(decoder_layer, NVLlamaMoETransformerLayer): + if decoder_layer.last_moe_aux_loss is not None: + moe_aux_loss = ( + decoder_layer.last_moe_aux_loss + if moe_aux_loss is None + else moe_aux_loss + decoder_layer.last_moe_aux_loss + ) + moe_load_entropy = ( + decoder_layer.last_moe_load_entropy + if moe_load_entropy is None + else moe_load_entropy + decoder_layer.last_moe_load_entropy + ) + moe_load_max = ( + decoder_layer.last_moe_load_max + if moe_load_max is None + else moe_load_max + decoder_layer.last_moe_load_max + ) + moe_dropped_tokens = ( + decoder_layer.last_moe_dropped_tokens + if moe_dropped_tokens is None + else moe_dropped_tokens + decoder_layer.last_moe_dropped_tokens + ) + moe_capacity = ( + decoder_layer.last_moe_capacity + if moe_capacity is None + else moe_capacity + decoder_layer.last_moe_capacity + ) + moe_layers += 1 hidden_states = self.norm(hidden_states) @@ -250,11 +547,18 @@ def forward( # If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output. hidden_states = _pad_input(hidden_states, indices, batch_size, max_seqlen) - return BaseModelOutputWithPast( + outputs = BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states if output_hidden_states else None, ) + if moe_aux_loss is not None and moe_layers > 0: + outputs.moe_aux_loss = moe_aux_loss / moe_layers + outputs.moe_load_entropy = moe_load_entropy / moe_layers + outputs.moe_load_max = moe_load_max / moe_layers + outputs.moe_dropped_tokens = moe_dropped_tokens / moe_layers + outputs.moe_capacity = moe_capacity / moe_layers + return outputs class NVLlamaForCausalLM(NVLlamaPreTrainedModel, transformers.GenerationMixin): @@ -334,14 +638,24 @@ def forward( loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + moe_aux_loss = getattr(outputs, "moe_aux_loss", None) + if moe_aux_loss is not None: + loss = loss + self.config.moe_aux_loss_coef * moe_aux_loss - return CausalLMOutputWithPast( + lm_outputs = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + if hasattr(outputs, "moe_aux_loss"): + lm_outputs.moe_aux_loss = outputs.moe_aux_loss + lm_outputs.moe_load_entropy = outputs.moe_load_entropy + lm_outputs.moe_load_max = outputs.moe_load_max + lm_outputs.moe_dropped_tokens = outputs.moe_dropped_tokens + lm_outputs.moe_capacity = outputs.moe_capacity + return lm_outputs class NVLlamaForSequenceClassification( # noqa: D101 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index cba8836de3..1e9a5f9175 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -3,6 +3,15 @@ use_te: true # Whether to use TransformerEngine layers through NVLlamaForCausalL config_name_or_path: ??? # E.g., meta-llama/Llama-3.2-1B or ./model_configs/meta-llama/Llama-3.2-1B config_kwargs: {} +# MoE config (only used when use_te: true and use_moe: true) +use_moe: false +moe_num_experts: 8 +moe_top_k: 1 +moe_capacity_factor: 1.25 +moe_min_capacity: 4 +moe_drop_tokens: true +moe_aux_loss_coef: 0.01 + num_train_steps: ??? grad_acc_steps: 1 # Gradient accumulation steps - effective batch = micro_batch_size * num_gpus * grad_acc_steps diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 56ccd7a349..c8aa27ba86 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -64,6 +64,16 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), } + if args.use_moe: + metrics_dict.update( + { + "train/moe_aux_loss": torchmetrics.MeanMetric(), + "train/moe_load_entropy": torchmetrics.MeanMetric(), + "train/moe_load_max": torchmetrics.MeanMetric(), + "train/moe_dropped_tokens": torchmetrics.MeanMetric(), + "train/moe_capacity": torchmetrics.MeanMetric(), + } + ) self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. @@ -85,6 +95,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.num_unpadded_tokens = 0 self.running_loss = 0.0 self.grad_acc_step_count = 0 + self.running_moe_aux_loss = 0.0 + self.running_moe_load_entropy = 0.0 + self.running_moe_load_max = 0.0 + self.running_moe_dropped_tokens = 0.0 + self.running_moe_capacity = 0.0 # Whether to step debug_api.step() after each step self.fp8_stats_enabled = args.fp8_stats_config.enabled @@ -105,6 +120,12 @@ def log_micro_step(self, batch: dict[str, torch.Tensor], outputs: CausalLMOutput # Fallback for pure sequence packing with no padding: all tokens are unpadded self.num_unpadded_tokens += batch["input_ids"].numel() self.running_loss += outputs.loss.item() + if hasattr(outputs, "moe_aux_loss") and outputs.moe_aux_loss is not None: + self.running_moe_aux_loss += outputs.moe_aux_loss.item() + self.running_moe_load_entropy += outputs.moe_load_entropy.item() + self.running_moe_load_max += outputs.moe_load_max.item() + self.running_moe_dropped_tokens += outputs.moe_dropped_tokens.item() + self.running_moe_capacity += outputs.moe_capacity.item() def log_step( self, @@ -136,6 +157,14 @@ def log_step( self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens / self.logging_frequency) + if "train/moe_aux_loss" in self.metrics: + self.metrics["train/moe_aux_loss"].update(self.running_moe_aux_loss / self.grad_acc_step_count) + self.metrics["train/moe_load_entropy"].update(self.running_moe_load_entropy / self.grad_acc_step_count) + self.metrics["train/moe_load_max"].update(self.running_moe_load_max / self.grad_acc_step_count) + self.metrics["train/moe_dropped_tokens"].update( + self.running_moe_dropped_tokens / self.grad_acc_step_count + ) + self.metrics["train/moe_capacity"].update(self.running_moe_capacity / self.grad_acc_step_count) if self._profiler is not None: self._profiler.step() @@ -165,6 +194,11 @@ def log_step( self.num_unpadded_tokens = 0 self.running_loss = 0.0 self.grad_acc_step_count = 0 + self.running_moe_aux_loss = 0.0 + self.running_moe_load_entropy = 0.0 + self.running_moe_load_max = 0.0 + self.running_moe_dropped_tokens = 0.0 + self.running_moe_capacity = 0.0 def finish(self): """Finish the logger and close the progress bar.""" diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 10a28a27cf..5e959840e6 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -85,6 +85,14 @@ def main(args: DictConfig) -> float | None: # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B". config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + if args.use_te and isinstance(config, NVLlamaConfig): + config.use_moe = args.use_moe + config.moe_num_experts = args.moe_num_experts + config.moe_top_k = args.moe_top_k + config.moe_capacity_factor = args.moe_capacity_factor + config.moe_min_capacity = args.moe_min_capacity + config.moe_drop_tokens = args.moe_drop_tokens + config.moe_aux_loss_coef = args.moe_aux_loss_coef # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8