From 8c94f85cdf8908d2c67855861e09aa2c060a4845 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 23 Jan 2026 12:28:55 -0800 Subject: [PATCH 01/21] tries to get fp4 working Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/defaults.yaml | 8 + .../recipes/esm2_native_te/modeling_esm_te.py | 657 ++++++++++++++++++ .../recipes/esm2_native_te/train_fsdp2.py | 33 +- 3 files changed, 684 insertions(+), 14 deletions(-) create mode 100644 bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index baace7c805..87c624aca0 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -51,6 +51,14 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + fp4_model_init_kwargs: + enabled: false # If this is set to true, fp4_config.enabled must also be set to true. + # Optimizer config adamw_kwargs: lr: 4e-4 diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py new file mode 100644 index 0000000000..a12e3ef32e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -0,0 +1,657 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs + + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + for layer_module in self.layers: + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + @classmethod + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): + """Override the default get_init_context method to allow for fp8 model initialization.""" + return [] + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys = ("lm_head.decoder.weight",) + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.init_weights() + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 28409e0c17..174cfcdd48 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -29,6 +29,8 @@ from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + # This import seems to be needed with meta device init and AutoModel.from_config from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 @@ -57,12 +59,6 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled and not args.fp8_config.enabled: - raise ValueError( - "fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" - ) - if args.fp8_stats_config.enabled: fp8_stats_file = args.fp8_stats_config.fp8_stats_file fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" @@ -84,12 +80,18 @@ def main(args: DictConfig) -> float | None: ) # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) - + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + elif args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + else: + print("No FP8 or FP4 config enabled, using default bfloat16") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -99,9 +101,9 @@ def main(args: DictConfig) -> float | None: # versions of weights are kept. with ( torch.device("meta") if args.use_meta_device else nullcontext(), - transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe, **args.fp8_config.fp8_model_init_kwargs), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + # model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) @@ -164,8 +166,11 @@ def main(args: DictConfig) -> float | None: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + fp_context = transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe) if args.fp8_config.enabled else nullcontext() + fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp4_recipe) if args.fp4_config.enabled else fp_context + # Note: FOr NVFP4 it looks like its just autocast? https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting # Forward pass with mixed precision. - with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): + with fp_context: outputs = model(**batch) # Backward pass. From 8514d7bc5fca9676dfec990666de3f536a211459 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 23 Jan 2026 12:50:14 -0800 Subject: [PATCH 02/21] refactors fp8 stats logs Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp8_debugging_stats.yaml | 7 ++++++- .../esm2_native_te/hydra_config/defaults.yaml | 7 ++++--- .../recipes/esm2_native_te/train_fsdp2.py | 20 +++++++++---------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml index 7544bbedcf..9653d8a044 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad, fprop] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 87c624aca0..a8e6ac88a1 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -84,7 +84,8 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: + +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 174cfcdd48..1bb569541c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -59,16 +59,16 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - if args.fp8_stats_config.enabled: - fp8_stats_file = args.fp8_stats_config.fp8_stats_file - fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") + if args.quant_stats_config.enabled: + quant_stats_file = args.quant_stats_config.quant_stats_file + quant_log_dir = Path(args.quant_stats_config.quant_log_dir) / f"rank_{dist_config.rank}" + quant_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {quant_log_dir}") te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") debug_api.initialize( - config_file=fp8_stats_file, + config_file=quant_stats_file, feature_dirs=[te_features_dir], - log_dir=fp8_log_dir, + log_dir=quant_log_dir, default_logging_enabled=True, ) @@ -125,7 +125,7 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). @@ -184,7 +184,7 @@ def main(args: DictConfig) -> float | None: optimizer.step() scheduler.step() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.step() optimizer.zero_grad() @@ -228,7 +228,7 @@ def main(args: DictConfig) -> float | None: # Clean up distributed training perf_logger.finish() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.end_debug() torch.distributed.destroy_process_group() From 09840d059ed51facbafb16586447cd2b256d7ebc Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 23 Jan 2026 15:59:09 -0800 Subject: [PATCH 03/21] fp4 debugging stats yaml Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp4_debugging_stats.yaml | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml new file mode 100644 index 0000000000..81d6f4a421 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,24 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the actual linear layers within attention that support FP8 stats + layer_types: [layernorm_qkv, proj, fc1, fc2] + transformer_engine: + # Uncomment once https://github.com/NVIDIA/TransformerEngine/pull/2296 is merged. + # LogFp4TensorStats: + # enabled: True + # tensors_struct: + # - tensor: activation + # stats: [underflows%, scale_inv_min, scale_inv_max, mse] + # freq: 10 + # - tensor: gradient + # stats: [underflows%, scale_inv_min, scale_inv_max, mse] + # freq: 10 + # - tensor: weight + # stats: [underflows%, scale_inv_min, scale_inv_max, mse] + # freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] #TODO: Can you get underflows% here? + tensors: [dgrad, wgrad, fprop] + freq: 1 From 0f62e9885c4e29b52499484476294d864d43eba3 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Sat, 24 Jan 2026 11:18:23 -0800 Subject: [PATCH 04/21] BF16 last 6 layers Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/modeling_esm_te.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index a12e3ef32e..0a7ed6e5b1 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -41,7 +41,7 @@ from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel from transformers.utils import logging from transformers.utils.generic import TransformersKwargs - +from contextlib import nullcontext logger = logging.get_logger(__name__) @@ -199,22 +199,36 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + # Set some layers to BF16. (28-33) (This will be from a config later). + # TODO: Also make sure this is only for FP4, not FP8 + layers_to_bf16 = {self.layers[-1], + self.layers[-2], + self.layers[-3], + self.layers[-4], + self.layers[-5], + self.layers[-6]} for layer_module in self.layers: + if layer_module in layers_to_bf16: + fp_context = transformer_engine.pytorch.autocast(enabled=False) + else: + fp_context = nullcontext() + if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + with fp_context: + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) hidden_states = self.emb_layer_norm_after(hidden_states) From 9c32457e92e6294b51cc8f5c3cf5a0bf4eedaeb7 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 26 Jan 2026 10:47:05 -0800 Subject: [PATCH 05/21] sets bf16 layers thru cli Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/L1_650M.yaml | 9 +++++++++ .../esm2_native_te/hydra_config/defaults.yaml | 2 ++ .../recipes/esm2_native_te/modeling_esm_te.py | 13 ++++++------- .../recipes/esm2_native_te/train_fsdp2.py | 3 ++- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index fd027601df..fae42a0b51 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -17,3 +17,12 @@ wandb_init_args: checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" + +# Layers explicitly set to BF16 in case of NVFP4 training. +bf16_layers: + - 27 + - 28 + - 29 + - 30 + - 31 + - 32 \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index a8e6ac88a1..6f6cf85341 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -89,3 +89,5 @@ quant_stats_config: enabled: false quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats + +bf16_layers: null \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 0a7ed6e5b1..f59516ff74 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + bf16_layers: Optional[list[int]] = None, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -111,7 +112,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type - + self.bf16_layers = bf16_layers # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -201,12 +202,10 @@ def forward( # Set some layers to BF16. (28-33) (This will be from a config later). # TODO: Also make sure this is only for FP4, not FP8 - layers_to_bf16 = {self.layers[-1], - self.layers[-2], - self.layers[-3], - self.layers[-4], - self.layers[-5], - self.layers[-6]} + layers_to_bf16 = set() + if self.config.bf16_layers is not None: + layers_to_bf16 = set(self.layers[layer_idx] for layer_idx in self.config.bf16_layers) + for layer_module in self.layers: if layer_module in layers_to_bf16: fp_context = transformer_engine.pytorch.autocast(enabled=False) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 1bb569541c..bbde4f1b61 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -91,7 +91,8 @@ def main(args: DictConfig) -> float | None: else: print("No FP8 or FP4 config enabled, using default bfloat16") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + bf16_layers = OmegaConf.to_container(args.bf16_layers, resolve=True) if args.bf16_layers is not None and args.fp4_config.enabled else None + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16, bf16_layers=bf16_layers) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" From 196d0c7f0a8b064f5b886a7cd55a8b7ef190433b Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 26 Jan 2026 11:38:20 -0800 Subject: [PATCH 06/21] donwgrade te version Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/requirements.txt b/bionemo-recipes/recipes/esm2_native_te/requirements.txt index b18607fd7a..0602ca8a83 100644 --- a/bionemo-recipes/recipes/esm2_native_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_native_te/requirements.txt @@ -8,6 +8,6 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers +transformers==4.57.3 wandb nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect From 265911da3e16f624becba32c4b88bb0a6528e416 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 14:00:22 -0800 Subject: [PATCH 07/21] layer specific autocast Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/L1_650M.yaml | 12 +++++-- .../esm2_native_te/hydra_config/defaults.yaml | 4 ++- .../recipes/esm2_native_te/modeling_esm_te.py | 20 ++++++------ .../recipes/esm2_native_te/train_ddp.py | 2 +- .../recipes/esm2_native_te/train_fsdp2.py | 32 +++++++++++++------ 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index fae42a0b51..f6108d5736 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -19,10 +19,18 @@ checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" # Layers explicitly set to BF16 in case of NVFP4 training. -bf16_layers: +fp8_layers: - 27 - 28 - 29 - 30 - 31 - - 32 \ No newline at end of file + - 32 + +fp4_layers: + - 0 + - 14 + - 15 + - 16 + +use_fp32_optimizer_weights: true \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 6f6cf85341..da6f9f47cb 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -90,4 +90,6 @@ quant_stats_config: quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats -bf16_layers: null \ No newline at end of file +fp8_layers: null +fp4_layers: null +use_fp32_optimizer_weights: false \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index f59516ff74..b2d1a75df9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -166,6 +166,7 @@ def _init_method(x): for i in range(config.num_hidden_layers) ] ) + self.layer_number_quantized_recipe_map = None self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, @@ -200,20 +201,17 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - # Set some layers to BF16. (28-33) (This will be from a config later). - # TODO: Also make sure this is only for FP4, not FP8 - layers_to_bf16 = set() - if self.config.bf16_layers is not None: - layers_to_bf16 = set(self.layers[layer_idx] for layer_idx in self.config.bf16_layers) - - for layer_module in self.layers: - if layer_module in layers_to_bf16: - fp_context = transformer_engine.pytorch.autocast(enabled=False) - else: - fp_context = nullcontext() + # Utilize the layer number quantized recipe map to determine the context for each layer. + for layer_number, layer_module in enumerate(self.layers): + fp_recipe = self.layer_number_quantized_recipe_map[layer_number] if layer_number in self.layer_number_quantized_recipe_map else None if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) + + if fp_recipe is not None: + fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) + else: + fp_context = nullcontext() with fp_context: hidden_states = layer_module( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 1027703f3c..002cbb94db 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -113,7 +113,7 @@ def main(args: DictConfig) -> float | None: device_ids=[dist_config.local_rank], output_device=dist_config.local_rank, device_mesh=device_mesh["ddp"], - ) + ) #TODO: Try BF16 compute weights with FP32 master weights here. # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. train_dataloader, dataset_or_sampler = ( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index bbde4f1b61..d8a377c96c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -84,15 +84,16 @@ def main(args: DictConfig) -> float | None: fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) - elif args.fp4_config.enabled: + + if args.fp4_config.enabled: fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs ) - else: - print("No FP8 or FP4 config enabled, using default bfloat16") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - bf16_layers = OmegaConf.to_container(args.bf16_layers, resolve=True) if args.bf16_layers is not None and args.fp4_config.enabled else None - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16, bf16_layers=bf16_layers) + fp8_layers = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None + fp4_layers = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None + + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -112,9 +113,21 @@ def main(args: DictConfig) -> float | None: transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"]) # TODO: Update mixed precision policy to set it to FP#2 fully_shard(model, mesh=device_mesh["dp"]) + # Create a layer map for the transformer stack. + layer_number_quantized_recipe_map = {} + for layer_number, layer in enumerate(transformer_stack): + + if layer_number in fp8_layers: + layer_number_quantized_recipe_map[layer_number] = fp8_recipe + elif layer_number in fp4_layers: + layer_number_quantized_recipe_map[layer_number] = fp4_recipe + else: + layer_number_quantized_recipe_map[layer_number] = None + + model.esm.encoder.layer_number_quantized_recipe_map = layer_number_quantized_recipe_map # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. # Note, this should happen before we create the optimizer. if args.use_meta_device: @@ -167,11 +180,12 @@ def main(args: DictConfig) -> float | None: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - fp_context = transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe) if args.fp8_config.enabled else nullcontext() - fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp4_recipe) if args.fp4_config.enabled else fp_context + # Note: FOr NVFP4 it looks like its just autocast? https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting # Forward pass with mixed precision. - with fp_context: + # Make the FP context just MXFP8. Then use NVFP4 for certain layers. + # with fp_context: #TODO: I think I can get rid of this, and just do it inside forward. + with transformer_engine.pytorch.autocast(): outputs = model(**batch) # Backward pass. From c5e472bde5aacb6ef84bf19f0c5cf7b3d24ca5fb Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 15:01:47 -0800 Subject: [PATCH 08/21] enables layer specific fp recipes Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/L1_650M.yaml | 24 ++++++++++++++++++- .../recipes/esm2_native_te/modeling_esm_te.py | 18 ++++++++++++-- .../recipes/esm2_native_te/train_fsdp2.py | 8 ++----- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index f6108d5736..e39c4b3985 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -20,6 +20,14 @@ checkpoint: # Layers explicitly set to BF16 in case of NVFP4 training. fp8_layers: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 - 27 - 28 - 29 @@ -28,9 +36,23 @@ fp8_layers: - 32 fp4_layers: - - 0 + - 9 + - 10 + - 11 + - 12 + - 13 - 14 - 15 - 16 + - 17 + - 18 + - 19 + - 20 + - 21 + - 22 + - 23 + - 24 + - 25 + - 26 use_fp32_optimizer_weights: true \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index b2d1a75df9..2053cdd4b1 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -37,6 +37,7 @@ MaskedLMOutput, TokenClassifierOutput, ) +import transformer_engine.common.recipe from transformers.models.esm.configuration_esm import EsmConfig from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel from transformers.utils import logging @@ -54,6 +55,13 @@ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", } +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = (transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling) +FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling) + class NVEsmConfig(EsmConfig): """NVEsmConfig is a configuration for the NVEsm model.""" @@ -208,10 +216,16 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - if fp_recipe is not None: + # If BF16 desired --> use autocast(false) so it goes to BF16. + # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. + # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. + if isinstance(fp_recipe, FP8_RECIPES): + fp_context = nullcontext() + elif isinstance(fp_recipe, FP4_RECIPES): fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) else: - fp_context = nullcontext() + fp_context = transformer_engine.pytorch.autocast(enabled=False) + # TODO(@jomitchell): Double check that this works, make a funciton for it then unit test it. with fp_context: hidden_states = layer_module( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index d8a377c96c..0d25cb94c7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -179,13 +179,9 @@ def main(args: DictConfig) -> float | None: while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - - # Note: FOr NVFP4 it looks like its just autocast? https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting - # Forward pass with mixed precision. - # Make the FP context just MXFP8. Then use NVFP4 for certain layers. - # with fp_context: #TODO: I think I can get rid of this, and just do it inside forward. - with transformer_engine.pytorch.autocast(): + # Use an outer FP8 recipe. + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): outputs = model(**batch) # Backward pass. From 2e2229cee53a53fccf68c4fa8b398aadbd068f9f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 15:35:44 -0800 Subject: [PATCH 09/21] adds fp32 optim weights with bf16 compute weights Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/train_fsdp2.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 0d25cb94c7..3a5a3c4e99 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -22,10 +22,12 @@ import torch import transformer_engine import transformer_engine.pytorch + from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.optim import AdamW + from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM @@ -93,7 +95,7 @@ def main(args: DictConfig) -> float | None: fp8_layers = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None fp4_layers = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -112,9 +114,14 @@ def main(args: DictConfig) -> float | None: # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward + reduce_dtype=torch.bfloat16, # Gradient reductions in BF16 + output_dtype=torch.bfloat16, # Forward output dtype + ) for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"]) # TODO: Update mixed precision policy to set it to FP#2 - fully_shard(model, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) # TODO: Update mixed precision policy to set it to FP#2 + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # Create a layer map for the transformer stack. layer_number_quantized_recipe_map = {} From 8acd2a94eb27c19f112367b1c97218350e7781a4 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 16:00:25 -0800 Subject: [PATCH 10/21] enables grad reduce in fp32 for better precision Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 3a5a3c4e99..c665308ebe 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -116,11 +116,11 @@ def main(args: DictConfig) -> float | None: mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward - reduce_dtype=torch.bfloat16, # Gradient reductions in BF16 + reduce_dtype=torch.float32, # Gradient reductions in FP32 output_dtype=torch.bfloat16, # Forward output dtype ) for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) # TODO: Update mixed precision policy to set it to FP#2 + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # Create a layer map for the transformer stack. From 589479d2eac54d8e947be444e9c819230a7a93ee Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 17:38:15 -0800 Subject: [PATCH 11/21] adds FusedAdam for fun Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index c665308ebe..861133f8f2 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -28,6 +28,7 @@ from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.optim import AdamW +from transformer_engine.pytorch.optimizers import FusedAdam from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM @@ -151,6 +152,15 @@ def main(args: DictConfig) -> float | None: # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + # optimizer = FusedAdam(model.parameters(), + # lr=4e-4, + # betas=(0.9, 0.98), + # eps=1e-8, + # weight_decay=0.01, + # master_weights=True, + # master_weight_dtype=torch.float32, + # ) + # Note: Got an error about mixed torch.Tensor and DTensor here, so using AdamW instead. scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. From c978c495b4f1fe9ef5ceda1015a263d7e2638dc2 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 3 Feb 2026 15:58:42 -0800 Subject: [PATCH 12/21] fixes up debugging yaml and adds dockerfile for te tot Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/Dockerfile | 15 ++++++++++++++- .../esm2_native_te/fp4_debugging_stats.yaml | 17 +++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index b940874af3..c388ddf563 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,9 +1,22 @@ # syntax=docker/dockerfile:1.4 FROM nvcr.io/nvidia/pytorch:25.12-py3 +# Install sccache for faster builds +RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + apt-get update \ + && apt-get install -y sccache \ + && rm -rf /var/lib/apt/lists/* + +# Uninstall pre-installed Transformer Engine and install from source +RUN pip uninstall -y transformer-engine && \ + NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ + pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main + +# Install BioNeMo requirements +WORKDIR /workspace/bionemo RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ PIP_CONSTRAINT= pip install -r /requirements.txt -WORKDIR /workspace/bionemo COPY . . diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index 81d6f4a421..83de927644 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -4,8 +4,7 @@ example_fp8_tensor_stat_collection: # Match the actual linear layers within attention that support FP8 stats layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: - # Uncomment once https://github.com/NVIDIA/TransformerEngine/pull/2296 is merged. - # LogFp4TensorStats: + # LogFp8TensorStats: # enabled: True # tensors_struct: # - tensor: activation @@ -14,11 +13,17 @@ example_fp8_tensor_stat_collection: # - tensor: gradient # stats: [underflows%, scale_inv_min, scale_inv_max, mse] # freq: 10 - # - tensor: weight - # stats: [underflows%, scale_inv_min, scale_inv_max, mse] - # freq: 10 + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 10 + - tensor: gradient + stats: [underflows%, mse] + freq: 10 LogTensorStats: enabled: True stats: [max, min, mean, std, l1_norm] #TODO: Can you get underflows% here? tensors: [dgrad, wgrad, fprop] - freq: 1 + freq: 10 From c353607b9ad43a04da6eb100179b98079f1a2eab Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 3 Feb 2026 16:31:09 -0800 Subject: [PATCH 13/21] inject layer regex patterns for fp4 fp8 Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp4_debugging_stats.yaml | 38 +++++---- .../esm2_native_te/hydra_config/L1_650M.yaml | 3 +- .../esm2_native_te/hydra_config/defaults.yaml | 1 + .../recipes/esm2_native_te/train_fsdp2.py | 80 ++++++++++++++++++- 4 files changed, 101 insertions(+), 21 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index 83de927644..480be54ddd 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -1,18 +1,10 @@ -example_fp8_tensor_stat_collection: +example_fp4_tensor_stat_collection: enabled: True layers: - # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv, proj, fc1, fc2] + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' transformer_engine: - # LogFp8TensorStats: - # enabled: True - # tensors_struct: - # - tensor: activation - # stats: [underflows%, scale_inv_min, scale_inv_max, mse] - # freq: 10 - # - tensor: gradient - # stats: [underflows%, scale_inv_min, scale_inv_max, mse] - # freq: 10 LogNvfp4TensorStats: enabled: True tensors_struct: @@ -22,8 +14,20 @@ example_fp8_tensor_stat_collection: - tensor: gradient stats: [underflows%, mse] freq: 10 - LogTensorStats: - enabled: True - stats: [max, min, mean, std, l1_norm] #TODO: Can you get underflows% here? - tensors: [dgrad, wgrad, fprop] - freq: 10 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 + - tensor: gradient + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index e39c4b3985..d71fc0375d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -18,7 +18,7 @@ wandb_init_args: checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" -# Layers explicitly set to BF16 in case of NVFP4 training. +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: - 1 - 2 @@ -34,6 +34,7 @@ fp8_layers: - 30 - 31 - 32 + - 33 fp4_layers: - 9 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index da6f9f47cb..00287ee83d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -90,6 +90,7 @@ quant_stats_config: quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null use_fp32_optimizer_weights: false \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 861133f8f2..6746aaaaa2 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import tempfile from contextlib import nullcontext from pathlib import Path @@ -22,6 +23,7 @@ import torch import transformer_engine import transformer_engine.pytorch +import yaml from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh @@ -48,6 +50,65 @@ logger.setLevel(logging.INFO) +def generate_layer_regex(layer_numbers: list[int]) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in logs). + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + return "" + # Use alternation for arbitrary layer lists: (1|2|3|4|5) + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (may be a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + # Update FP4 section if it exists and fp4_layers is provided + if fp4_layers and "example_fp4_tensor_stat_collection" in config: + fp4_regex = generate_layer_regex(fp4_layers) + config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + + # Update FP8 section if it exists and fp8_layers is provided + if fp8_layers and "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + + # Write to a temp file to avoid modifying the original + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + # Log the updated config for visibility + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + @hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") def main(args: DictConfig) -> float | None: """Train ESM-2 with TE layers using fsdp2. @@ -62,8 +123,24 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) + # Parse layer lists first (1-indexed from args, used for both quant stats and internal recipe mapping) + fp8_layers_1indexed = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None + fp4_layers_1indexed = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None + + # Convert to 0-indexed for internal use + fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed else None + fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed else None + if args.quant_stats_config.enabled: quant_stats_file = args.quant_stats_config.quant_stats_file + + # Update the quant stats config with layer-specific regex patterns (using 1-indexed layer numbers) + quant_stats_file = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=fp4_layers_1indexed, + fp8_layers=fp8_layers_1indexed, + ) + quant_log_dir = Path(args.quant_stats_config.quant_log_dir) / f"rank_{dist_config.rank}" quant_log_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Logging quant stats to {quant_log_dir}") @@ -92,9 +169,6 @@ def main(args: DictConfig) -> float | None: fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs ) - # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - fp8_layers = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None - fp4_layers = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. From 879c50d42dadae6d81e87ef5cbc6662a8ceba1d8 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 3 Feb 2026 17:40:27 -0800 Subject: [PATCH 14/21] enables fp4 layer with nothing Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp4_debugging_stats.yaml | 4 +- .../recipes/esm2_native_te/train_fsdp2.py | 50 +++++++++++++------ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index 480be54ddd..d2e9da08e6 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -26,8 +26,8 @@ example_fp8_tensor_stat_collection: enabled: True tensors_struct: - tensor: activation - stats: [underflows%, scale_inv_min, scale_inv_max, mse] + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] freq: 10 - tensor: gradient - stats: [underflows%, scale_inv_min, scale_inv_max, mse] + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] freq: 10 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 6746aaaaa2..f089b0db1e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -50,17 +50,19 @@ logger.setLevel(logging.INFO) -def generate_layer_regex(layer_numbers: list[int]) -> str: +def generate_layer_regex(layer_numbers: list[int] | None) -> str: """Generate a regex pattern to match specific layer numbers (1-indexed). Args: layer_numbers: List of layer numbers (1-indexed, as shown in logs). + If empty or None, returns a pattern that matches nothing. Returns: Regex pattern string for matching those layers' linear sublayers. """ if not layer_numbers: - return "" + # Return a pattern that matches nothing (non-existent layer) + return r"model\.esm\.encoder\.layers\.DISABLED_NO_LAYERS_SPECIFIED" # Use alternation for arbitrary layer lists: (1|2|3|4|5) layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" @@ -80,21 +82,40 @@ def update_quant_stats_config( Returns: Path to the updated config file (may be a temp file). + + Raises: + ValueError: If fp4_layers and fp8_layers have overlapping layer numbers. """ + # Check for overlapping layers + fp4_set = set(fp4_layers) if fp4_layers else set() + fp8_set = set(fp8_layers) if fp8_layers else set() + overlap = fp4_set & fp8_set + if overlap: + raise ValueError( + f"fp4_layers and fp8_layers cannot have overlapping layer numbers. " + f"Found overlap: {sorted(overlap)}" + ) + with open(config_file, "r") as f: config = yaml.safe_load(f) - # Update FP4 section if it exists and fp4_layers is provided - if fp4_layers and "example_fp4_tensor_stat_collection" in config: + # Update FP4 section if it exists (always update, even if empty to disable matching) + if "example_fp4_tensor_stat_collection" in config: fp4_regex = generate_layer_regex(fp4_layers) config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex - logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + if fp4_layers: + logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + else: + logger.info("FP4 layers empty - regex set to match nothing") - # Update FP8 section if it exists and fp8_layers is provided - if fp8_layers and "example_fp8_tensor_stat_collection" in config: + # Update FP8 section if it exists (always update, even if empty to disable matching) + if "example_fp8_tensor_stat_collection" in config: fp8_regex = generate_layer_regex(fp8_layers) config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex - logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") # Write to a temp file to avoid modifying the original temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) @@ -127,9 +148,9 @@ def main(args: DictConfig) -> float | None: fp8_layers_1indexed = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None fp4_layers_1indexed = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None - # Convert to 0-indexed for internal use - fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed else None - fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed else None + # Convert to 0-indexed for internal use (use 'is not None' to handle empty lists correctly) + fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed is not None else None + fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed is not None else None if args.quant_stats_config.enabled: quant_stats_file = args.quant_stats_config.quant_stats_file @@ -200,11 +221,12 @@ def main(args: DictConfig) -> float | None: # Create a layer map for the transformer stack. layer_number_quantized_recipe_map = {} + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() for layer_number, layer in enumerate(transformer_stack): - - if layer_number in fp8_layers: + if layer_number in fp8_layers_set: layer_number_quantized_recipe_map[layer_number] = fp8_recipe - elif layer_number in fp4_layers: + elif layer_number in fp4_layers_set: layer_number_quantized_recipe_map[layer_number] = fp4_recipe else: layer_number_quantized_recipe_map[layer_number] = None From 5154fea683c0c500734ef7b33dbed1a7de25075f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 4 Feb 2026 13:37:43 -0800 Subject: [PATCH 15/21] pins autotokenizer to previous revision Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/dataset.py | 2 +- .../recipes/esm2_native_te/fp4_debugging_stats.yaml | 8 ++++---- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 9 ++++++++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index c915f30ea8..8b8c0f06bf 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -56,7 +56,7 @@ def create_tokenized_dataset( ) dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision="d81c2e5aec37b5e794d0482e3996fb045a137792") def tokenize_function(examples): """Tokenize the protein sequences.""" diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index d2e9da08e6..d56739a6a6 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -10,10 +10,10 @@ example_fp4_tensor_stat_collection: tensors_struct: - tensor: activation stats: [underflows%, mse] - freq: 10 + freq: 100 - tensor: gradient stats: [underflows%, mse] - freq: 10 + freq: 100 example_fp8_tensor_stat_collection: enabled: True @@ -27,7 +27,7 @@ example_fp8_tensor_stat_collection: tensors_struct: - tensor: activation stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] - freq: 10 + freq: 100 - tensor: gradient stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] - freq: 10 + freq: 100 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index f089b0db1e..be62b7f5de 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -17,6 +17,7 @@ import tempfile from contextlib import nullcontext from pathlib import Path +from torch.profiler import profile, ProfilerActivity import hydra import nvdlfw_inspect.api as debug_api @@ -294,8 +295,14 @@ def main(args: DictConfig) -> float | None: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 # Use an outer FP8 recipe. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None): outputs = model(**batch) + + # if step == 5: # Profile step 5 + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + # with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + # outputs = model(**batch) + # logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) # Backward pass. loss = outputs.loss From d880fc93f09632688e2504f666cd2d1676993f3b Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 4 Feb 2026 18:32:28 -0800 Subject: [PATCH 16/21] adds Dockerfile.te_testing for TE build from src but bad perf on it Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/Dockerfile | 17 ++------------ .../esm2_native_te/Dockerfile.te_testing | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index c388ddf563..71a793b1b7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,22 +1,9 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:25.12-py3 +FROM nvcr.io/nvidia/pytorch:25.11-py3 -# Install sccache for faster builds -RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ - --mount=type=cache,target=/var/cache/apt,sharing=locked \ - apt-get update \ - && apt-get install -y sccache \ - && rm -rf /var/lib/apt/lists/* - -# Uninstall pre-installed Transformer Engine and install from source -RUN pip uninstall -y transformer-engine && \ - NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ - pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main - -# Install BioNeMo requirements -WORKDIR /workspace/bionemo RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ PIP_CONSTRAINT= pip install -r /requirements.txt +WORKDIR /workspace/bionemo COPY . . diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing b/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing new file mode 100644 index 0000000000..c388ddf563 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing @@ -0,0 +1,22 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:25.12-py3 + +# Install sccache for faster builds +RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + apt-get update \ + && apt-get install -y sccache \ + && rm -rf /var/lib/apt/lists/* + +# Uninstall pre-installed Transformer Engine and install from source +RUN pip uninstall -y transformer-engine && \ + NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ + pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main + +# Install BioNeMo requirements +WORKDIR /workspace/bionemo +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +COPY . . From b6042f17b6c8e203b694d52cc999bbeee8e44054 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 9 Feb 2026 10:24:34 -0800 Subject: [PATCH 17/21] enables tokenizer revision Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/dataset.py | 7 +++++-- .../hydra_config/L1_15B_perf_test.yaml | 1 + .../recipes/esm2_native_te/hydra_config/L1_3B.yaml | 1 + .../esm2_native_te/hydra_config/L1_650M.yaml | 2 +- .../esm2_native_te/hydra_config/defaults.yaml | 3 ++- .../recipes/esm2_native_te/modeling_esm_te.py | 1 + .../recipes/esm2_native_te/train_fsdp2.py | 14 +++++++++----- 7 files changed, 20 insertions(+), 9 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index 8b8c0f06bf..78def1ef96 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -42,6 +42,7 @@ def create_tokenized_dataset( max_seq_length: int = 1024, buffer_size: int = 10_000, use_lazy_tokenization: bool = True, + tokenizer_revision: str | None = None, ): """Create a tokenized dataset.""" logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") @@ -56,7 +57,7 @@ def create_tokenized_dataset( ) dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision="d81c2e5aec37b5e794d0482e3996fb045a137792") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None) def tokenize_function(examples): """Tokenize the protein sequences.""" @@ -167,6 +168,7 @@ def create_thd_dataloader( use_stateful_dataloader: bool = False, mlm_probability: float = 0.15, pad_sequences_to_be_divisible_by: int | None = None, + tokenizer_revision: str | None = None, ): """Create a dataloader that packs up to the maximum number of tokens per batch. @@ -186,7 +188,7 @@ def create_thd_dataloader( mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. This is useful for context parallelism. Defaults to None. - + tokenizer_revision: The revision of the tokenizer to use. Defaults to None. Returns: A dataloader that can be used for training. """ @@ -196,6 +198,7 @@ def create_thd_dataloader( load_dataset_kwargs=load_dataset_kwargs, max_seq_length=max_seq_length, buffer_size=buffer_size, + tokenizer_revision=tokenizer_revision, ) assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml index 0b91c5608a..2b6f602e3e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml @@ -8,6 +8,7 @@ num_train_steps: 500 dataset: micro_batch_size: 12 + tokenizer_revision: "f29e20d2b10d0aba2036831df65cdca1befe926f" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml index e8e47d908f..3e055907ca 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml @@ -8,6 +8,7 @@ num_train_steps: 10_000 dataset: micro_batch_size: 16 + tokenizer_revision: "86a86f18e6bb1eb4bcf91c594e1c0ad446d8eec6" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index d71fc0375d..fc1153dc31 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -8,7 +8,7 @@ num_train_steps: 200 dataset: micro_batch_size: 4 - + tokenizer_revision: "d81c2e5aec37b5e794d0482e3996fb045a137792" # WandB config wandb_init_args: name: "esm2_t33_650M_UR50D" diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 00287ee83d..0cbc271212 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -13,6 +13,7 @@ cp_size: 1 use_sequence_packing: false dataset: tokenizer_name: ${model_tag} + tokenizer_revision: null micro_batch_size: ??? num_workers: 1 max_seq_length: 1024 @@ -93,4 +94,4 @@ quant_stats_config: # Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null -use_fp32_optimizer_weights: false \ No newline at end of file +use_fp32_master_weights: null \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 2053cdd4b1..1932c0f4c7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -216,6 +216,7 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) + import pdb; pdb.set_trace() # If BF16 desired --> use autocast(false) so it goes to BF16. # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index be62b7f5de..cfbe109a42 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -192,7 +192,7 @@ def main(args: DictConfig) -> float | None: fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs ) - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -216,10 +216,14 @@ def main(args: DictConfig) -> float | None: reduce_dtype=torch.float32, # Gradient reductions in FP32 output_dtype=torch.bfloat16, # Forward output dtype ) - for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) - fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) - + if args.use_fp32_master_weights: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + else: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) # Create a layer map for the transformer stack. layer_number_quantized_recipe_map = {} fp8_layers_set = set(fp8_layers) if fp8_layers else set() From 269961bd72866d9af2f9b9bc691415d0a16e5074 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 10 Feb 2026 11:35:07 -0800 Subject: [PATCH 18/21] adds nsight profiling Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/defaults.yaml | 10 ++++ .../recipes/esm2_native_te/modeling_esm_te.py | 6 +- .../recipes/esm2_native_te/train_fsdp2.py | 56 +++++++++++++++---- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 0cbc271212..2682cb3a96 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -91,6 +91,16 @@ quant_stats_config: quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats +# Nsight Systems profiling config. +# To use, wrap your launch command with: +# nsys profile -s none -t cuda,nvtx -o --force-overwrite true \ +# --capture-range=cudaProfilerApi --capture-range-end=stop +nsys_profiling: + enabled: false + start_step: 5 # Step at which to start CUDA profiler capture + end_step: 8 # Step at which to stop CUDA profiler capture + ranks: [0] # Which ranks to profile (list of ints) + # Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 1932c0f4c7..bd9ae9674e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -27,6 +27,7 @@ # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. import torch +import torch.cuda.nvtx as nvtx import transformer_engine.pytorch from torch import nn from torch.nn import CrossEntropyLoss @@ -216,7 +217,6 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - import pdb; pdb.set_trace() # If BF16 desired --> use autocast(false) so it goes to BF16. # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. @@ -228,6 +228,7 @@ def forward( fp_context = transformer_engine.pytorch.autocast(enabled=False) # TODO(@jomitchell): Double check that this works, make a funciton for it then unit test it. + nvtx.range_push(f"encoder_layer_{layer_number}") with fp_context: hidden_states = layer_module( hidden_states, @@ -241,8 +242,11 @@ def forward( max_seqlen_kv=kwargs.get("max_length_k", None), pad_between_seqs=kwargs.get("pad_between_seqs", None), ) + nvtx.range_pop() # encoder_layer_N + nvtx.range_push("emb_layer_norm_after") hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index cfbe109a42..bc6c45d793 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -22,6 +22,7 @@ import hydra import nvdlfw_inspect.api as debug_api import torch +import torch.cuda.nvtx as nvtx import transformer_engine import transformer_engine.pytorch import yaml @@ -292,37 +293,60 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + # Nsight Systems profiling setup. + nsys_enabled = args.nsys_profiling.enabled + nsys_start_step = args.nsys_profiling.start_step if nsys_enabled else -1 + nsys_end_step = args.nsys_profiling.end_step if nsys_enabled else -1 + nsys_ranks = set(OmegaConf.to_container(args.nsys_profiling.ranks, resolve=True)) if nsys_enabled else set() + nsys_profiling_active = False + + if nsys_enabled and dist_config.rank in nsys_ranks: + logger.info( + f"Nsight profiling enabled for rank {dist_config.rank}: " + f"will capture steps [{nsys_start_step}, {nsys_end_step})" + ) + # Training loop step = start_step while step < args.num_train_steps: for batch in train_dataloader: + # --- Nsys: start profiler at the configured step --- + if nsys_enabled and step == nsys_start_step and dist_config.rank in nsys_ranks: + logger.info(f"[Rank {dist_config.rank}] Starting nsys capture at step {step}") + torch.cuda.cudart().cudaProfilerStart() + nsys_profiling_active = True + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - # Use an outer FP8 recipe. + # --- Forward pass --- + nvtx.range_push(f"step_{step}") + nvtx.range_push("forward") with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None): outputs = model(**batch) - - # if step == 5: # Profile step 5 - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - # with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - # outputs = model(**batch) - # logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + nvtx.range_pop() # forward - # Backward pass. + # --- Backward pass --- + nvtx.range_push("backward") loss = outputs.loss loss.backward() + nvtx.range_pop() # backward - # Compute and clip gradient norms. + # --- Grad clip --- + nvtx.range_push("clip_grad_norm") total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + nvtx.range_pop() # clip_grad_norm - # Step optimizer. + # --- Optimizer step --- + nvtx.range_push("optimizer_step") optimizer.step() scheduler.step() + nvtx.range_pop() # optimizer_step if args.quant_stats_config.enabled: debug_api.step() optimizer.zero_grad() + nvtx.range_pop() # step_N perf_logger.log_step( step=step, @@ -345,6 +369,12 @@ def main(args: DictConfig) -> float | None: max_checkpoints=args.checkpoint.max_checkpoints, ) + # --- Nsys: stop profiler at the configured step --- + if nsys_profiling_active and step >= nsys_end_step: + logger.info(f"[Rank {dist_config.rank}] Stopping nsys capture at step {step}") + torch.cuda.cudart().cudaProfilerStop() + nsys_profiling_active = False + step += 1 if step >= args.num_train_steps: break @@ -361,6 +391,12 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, ) + # Ensure nsys profiler is stopped if training ended before end_step. + if nsys_profiling_active: + logger.info(f"[Rank {dist_config.rank}] Stopping nsys capture at end of training (step {step})") + torch.cuda.cudart().cudaProfilerStop() + nsys_profiling_active = False + # Clean up distributed training perf_logger.finish() if args.quant_stats_config.enabled: From 514e499d0a024bd4796db712341d07088893ae7e Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 10 Feb 2026 11:39:57 -0800 Subject: [PATCH 19/21] adds gemm profiling scripts Signed-off-by: Jonathan Mitchell --- .../gemm_benchmarking/gemm_benchmark.py | 683 +++++++++++++++ .../gemm_benchmark_withshapes.py | 717 +++++++++++++++ .../gemm_benchmarking/profiler_gemm.py | 558 ++++++++++++ .../gemm_benchmarking/roofline.py | 522 +++++++++++ .../roofline_prequantized.py | 753 ++++++++++++++++ .../roofline_prequantized_with_shapes.py | 772 ++++++++++++++++ .../roofline_prequantized_with_shapes_mb.py | 829 ++++++++++++++++++ .../roofline_prequantized_withtorchao.py | 753 ++++++++++++++++ 8 files changed, 5587 insertions(+) create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py create mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py new file mode 100644 index 0000000000..f9819b5bc8 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py @@ -0,0 +1,683 @@ + +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from typing import Optional + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp8_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (requires fp8_dtype argument) + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP8 + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + + # Output buffer + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp8, False, # A, transA + B_fp8, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp8, False, + B_fp8, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + except Exception as e: + print(f"Warning: FP8 prequantized benchmark failed: {e}") + return None + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def benchmark_te_fp4_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (uses default kFloat4E2M1, but being explicit) + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP4 + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + + # Output buffer + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp4, False, # A, transA + B_fp4, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp4, False, + B_fp4, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + except Exception as e: + print(f"Warning: FP4 prequantized benchmark failed: {e}") + return None + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Full quantize path (realistic training overhead)") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + # Select benchmark functions based on pre_quantize flag + fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 + fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"\nPlot saved to: {output_path}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--pre-quantize", + action="store_true", + help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help="Comma-separated list of square matrix sizes (e.g., '1024,2048,4096')" + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + sizes = [int(s.strip()) for s in args.shapes.split(",")] + shapes = [(s, s, s) for s in sizes] + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup, + pre_quantize=args.pre_quantize + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py new file mode 100644 index 0000000000..a9c8df9d89 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py @@ -0,0 +1,717 @@ +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from typing import Optional + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp8_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (requires fp8_dtype argument) + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # Create BF16 tensors in the layout expected by tex.generic_gemm. + # + # Important: tex.generic_gemm has its own conventions for (A, B, transa, transb) and + # expected output orientation. With the settings used below (transa=False, transb=True), + # and with A shaped (K, M) and B shaped (K, N), TransformerEngine expects the output D + # to be shaped (N, M) (note the swapped order). This is fine for benchmarking throughput + # (FLOP count is still 2*M*N*K); it's just a layout convention. + A = torch.randn(K, M, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + # Pre-quantize to FP8 + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + + # Output buffer (see note above about expected (N, M) orientation) + D = torch.empty(N, M, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp8, False, # A, transA + B_fp8, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp8, False, + B_fp8, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + except Exception as e: + print(f"Warning: FP8 prequantized benchmark failed: {e}") + return None + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + import sys; sys.exit(0) + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def benchmark_te_fp4_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (uses default kFloat4E2M1, but being explicit) + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # Create BF16 tensors in the layout expected by tex.generic_gemm. + # See FP8 pre-quantized path above for rationale (including expected D orientation). + A = torch.randn(K, M, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + # Pre-quantize to FP4 + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + + # Output buffer (see note above about expected (N, M) orientation) + D = torch.empty(N, M, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp4, False, # A, transA + B_fp4, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp4, False, + B_fp4, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + except Exception as e: + print(f"Warning: FP4 prequantized benchmark failed: {e}") + return None + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: + """Parse a shapes argument into a list of (M, K, N) tuples. + + Supports either: + - Square sizes: "1024,2048,4096" -> [(1024, 1024, 1024), ...] + - Explicit triplets: "8192x5120x15360,8192x5120x5120" + """ + items = [s.strip() for s in shapes_arg.split(",") if s.strip()] + if not items: + raise ValueError("Empty --shapes argument.") + + shapes: list[tuple[int, int, int]] = [] + for item in items: + if "x" in item: + parts = [p.strip() for p in item.lower().split("x")] + if len(parts) != 3: + raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") + m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) + shapes.append((m, k, n)) + else: + size = int(item) + shapes.append((size, size, size)) + + return shapes + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Full quantize path (realistic training overhead)") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + # Select benchmark functions based on pre_quantize flag + fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 + fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"\nPlot saved to: {output_path}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--pre-quantize", + action="store_true", + help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help=( + "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " + "or explicit triplets like '8192x5120x15360,8192x5120x5120'." + ), + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + shapes = parse_shapes_arg(args.shapes) + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup, + pre_quantize=args.pre_quantize + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py new file mode 100644 index 0000000000..1d97e9bb11 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python3 +""" +GEMM Profiler with Power/Clock Monitoring + +Detailed profiling of a specific GEMM size with GPU telemetry to understand +performance characteristics and potential throttling. + +Usage: + python profiler_gemm.py --size 1536 --precision bf16 + python profiler_gemm.py --size 1536 --precision fp8 --pre-quantize + python profiler_gemm.py --size 1536 --precision fp4 --pre-quantize --with-leading-kernel +""" + +import argparse +import time +import threading +import torch +from dataclasses import dataclass, field +from typing import Optional, List +import subprocess +import json + +# Try to import pynvml for GPU monitoring +try: + import pynvml + PYNVML_AVAILABLE = True +except ImportError: + PYNVML_AVAILABLE = False + print("Warning: pynvml not available. Install with: pip install pynvml") + +# Optional TE import +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available.") + + +@dataclass +class GPUTelemetry: + """Container for GPU telemetry samples.""" + timestamps: List[float] = field(default_factory=list) + power_watts: List[float] = field(default_factory=list) + temperature_c: List[int] = field(default_factory=list) + sm_clock_mhz: List[int] = field(default_factory=list) + memory_clock_mhz: List[int] = field(default_factory=list) + gpu_utilization: List[int] = field(default_factory=list) + + +class GPUMonitor: + """Background thread for monitoring GPU telemetry.""" + + def __init__(self, device_id: int = 0, sample_interval_ms: float = 10): + self.device_id = device_id + self.sample_interval = sample_interval_ms / 1000.0 + self.telemetry = GPUTelemetry() + self._running = False + self._thread = None + self._handle = None + + def start(self): + if not PYNVML_AVAILABLE: + print("Warning: pynvml not available, skipping GPU monitoring") + return + + pynvml.nvmlInit() + self._handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id) + self._running = True + self._thread = threading.Thread(target=self._monitor_loop, daemon=True) + self._thread.start() + + def stop(self) -> GPUTelemetry: + self._running = False + if self._thread: + self._thread.join(timeout=1.0) + if PYNVML_AVAILABLE: + pynvml.nvmlShutdown() + return self.telemetry + + def _monitor_loop(self): + start_time = time.perf_counter() + while self._running: + try: + now = time.perf_counter() - start_time + + # Power + power_mw = pynvml.nvmlDeviceGetPowerUsage(self._handle) + power_w = power_mw / 1000.0 + + # Temperature + temp = pynvml.nvmlDeviceGetTemperature(self._handle, pynvml.NVML_TEMPERATURE_GPU) + + # Clocks + sm_clock = pynvml.nvmlDeviceGetClockInfo(self._handle, pynvml.NVML_CLOCK_SM) + mem_clock = pynvml.nvmlDeviceGetClockInfo(self._handle, pynvml.NVML_CLOCK_MEM) + + # Utilization + util = pynvml.nvmlDeviceGetUtilizationRates(self._handle) + + self.telemetry.timestamps.append(now) + self.telemetry.power_watts.append(power_w) + self.telemetry.temperature_c.append(temp) + self.telemetry.sm_clock_mhz.append(sm_clock) + self.telemetry.memory_clock_mhz.append(mem_clock) + self.telemetry.gpu_utilization.append(util.gpu) + + except Exception as e: + pass # Ignore sampling errors + + time.sleep(self.sample_interval) + + +def get_gpu_info(): + """Get current GPU info using nvidia-smi.""" + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=name,power.limit,clocks.max.sm,clocks.max.memory', + '--format=csv,noheader,nounits'], + capture_output=True, text=True + ) + if result.returncode == 0: + parts = result.stdout.strip().split(', ') + return { + 'name': parts[0], + 'power_limit_w': float(parts[1]), + 'max_sm_clock_mhz': int(parts[2]), + 'max_mem_clock_mhz': int(parts[3]), + } + except: + pass + return None + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + return 2 * M * N * K + + +def profile_bf16_gemm(M: int, K: int, N: int, num_warmup: int, num_iters: int, + with_leading_kernel: bool) -> tuple: + """Profile BF16 GEMM with telemetry.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + if with_leading_kernel: + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + # Warmup + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Start monitoring + monitor = GPUMonitor(sample_interval_ms=5) + monitor.start() + + # Give monitor a moment to start + time.sleep(0.01) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if with_leading_kernel: + _ = torch.matmul(A_large, B_large) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + # Stop monitoring + telemetry = monitor.stop() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + tflops = (flops / (avg_time_ms / 1000.0)) / 1e12 + + return tflops, avg_time_ms, telemetry + + +def profile_fp8_gemm(M: int, K: int, N: int, num_warmup: int, num_iters: int, + with_leading_kernel: bool, pre_quantize: bool) -> tuple: + """Profile FP8 GEMM with telemetry.""" + if not TE_AVAILABLE: + return None, None, None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + if pre_quantize: + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + workspace_size = 32 * 1024 * 1024 + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + if with_leading_kernel: + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp8 = quantizer.quantize(A_large) + B_large_fp8 = quantizer.quantize(B_large) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + def run_gemm(): + tex.generic_gemm( + A_fp8, False, B_fp8, True, D, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + def run_large_gemm(): + tex.generic_gemm( + A_large_fp8, False, B_large_fp8, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + else: + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + if with_leading_kernel: + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Warmup + if pre_quantize: + for _ in range(num_warmup): + run_gemm() + else: + with te.autocast(enabled=True, recipe=fp8_recipe): + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Start monitoring + monitor = GPUMonitor(sample_interval_ms=5) + monitor.start() + time.sleep(0.01) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if pre_quantize: + if with_leading_kernel: + run_large_gemm() + + start_event.record() + for _ in range(num_iters): + run_gemm() + end_event.record() + else: + with te.autocast(enabled=True, recipe=fp8_recipe): + if with_leading_kernel: + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + telemetry = monitor.stop() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + tflops = (flops / (avg_time_ms / 1000.0)) / 1e12 + + return tflops, avg_time_ms, telemetry + + +def profile_fp4_gemm(M: int, K: int, N: int, num_warmup: int, num_iters: int, + with_leading_kernel: bool, pre_quantize: bool) -> tuple: + """Profile FP4 GEMM with telemetry.""" + if not TE_AVAILABLE: + return None, None, None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + if pre_quantize: + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + workspace_size = 32 * 1024 * 1024 + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + if with_leading_kernel: + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp4 = quantizer.quantize(A_large) + B_large_fp4 = quantizer.quantize(B_large) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + def run_gemm(): + tex.generic_gemm( + A_fp4, False, B_fp4, True, D, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + def run_large_gemm(): + tex.generic_gemm( + A_large_fp4, False, B_large_fp4, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + else: + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + if with_leading_kernel: + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Warmup + if pre_quantize: + for _ in range(num_warmup): + run_gemm() + else: + with te.autocast(enabled=True, recipe=fp4_recipe): + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Start monitoring + monitor = GPUMonitor(sample_interval_ms=5) + monitor.start() + time.sleep(0.01) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if pre_quantize: + if with_leading_kernel: + run_large_gemm() + + start_event.record() + for _ in range(num_iters): + run_gemm() + end_event.record() + else: + with te.autocast(enabled=True, recipe=fp4_recipe): + if with_leading_kernel: + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + telemetry = monitor.stop() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + tflops = (flops / (avg_time_ms / 1000.0)) / 1e12 + + return tflops, avg_time_ms, telemetry + + +def print_telemetry_summary(telemetry: GPUTelemetry, gpu_info: dict): + """Print summary of GPU telemetry.""" + if not telemetry.timestamps: + print("\nNo telemetry data collected (pynvml not available?)") + return + + print("\n" + "=" * 60) + print("GPU TELEMETRY SUMMARY") + print("=" * 60) + + # Power + avg_power = sum(telemetry.power_watts) / len(telemetry.power_watts) + max_power = max(telemetry.power_watts) + min_power = min(telemetry.power_watts) + power_limit = gpu_info.get('power_limit_w', 0) if gpu_info else 0 + print(f"\nPower (W):") + print(f" Avg: {avg_power:.1f} Min: {min_power:.1f} Max: {max_power:.1f} Limit: {power_limit:.0f}") + if power_limit > 0: + print(f" Utilization: {100 * avg_power / power_limit:.1f}% of limit") + + # Temperature + avg_temp = sum(telemetry.temperature_c) / len(telemetry.temperature_c) + max_temp = max(telemetry.temperature_c) + print(f"\nTemperature (°C):") + print(f" Avg: {avg_temp:.0f} Max: {max_temp:.0f}") + + # SM Clock + avg_sm = sum(telemetry.sm_clock_mhz) / len(telemetry.sm_clock_mhz) + max_sm = max(telemetry.sm_clock_mhz) + min_sm = min(telemetry.sm_clock_mhz) + max_sm_possible = gpu_info.get('max_sm_clock_mhz', 0) if gpu_info else 0 + print(f"\nSM Clock (MHz):") + print(f" Avg: {avg_sm:.0f} Min: {min_sm} Max: {max_sm} GPU Max: {max_sm_possible}") + if max_sm_possible > 0: + print(f" Running at: {100 * avg_sm / max_sm_possible:.1f}% of max clock") + + # Check for throttling indicators + print("\n" + "-" * 60) + print("THROTTLING ANALYSIS:") + if power_limit > 0 and max_power >= power_limit * 0.95: + print(" ⚠️ Power usage near limit - possible power throttling") + else: + print(" ✓ Power usage below limit") + + if max_sm_possible > 0 and avg_sm < max_sm_possible * 0.9: + print(f" ⚠️ SM clocks below max ({avg_sm:.0f} vs {max_sm_possible} MHz)") + else: + print(" ✓ SM clocks near max") + + clock_variance = max_sm - min_sm + if clock_variance > 100: + print(f" ⚠️ Clock variance: {clock_variance} MHz (unstable clocks)") + else: + print(f" ✓ Clock variance: {clock_variance} MHz (stable)") + + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="GEMM Profiler with GPU Telemetry") + parser.add_argument("--size", "-s", type=int, default=1536, + help="Matrix size (square MxKxN)") + parser.add_argument("--precision", "-p", choices=['bf16', 'fp8', 'fp4'], default='bf16', + help="Precision to benchmark") + parser.add_argument("--num-warmup", type=int, default=50, + help="Warmup iterations") + parser.add_argument("--num-iters", type=int, default=500, + help="Timed iterations") + parser.add_argument("--pre-quantize", action="store_true", + help="Use pre-quantized inputs (FP8/FP4 only)") + parser.add_argument("--with-leading-kernel", action="store_true", + help="Run a large GEMM before the timed kernels") + parser.add_argument("--compare", action="store_true", + help="Run both with and without leading kernel for comparison") + parser.add_argument("--gpu-warmup", type=float, default=3.0, + help="Seconds to warm up GPU before profiling") + + args = parser.parse_args() + + M = K = N = args.size + + # Get GPU info + gpu_info = get_gpu_info() + print("\n" + "=" * 70) + print("GEMM PROFILER") + print("=" * 70) + if gpu_info: + print(f"GPU: {gpu_info['name']}") + print(f"Power Limit: {gpu_info['power_limit_w']:.0f}W") + print(f"Max SM Clock: {gpu_info['max_sm_clock_mhz']} MHz") + + print(f"\nConfiguration:") + print(f" Shape: {M}x{K}x{N}") + print(f" Precision: {args.precision.upper()}") + print(f" Iterations: {args.num_warmup} warmup + {args.num_iters} timed") + print(f" Pre-quantize: {args.pre_quantize}") + + # GPU warmup + if args.gpu_warmup > 0: + print(f"\nWarming up GPU for {args.gpu_warmup:.1f} seconds...") + device = torch.device("cuda") + warmup_a = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + warmup_b = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + start = time.time() + while time.time() - start < args.gpu_warmup: + _ = torch.matmul(warmup_a, warmup_b) + torch.cuda.synchronize() + del warmup_a, warmup_b + torch.cuda.empty_cache() + print("GPU warmup complete.") + + if args.compare: + # Run both configurations + configs = [ + ("Without leading kernel", False), + ("With leading kernel", True), + ] + else: + configs = [( + "With leading kernel" if args.with_leading_kernel else "Without leading kernel", + args.with_leading_kernel + )] + + results = [] + + for config_name, use_leading in configs: + print(f"\n{'='*60}") + print(f"Configuration: {config_name}") + print('='*60) + + if args.precision == 'bf16': + tflops, avg_ms, telemetry = profile_bf16_gemm( + M, K, N, args.num_warmup, args.num_iters, use_leading + ) + elif args.precision == 'fp8': + tflops, avg_ms, telemetry = profile_fp8_gemm( + M, K, N, args.num_warmup, args.num_iters, use_leading, args.pre_quantize + ) + elif args.precision == 'fp4': + tflops, avg_ms, telemetry = profile_fp4_gemm( + M, K, N, args.num_warmup, args.num_iters, use_leading, args.pre_quantize + ) + + if tflops is not None: + print(f"\nResults:") + print(f" TFLOPS: {tflops:.1f}") + print(f" Avg time: {avg_ms:.4f} ms") + print_telemetry_summary(telemetry, gpu_info) + results.append((config_name, tflops, avg_ms)) + else: + print("Benchmark failed or not available") + + # Print comparison summary + if len(results) == 2: + print(f"\n{'=' * 60}") + print("COMPARISON SUMMARY") + print("=" * 60) + name1, tflops1, ms1 = results[0] + name2, tflops2, ms2 = results[1] + + print(f"\n {name1}:") + print(f" {tflops1:.1f} TFLOPS, {ms1:.4f} ms") + print(f"\n {name2}:") + print(f" {tflops2:.1f} TFLOPS, {ms2:.4f} ms") + + diff_pct = 100 * (tflops2 - tflops1) / tflops1 + print(f"\n Difference: {diff_pct:+.1f}%") + + if diff_pct < -5: + print(f"\n ⚠️ Leading kernel hurts performance") + print(f" Likely cause: power/thermal throttling from the leading kernel") + elif diff_pct > 5: + print(f"\n ✓ Leading kernel helps performance") + print(f" Without it, CPU dispatch overhead was being measured") + else: + print(f"\n ~ Minimal difference") + print(f" CPU dispatch overhead is not significant for this size") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py new file mode 100644 index 0000000000..6b2e8373c0 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py @@ -0,0 +1,522 @@ +#!/usr/bin/env python3 +""" +GEMM Benchmarking Script with TFLOPS Measurement and Plotting + +Benchmarks matrix multiplication performance across different precisions +(BF16, FP8 via Transformer Engine) and generates a comparison plot. + +Usage: + python gemm_benchmark.py [--output plot.png] [--num-iters 100] +""" + +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: + """Parse a shapes argument into a list of (M, K, N) tuples. + + Supports either: + - Square sizes: "1024,2048,4096" -> [(1024,1024,1024), ...] + - Explicit triplets: "8192x5120x15360,8192x5120x5120" + """ + items = [s.strip() for s in shapes_arg.split(",") if s.strip()] + if not items: + raise ValueError("Empty --shapes argument.") + + shapes: list[tuple[int, int, int]] = [] + for item in items: + if "x" in item: + parts = [p.strip() for p in item.lower().split("x")] + if len(parts) != 3: + raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") + m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) + shapes.append((m, k, n)) + else: + size = int(item) + shapes.append((size, size, size)) + + return shapes + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0 +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = benchmark_te_fp8(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = benchmark_te_fp4(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + output_path_obj = Path(output_path) + supported_formats = set(fig.canvas.get_supported_filetypes().keys()) + suffix = output_path_obj.suffix.lower().lstrip(".") + if suffix not in supported_formats: + output_path_obj = output_path_obj.with_suffix(".png") + print( + f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " + f"saving to '{output_path_obj}' instead." + ) + plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") + print(f"\nPlot saved to: {output_path}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help=( + "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " + "or explicit triplets like '8192x5120x15360,8192x5120x5120'." + ), + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + shapes = parse_shapes_arg(args.shapes) + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py new file mode 100644 index 0000000000..a8dbb436d2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py @@ -0,0 +1,753 @@ +#!/usr/bin/env python3 +""" +GEMM Benchmarking Script with TFLOPS Measurement and Plotting + +Benchmarks matrix multiplication performance across different precisions +(BF16, FP8 via Transformer Engine) and generates a comparison plot. + +Usage: + python gemm_benchmark.py [--output plot.png] [--num-iters 100] +""" + +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Large matrix for leading kernel to saturate GPU + A_large = torch.randn(4096, 4096, dtype=dtype, device=device) + B_large = torch.randn(4096, 4096, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + # Start with a long-running kernel to avoid measuring CPU dispatch overhead + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = torch.matmul(A_large, B_large) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp8_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (requires fp8_dtype argument) + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP8 + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp8 = quantizer.quantize(A_large) + B_large_fp8 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp8, False, # A, transA + B_fp8, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp8, False, B_large_fp8, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp8, False, + B_fp8, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + except Exception as e: + print(f"Warning: FP8 prequantized benchmark failed: {e}") + return None + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def benchmark_te_fp4_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (uses default kFloat4E2M1, but being explicit) + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP4 + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp4 = quantizer.quantize(A_large) + B_large_fp4 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp4, False, # A, transA + B_fp4, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp4, False, B_large_fp4, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp4, False, + B_fp4, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + except Exception as e: + print(f"Warning: FP4 prequantized benchmark failed: {e}") + return None + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Full quantize path (realistic training overhead)") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + # Select benchmark functions based on pre_quantize flag + fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 + fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + output_path_obj = Path(output_path) + supported_formats = set(fig.canvas.get_supported_filetypes().keys()) + suffix = output_path_obj.suffix.lower().lstrip(".") + if suffix not in supported_formats: + output_path_obj = output_path_obj.with_suffix(".png") + print( + f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " + f"saving to '{output_path_obj}' instead." + ) + plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") + print(f"\nPlot saved to: {output_path}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--pre-quantize", + action="store_true", + help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help="Comma-separated list of square matrix sizes (e.g., '1024,2048,4096')" + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + sizes = [int(s.strip()) for s in args.shapes.split(",")] + shapes = [(s, s, s) for s in sizes] + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup, + pre_quantize=args.pre_quantize + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py new file mode 100644 index 0000000000..671cce37c2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py @@ -0,0 +1,772 @@ +#!/usr/bin/env python3 +""" +GEMM Benchmarking Script with TFLOPS Measurement and Plotting + +Benchmarks matrix multiplication performance across different precisions +(BF16, FP8 via Transformer Engine) and generates a comparison plot. + +Usage: + python gemm_benchmark.py [--output plot.png] [--num-iters 100] + python roofline_prequantized_with_shapes.py --output gemm_benchmark_expected_shapes.png --num-warmup 100 --num-iters 100 --gpu-warmup 100 --shapes 8192x5120x20480,8192x20480x5120 +""" + +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from typing import Optional + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Large matrix for leading kernel to saturate GPU + A_large = torch.randn(4096, 4096, dtype=dtype, device=device) + B_large = torch.randn(4096, 4096, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + # Start with a long-running kernel to avoid measuring CPU dispatch overhead + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = torch.matmul(A_large, B_large) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp8_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (requires fp8_dtype argument) + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP8 + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp8 = quantizer.quantize(A_large) + B_large_fp8 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp8, False, # A, transA + B_fp8, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp8, False, B_large_fp8, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp8, False, + B_fp8, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + except Exception as e: + print(f"Warning: FP8 prequantized benchmark failed: {e}") + return None + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def benchmark_te_fp4_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (uses default kFloat4E2M1, but being explicit) + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP4 + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp4 = quantizer.quantize(A_large) + B_large_fp4 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp4, False, # A, transA + B_fp4, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp4, False, B_large_fp4, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp4, False, + B_fp4, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + except Exception as e: + print(f"Warning: FP4 prequantized benchmark failed: {e}") + return None + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: + """Parse a shapes argument into a list of (M, K, N) tuples. + + Supports either: + - Square sizes: "1024,2048,4096" -> [(1024,1024,1024), ...] + - Explicit triplets: "8192x5120x15360,8192x5120x5120" + """ + items = [s.strip() for s in shapes_arg.split(",") if s.strip()] + if not items: + raise ValueError("Empty --shapes argument.") + + shapes: list[tuple[int, int, int]] = [] + for item in items: + if "x" in item: + parts = [p.strip() for p in item.lower().split("x")] + if len(parts) != 3: + raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") + m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) + shapes.append((m, k, n)) + else: + size = int(item) + shapes.append((size, size, size)) + + return shapes + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Full quantize path (realistic training overhead)") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + # Select benchmark functions based on pre_quantize flag + fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 + fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + plt.savefig(output_path, dpi=150, bbox_inches='tight') + print(f"\nPlot saved to: {output_path}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--pre-quantize", + action="store_true", + help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help=( + "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " + "or explicit triplets like '8192x5120x15360,8192x5120x5120'." + ), + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + shapes = parse_shapes_arg(args.shapes) + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup, + pre_quantize=args.pre_quantize + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py new file mode 100644 index 0000000000..73539635e9 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py @@ -0,0 +1,829 @@ +#!/usr/bin/env python3 +""" +GEMM Benchmarking Script with TFLOPS Measurement and Plotting + +Benchmarks matrix multiplication performance across different precisions +(BF16, FP8 via Transformer Engine) and generates a comparison plot. + +Usage: + python gemm_benchmark.py [--output plot.png] [--num-iters 100] +""" + +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from pathlib import Path +from typing import Optional +import math + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Large matrix for leading kernel to saturate GPU + A_large = torch.randn(4096, 4096, dtype=dtype, device=device) + B_large = torch.randn(4096, 4096, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + # Start with a long-running kernel to avoid measuring CPU dispatch overhead + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = torch.matmul(A_large, B_large) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp8_prequantized( + M: int, + K: int, + N: int, + workspace_size: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (requires fp8_dtype argument) + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP8 + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp8 = quantizer.quantize(A_large) + B_large_fp8 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp8, False, # A, transA + B_fp8, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp8, False, B_large_fp8, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp8, False, + B_fp8, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + except Exception as e: + print( + f"Warning: FP8 prequantized benchmark failed for shape {M}x{K}x{N}: {e}\n" + f" Tip: try increasing --workspace-mb (current={workspace_size / (1024 * 1024):.0f}MB) " + "or run without --pre-quantize to use te.Linear()." + ) + return None + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def benchmark_te_fp4_prequantized( + M: int, + K: int, + N: int, + workspace_size: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (uses default kFloat4E2M1, but being explicit) + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP4 + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp4 = quantizer.quantize(A_large) + B_large_fp4 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp4, False, # A, transA + B_fp4, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp4, False, B_large_fp4, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp4, False, + B_fp4, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + except Exception as e: + print( + f"Warning: FP4 prequantized benchmark failed for shape {M}x{K}x{N}: {e}\n" + f" Tip: try increasing --workspace-mb (current={workspace_size / (1024 * 1024):.0f}MB)." + ) + return None + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: + """Parse a shapes argument into a list of (M, K, N) tuples. + + Supports either: + - Square sizes: "1024,2048,4096" -> [(1024,1024,1024), ...] + - Explicit triplets: "8192x5120x15360,8192x5120x5120" + """ + items = [s.strip() for s in shapes_arg.split(",") if s.strip()] + if not items: + raise ValueError("Empty --shapes argument.") + + shapes: list[tuple[int, int, int]] = [] + for item in items: + if "x" in item: + parts = [p.strip() for p in item.lower().split("x")] + if len(parts) != 3: + raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") + m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) + shapes.append((m, k, n)) + else: + size = int(item) + shapes.append((size, size, size)) + + return shapes + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False, + workspace_mb: int = 32, +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + has_fp8 = is_fp8_available() + run_fp8 = include_fp8 and TE_PYTORCH_AVAILABLE and has_fp8 + run_fp4 = include_fp4 and TE_PYTORCH_AVAILABLE and has_blackwell + workspace_size = int(workspace_mb) * 1024 * 1024 + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + major, minor = torch.cuda.get_device_capability() + print(f"CUDA capability: SM{major}{minor}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Full quantize path (realistic training overhead)") + if pre_quantize: + print(f"Workspace: {workspace_mb}MB") + if (include_fp8 or include_fp4) and not TE_PYTORCH_AVAILABLE: + msg = "Note: FP8/FP4 requested but transformer_engine.pytorch import failed; skipping FP8/FP4." + if TE_IMPORT_ERROR: + msg += f" ImportError: {TE_IMPORT_ERROR}" + print(msg) + if pre_quantize and (include_fp8 or include_fp4) and not TE_TORCH_EXT_AVAILABLE: + msg = ( + "Note: --pre-quantize requires transformer_engine_torch (tex.generic_gemm). " + "transformer_engine_torch import failed; skipping FP8/FP4 pre-quantized benchmarks." + ) + if TE_IMPORT_ERROR: + msg += f" ImportError: {TE_IMPORT_ERROR}" + print(msg) + if include_fp8 and TE_AVAILABLE and not has_fp8: + print("Note: FP8 requested but this GPU does not support FP8 Tensor Cores; skipping FP8.") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + # Select benchmark functions based on pre_quantize flag + if pre_quantize: + # Pre-quantized path uses tex.generic_gemm, which requires the TE torch extension. + if not TE_TORCH_EXT_AVAILABLE: + return {"BF16": results["BF16"]} + fp8_benchmark_fn = lambda m, k, n, nw, ni: benchmark_te_fp8_prequantized( # noqa: E731 + m, k, n, workspace_size=workspace_size, num_warmup=nw, num_iters=ni + ) + fp4_benchmark_fn = lambda m, k, n, nw, ni: benchmark_te_fp4_prequantized( # noqa: E731 + m, k, n, workspace_size=workspace_size, num_warmup=nw, num_iters=ni + ) + else: + fp8_benchmark_fn = benchmark_te_fp8 + fp4_benchmark_fn = benchmark_te_fp4 + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + output_path_obj = Path(output_path) + supported_formats = set(fig.canvas.get_supported_filetypes().keys()) + suffix = output_path_obj.suffix.lower().lstrip(".") + if suffix not in supported_formats: + output_path_obj = output_path_obj.with_suffix(".png") + print( + f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " + f"saving to '{output_path_obj}' instead." + ) + plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") + print(f"\nPlot saved to: {output_path_obj}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--pre-quantize", + action="store_true", + help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" + ) + parser.add_argument( + "--workspace-mb", + type=int, + default=32, + help="Workspace size in MB for pre-quantized generic_gemm() path (default: 32).", + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help=( + "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " + "or explicit triplets like '8192x5120x15360,8192x5120x5120'." + ), + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + shapes = parse_shapes_arg(args.shapes) + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup, + pre_quantize=args.pre_quantize, + workspace_mb=args.workspace_mb, + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py new file mode 100644 index 0000000000..a8dbb436d2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py @@ -0,0 +1,753 @@ +#!/usr/bin/env python3 +""" +GEMM Benchmarking Script with TFLOPS Measurement and Plotting + +Benchmarks matrix multiplication performance across different precisions +(BF16, FP8 via Transformer Engine) and generates a comparison plot. + +Usage: + python gemm_benchmark.py [--output plot.png] [--num-iters 100] +""" + +import argparse +import time +import torch +import matplotlib.pyplot as plt +import numpy as np +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +# Optional TE import - gracefully handle if not available +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") + +# Check for Blackwell (SM100+) for FP4 support +def is_blackwell_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +@dataclass +class BenchmarkResult: + """Container for benchmark results.""" + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """ + Compute theoretical FLOP count for GEMM C = A @ B. + + A: (M, K), B: (K, N), C: (M, N) + Each output element requires K multiply-adds = 2K FLOPs + Total: 2 * M * N * K + """ + return 2 * M * N * K + + +def benchmark_torch_matmul( + M: int, + K: int, + N: int, + dtype: torch.dtype, + num_warmup: int = 10, + num_iters: int = 100 +) -> BenchmarkResult: + """Benchmark torch.matmul at specified precision.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=dtype, device=device) + B = torch.randn(K, N, dtype=dtype, device=device) + + # Large matrix for leading kernel to saturate GPU + A_large = torch.randn(4096, 4096, dtype=dtype, device=device) + B_large = torch.randn(4096, 4096, dtype=dtype, device=device) + + # Warmup - critical for accurate timing + for _ in range(num_warmup): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + # Start with a long-running kernel to avoid measuring CPU dispatch overhead + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = torch.matmul(A_large, B_large) + + start_event.record() + for _ in range(num_iters): + _ = torch.matmul(A, B) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + precision_name = { + torch.bfloat16: "BF16", + torch.float16: "FP16", + torch.float32: "FP32", + }.get(dtype, str(dtype)) + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision=precision_name + ) + + +def benchmark_te_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp8_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + + +def benchmark_te_fp8_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (requires fp8_dtype argument) + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP8 + A_fp8 = quantizer.quantize(A) + B_fp8 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp8 = quantizer.quantize(A_large) + B_large_fp8 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp8, False, # A, transA + B_fp8, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp8, False, B_large_fp8, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp8, False, + B_fp8, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="MXFP8" + ) + except Exception as e: + print(f"Warning: FP8 prequantized benchmark failed: {e}") + return None + + +def benchmark_te_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + + # Large tensors for leading kernel + linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + # Keep autocast context open for warmup and timing + with te.autocast(enabled=True, recipe=fp4_recipe): + # Warmup + for _ in range(num_warmup): + _ = linear(x) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + _ = linear_large(x_large) + + start_event.record() + for _ in range(num_iters): + _ = linear(x) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + + +def benchmark_te_fp4_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100 +) -> Optional[BenchmarkResult]: + """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + if not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + # Create quantizer (uses default kFloat4E2M1, but being explicit) + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # Create BF16 tensors + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM + + # Pre-quantize to FP4 + A_fp4 = quantizer.quantize(A) + B_fp4 = quantizer.quantize(B) + + # Large tensors for leading kernel + A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + A_large_fp4 = quantizer.quantize(A_large) + B_large_fp4 = quantizer.quantize(B_large) + + # Output buffers + D = torch.empty(M, N, dtype=torch.bfloat16, device=device) + D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + # Allocate workspace (estimate size) + workspace_size = 32 * 1024 * 1024 # 32MB workspace + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) + + # Warmup + for _ in range(num_warmup): + tex.generic_gemm( + A_fp4, False, # A, transA + B_fp4, True, # B, transB + D, # output + None, # quantizer (None = no output quantization) + tex.DType.kBFloat16, # output_dtype + None, # bias + tex.DType.kBFloat16, # bias_type (unused when bias=None) + False, # gelu + None, # gelu_in + False, # grad + workspace, # workspace + workspace_size, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + torch.cuda.synchronize() + + # Timed iterations with leading kernel + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Leading kernel keeps GPU busy while CPU queues the timed kernels + tex.generic_gemm( + A_large_fp4, False, B_large_fp4, True, D_large, None, + tex.DType.kBFloat16, None, tex.DType.kBFloat16, + False, None, False, workspace, workspace_size, False, False, + ) + + start_event.record() + for _ in range(num_iters): + tex.generic_gemm( + A_fp4, False, + B_fp4, True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + workspace, + workspace_size, + False, + False, + ) + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / num_iters + avg_time_s = avg_time_ms / 1000.0 + + tflops = (flops / avg_time_s) / 1e12 + + return BenchmarkResult( + tflops=tflops, + avg_time_ms=avg_time_ms, + shape=(M, K, N), + precision="NVFP4" + ) + except Exception as e: + print(f"Warning: FP4 prequantized benchmark failed: {e}") + return None + + +def get_default_matrix_shapes() -> list[tuple[int, int, int]]: + """Return default matrix shapes for benchmarking (square matrices).""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def warmup_gpu(duration_seconds: float = 5.0): + """ + Warmup the GPU to stabilize clocks before benchmarking. + + Runs sustained matmuls to bring GPU out of idle state and + get clocks/thermals to steady state. + """ + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + + device = torch.device("cuda") + + # Use a moderate size that keeps GPU busy without OOM + M, K, N = 4096, 4096, 4096 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + start_time = time.time() + + while time.time() - start_time < duration_seconds: + # Run a batch of matmuls + for _ in range(10): + _ = torch.matmul(A, B) + torch.cuda.synchronize() + + # Clear memory + del A, B + torch.cuda.empty_cache() + + print("GPU warmup complete.\n") + + +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False +) -> dict[str, list[float]]: + """Run all benchmarks and return results organized by precision.""" + + results = {"BF16": [], "MXFP8": [], "NVFP4": []} + + # Check hardware capabilities + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + # Print header + gpu_name = torch.cuda.get_device_name(0) + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Full quantize path (realistic training overhead)") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + # Warmup GPU to stabilize clocks + if gpu_warmup_seconds > 0: + warmup_gpu(gpu_warmup_seconds) + + print("=" * 80) + + # Build header dynamically + header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" + if run_fp8: + header += f" {'MXFP8 TFLOPS':>14}" + if run_fp4: + header += f" {'NVFP4 TFLOPS':>14}" + header += f" {'Best Speedup':>12}" + print(header) + print("-" * 80) + + # Select benchmark functions based on pre_quantize flag + fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 + fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 + + for M, K, N in shapes: + shape_str = f"{M}x{K}x{N}" + + # BF16 benchmark + bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) + results["BF16"].append(bf16_result.tflops) + + row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" + best_tflops = bf16_result.tflops + + # FP8 benchmark + if run_fp8: + fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp8_result: + results["MXFP8"].append(fp8_result.tflops) + row += f" {fp8_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp8_result.tflops) + else: + results["MXFP8"].append(0) + row += f" {'N/A':>14}" + + # FP4 benchmark (Blackwell only) + if run_fp4: + fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) + if fp4_result: + results["NVFP4"].append(fp4_result.tflops) + row += f" {fp4_result.tflops:>14.1f}" + best_tflops = max(best_tflops, fp4_result.tflops) + else: + results["NVFP4"].append(0) + row += f" {'N/A':>14}" + + speedup = best_tflops / bf16_result.tflops + row += f" {speedup:>11.2f}x" + print(row) + + print("=" * 80) + + # Remove empty precision results + results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} + + return results + + +def create_plot( + shapes: list[tuple[int, int, int]], + results: dict[str, list[float]], + output_path: str = "gemm_benchmark.png", + title: Optional[str] = None +): + """Create a bar plot matching the style of the reference image.""" + + gpu_name = torch.cuda.get_device_name(0) + if title is None: + title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" + + # Create labels for x-axis + labels = [f"{m}x{k}x{n}" for m, k, n in shapes] + x = np.arange(len(labels)) + + # Determine bar width based on number of kernels + num_kernels = len(results) + bar_width = 0.8 / num_kernels + + # Color scheme matching the reference plot + colors = { + "BF16": "#808080", # Gray + "MXFP8": "#4B0082", # Indigo/Purple + "NVFP4": "#B22222", # Firebrick red (for future use) + } + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plot bars for each precision + for i, (precision, tflops_list) in enumerate(results.items()): + offset = (i - num_kernels / 2 + 0.5) * bar_width + color = colors.get(precision, f"C{i}") + bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) + + # Customize the plot + ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) + ax.set_ylabel("Performance (TFLOPS)", fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) + + # Add legend + ax.legend(title="Kernel", loc='upper left', fontsize=10) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Set y-axis to start at 0 + ax.set_ylim(bottom=0) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the plot + output_path_obj = Path(output_path) + supported_formats = set(fig.canvas.get_supported_filetypes().keys()) + suffix = output_path_obj.suffix.lower().lstrip(".") + if suffix not in supported_formats: + output_path_obj = output_path_obj.with_suffix(".png") + print( + f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " + f"saving to '{output_path_obj}' instead." + ) + plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") + print(f"\nPlot saved to: {output_path}") + + return fig, ax + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmarking with TFLOPS measurement and plotting" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="gemm_benchmark.png", + help="Output path for the plot (default: gemm_benchmark.png)" + ) + parser.add_argument( + "--num-warmup", + type=int, + default=10, + help="Number of warmup iterations (default: 10)" + ) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of timed iterations (default: 100)" + ) + parser.add_argument( + "--gpu-warmup", + type=float, + default=5.0, + help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" + ) + parser.add_argument( + "--no-fp8", + action="store_true", + help="Skip FP8 benchmarks" + ) + parser.add_argument( + "--no-fp4", + action="store_true", + help="Skip FP4 benchmarks (only available on Blackwell)" + ) + parser.add_argument( + "--pre-quantize", + action="store_true", + help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" + ) + parser.add_argument( + "--shapes", + type=str, + default=None, + help="Comma-separated list of square matrix sizes (e.g., '1024,2048,4096')" + ) + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a GPU.") + return 1 + + # Parse custom shapes if provided + if args.shapes: + sizes = [int(s.strip()) for s in args.shapes.split(",")] + shapes = [(s, s, s) for s in sizes] + else: + shapes = get_default_matrix_shapes() + + # Run benchmarks + results = run_benchmarks( + shapes=shapes, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + include_fp8=not args.no_fp8, + include_fp4=not args.no_fp4, + gpu_warmup_seconds=args.gpu_warmup, + pre_quantize=args.pre_quantize + ) + + # Create plot + create_plot(shapes, results, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file From e81cbb45f7a1d3f7c0bca21c75405d92b295dcf4 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 20 Feb 2026 08:45:50 -0800 Subject: [PATCH 20/21] NVFP4 with Nsight code - This is the code I did perf benchmarks for - Contains Nsight profiling results - Also contains NVFP4 / NVFP8 smart logger - Contains layer-wise selection - Home base code as of 02/20/2026. - Will make targeted code from here Signed-off-by: Jonathan Mitchell --- .vscode/settings.json | 3 +- .../gemm_benchmarking/gemm_benchmark.py | 683 --------------- .../gemm_benchmark_withshapes.py | 717 --------------- .../gemm_benchmarking/profiler_gemm.py | 558 ------------ .../gemm_benchmarking/roofline.py | 522 ----------- .../roofline_prequantized.py | 753 ---------------- .../roofline_prequantized_with_shapes.py | 772 ---------------- .../roofline_prequantized_with_shapes_mb.py | 829 ------------------ .../roofline_prequantized_withtorchao.py | 753 ---------------- 9 files changed, 2 insertions(+), 5588 deletions(-) delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py delete mode 100644 bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 41bac6a7e9..e6a3603dab 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -26,5 +26,6 @@ "editor.rulers": [ 120 ], - "autoDocstring.docstringFormat": "google-notypes" + "autoDocstring.docstringFormat": "google-notypes", + "search.exclude": { "**/logs/**": true }, } diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py deleted file mode 100644 index f9819b5bc8..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark.py +++ /dev/null @@ -1,683 +0,0 @@ - -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from typing import Optional - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp8_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (requires fp8_dtype argument) - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP8 - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - - # Output buffer - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp8, False, # A, transA - B_fp8, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp8, False, - B_fp8, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - except Exception as e: - print(f"Warning: FP8 prequantized benchmark failed: {e}") - return None - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def benchmark_te_fp4_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (uses default kFloat4E2M1, but being explicit) - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP4 - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - - # Output buffer - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp4, False, # A, transA - B_fp4, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp4, False, - B_fp4, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - except Exception as e: - print(f"Warning: FP4 prequantized benchmark failed: {e}") - return None - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0, - pre_quantize: bool = False -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - run_fp8 = include_fp8 and TE_AVAILABLE - run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if pre_quantize: - print("Mode: Pre-quantized inputs (raw kernel throughput)") - else: - print("Mode: Full quantize path (realistic training overhead)") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - # Select benchmark functions based on pre_quantize flag - fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 - fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - plt.savefig(output_path, dpi=150, bbox_inches='tight') - print(f"\nPlot saved to: {output_path}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--pre-quantize", - action="store_true", - help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help="Comma-separated list of square matrix sizes (e.g., '1024,2048,4096')" - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - sizes = [int(s.strip()) for s in args.shapes.split(",")] - shapes = [(s, s, s) for s in sizes] - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup, - pre_quantize=args.pre_quantize - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py deleted file mode 100644 index a9c8df9d89..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/gemm_benchmark_withshapes.py +++ /dev/null @@ -1,717 +0,0 @@ -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from typing import Optional - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp8_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (requires fp8_dtype argument) - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - - # Create BF16 tensors in the layout expected by tex.generic_gemm. - # - # Important: tex.generic_gemm has its own conventions for (A, B, transa, transb) and - # expected output orientation. With the settings used below (transa=False, transb=True), - # and with A shaped (K, M) and B shaped (K, N), TransformerEngine expects the output D - # to be shaped (N, M) (note the swapped order). This is fine for benchmarking throughput - # (FLOP count is still 2*M*N*K); it's just a layout convention. - A = torch.randn(K, M, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - # Pre-quantize to FP8 - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - - # Output buffer (see note above about expected (N, M) orientation) - D = torch.empty(N, M, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp8, False, # A, transA - B_fp8, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp8, False, - B_fp8, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - except Exception as e: - print(f"Warning: FP8 prequantized benchmark failed: {e}") - return None - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - import sys; sys.exit(0) - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def benchmark_te_fp4_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (uses default kFloat4E2M1, but being explicit) - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - - # Create BF16 tensors in the layout expected by tex.generic_gemm. - # See FP8 pre-quantized path above for rationale (including expected D orientation). - A = torch.randn(K, M, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - # Pre-quantize to FP4 - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - - # Output buffer (see note above about expected (N, M) orientation) - D = torch.empty(N, M, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp4, False, # A, transA - B_fp4, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp4, False, - B_fp4, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - except Exception as e: - print(f"Warning: FP4 prequantized benchmark failed: {e}") - return None - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: - """Parse a shapes argument into a list of (M, K, N) tuples. - - Supports either: - - Square sizes: "1024,2048,4096" -> [(1024, 1024, 1024), ...] - - Explicit triplets: "8192x5120x15360,8192x5120x5120" - """ - items = [s.strip() for s in shapes_arg.split(",") if s.strip()] - if not items: - raise ValueError("Empty --shapes argument.") - - shapes: list[tuple[int, int, int]] = [] - for item in items: - if "x" in item: - parts = [p.strip() for p in item.lower().split("x")] - if len(parts) != 3: - raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") - m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) - shapes.append((m, k, n)) - else: - size = int(item) - shapes.append((size, size, size)) - - return shapes - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0, - pre_quantize: bool = False -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - run_fp8 = include_fp8 and TE_AVAILABLE - run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if pre_quantize: - print("Mode: Pre-quantized inputs (raw kernel throughput)") - else: - print("Mode: Full quantize path (realistic training overhead)") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - # Select benchmark functions based on pre_quantize flag - fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 - fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - plt.savefig(output_path, dpi=150, bbox_inches='tight') - print(f"\nPlot saved to: {output_path}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--pre-quantize", - action="store_true", - help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help=( - "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " - "or explicit triplets like '8192x5120x15360,8192x5120x5120'." - ), - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - shapes = parse_shapes_arg(args.shapes) - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup, - pre_quantize=args.pre_quantize - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py deleted file mode 100644 index 1d97e9bb11..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/profiler_gemm.py +++ /dev/null @@ -1,558 +0,0 @@ -#!/usr/bin/env python3 -""" -GEMM Profiler with Power/Clock Monitoring - -Detailed profiling of a specific GEMM size with GPU telemetry to understand -performance characteristics and potential throttling. - -Usage: - python profiler_gemm.py --size 1536 --precision bf16 - python profiler_gemm.py --size 1536 --precision fp8 --pre-quantize - python profiler_gemm.py --size 1536 --precision fp4 --pre-quantize --with-leading-kernel -""" - -import argparse -import time -import threading -import torch -from dataclasses import dataclass, field -from typing import Optional, List -import subprocess -import json - -# Try to import pynvml for GPU monitoring -try: - import pynvml - PYNVML_AVAILABLE = True -except ImportError: - PYNVML_AVAILABLE = False - print("Warning: pynvml not available. Install with: pip install pynvml") - -# Optional TE import -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available.") - - -@dataclass -class GPUTelemetry: - """Container for GPU telemetry samples.""" - timestamps: List[float] = field(default_factory=list) - power_watts: List[float] = field(default_factory=list) - temperature_c: List[int] = field(default_factory=list) - sm_clock_mhz: List[int] = field(default_factory=list) - memory_clock_mhz: List[int] = field(default_factory=list) - gpu_utilization: List[int] = field(default_factory=list) - - -class GPUMonitor: - """Background thread for monitoring GPU telemetry.""" - - def __init__(self, device_id: int = 0, sample_interval_ms: float = 10): - self.device_id = device_id - self.sample_interval = sample_interval_ms / 1000.0 - self.telemetry = GPUTelemetry() - self._running = False - self._thread = None - self._handle = None - - def start(self): - if not PYNVML_AVAILABLE: - print("Warning: pynvml not available, skipping GPU monitoring") - return - - pynvml.nvmlInit() - self._handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id) - self._running = True - self._thread = threading.Thread(target=self._monitor_loop, daemon=True) - self._thread.start() - - def stop(self) -> GPUTelemetry: - self._running = False - if self._thread: - self._thread.join(timeout=1.0) - if PYNVML_AVAILABLE: - pynvml.nvmlShutdown() - return self.telemetry - - def _monitor_loop(self): - start_time = time.perf_counter() - while self._running: - try: - now = time.perf_counter() - start_time - - # Power - power_mw = pynvml.nvmlDeviceGetPowerUsage(self._handle) - power_w = power_mw / 1000.0 - - # Temperature - temp = pynvml.nvmlDeviceGetTemperature(self._handle, pynvml.NVML_TEMPERATURE_GPU) - - # Clocks - sm_clock = pynvml.nvmlDeviceGetClockInfo(self._handle, pynvml.NVML_CLOCK_SM) - mem_clock = pynvml.nvmlDeviceGetClockInfo(self._handle, pynvml.NVML_CLOCK_MEM) - - # Utilization - util = pynvml.nvmlDeviceGetUtilizationRates(self._handle) - - self.telemetry.timestamps.append(now) - self.telemetry.power_watts.append(power_w) - self.telemetry.temperature_c.append(temp) - self.telemetry.sm_clock_mhz.append(sm_clock) - self.telemetry.memory_clock_mhz.append(mem_clock) - self.telemetry.gpu_utilization.append(util.gpu) - - except Exception as e: - pass # Ignore sampling errors - - time.sleep(self.sample_interval) - - -def get_gpu_info(): - """Get current GPU info using nvidia-smi.""" - try: - result = subprocess.run( - ['nvidia-smi', '--query-gpu=name,power.limit,clocks.max.sm,clocks.max.memory', - '--format=csv,noheader,nounits'], - capture_output=True, text=True - ) - if result.returncode == 0: - parts = result.stdout.strip().split(', ') - return { - 'name': parts[0], - 'power_limit_w': float(parts[1]), - 'max_sm_clock_mhz': int(parts[2]), - 'max_mem_clock_mhz': int(parts[3]), - } - except: - pass - return None - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - return 2 * M * N * K - - -def profile_bf16_gemm(M: int, K: int, N: int, num_warmup: int, num_iters: int, - with_leading_kernel: bool) -> tuple: - """Profile BF16 GEMM with telemetry.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - if with_leading_kernel: - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - # Warmup - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Start monitoring - monitor = GPUMonitor(sample_interval_ms=5) - monitor.start() - - # Give monitor a moment to start - time.sleep(0.01) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - if with_leading_kernel: - _ = torch.matmul(A_large, B_large) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - # Stop monitoring - telemetry = monitor.stop() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - tflops = (flops / (avg_time_ms / 1000.0)) / 1e12 - - return tflops, avg_time_ms, telemetry - - -def profile_fp8_gemm(M: int, K: int, N: int, num_warmup: int, num_iters: int, - with_leading_kernel: bool, pre_quantize: bool) -> tuple: - """Profile FP8 GEMM with telemetry.""" - if not TE_AVAILABLE: - return None, None, None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - if pre_quantize: - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - workspace_size = 32 * 1024 * 1024 - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - if with_leading_kernel: - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp8 = quantizer.quantize(A_large) - B_large_fp8 = quantizer.quantize(B_large) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - def run_gemm(): - tex.generic_gemm( - A_fp8, False, B_fp8, True, D, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - def run_large_gemm(): - tex.generic_gemm( - A_large_fp8, False, B_large_fp8, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - else: - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - if with_leading_kernel: - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Warmup - if pre_quantize: - for _ in range(num_warmup): - run_gemm() - else: - with te.autocast(enabled=True, recipe=fp8_recipe): - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Start monitoring - monitor = GPUMonitor(sample_interval_ms=5) - monitor.start() - time.sleep(0.01) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - if pre_quantize: - if with_leading_kernel: - run_large_gemm() - - start_event.record() - for _ in range(num_iters): - run_gemm() - end_event.record() - else: - with te.autocast(enabled=True, recipe=fp8_recipe): - if with_leading_kernel: - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - telemetry = monitor.stop() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - tflops = (flops / (avg_time_ms / 1000.0)) / 1e12 - - return tflops, avg_time_ms, telemetry - - -def profile_fp4_gemm(M: int, K: int, N: int, num_warmup: int, num_iters: int, - with_leading_kernel: bool, pre_quantize: bool) -> tuple: - """Profile FP4 GEMM with telemetry.""" - if not TE_AVAILABLE: - return None, None, None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - if pre_quantize: - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - workspace_size = 32 * 1024 * 1024 - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - if with_leading_kernel: - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp4 = quantizer.quantize(A_large) - B_large_fp4 = quantizer.quantize(B_large) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - def run_gemm(): - tex.generic_gemm( - A_fp4, False, B_fp4, True, D, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - def run_large_gemm(): - tex.generic_gemm( - A_large_fp4, False, B_large_fp4, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - else: - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - if with_leading_kernel: - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Warmup - if pre_quantize: - for _ in range(num_warmup): - run_gemm() - else: - with te.autocast(enabled=True, recipe=fp4_recipe): - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Start monitoring - monitor = GPUMonitor(sample_interval_ms=5) - monitor.start() - time.sleep(0.01) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - if pre_quantize: - if with_leading_kernel: - run_large_gemm() - - start_event.record() - for _ in range(num_iters): - run_gemm() - end_event.record() - else: - with te.autocast(enabled=True, recipe=fp4_recipe): - if with_leading_kernel: - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - telemetry = monitor.stop() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - tflops = (flops / (avg_time_ms / 1000.0)) / 1e12 - - return tflops, avg_time_ms, telemetry - - -def print_telemetry_summary(telemetry: GPUTelemetry, gpu_info: dict): - """Print summary of GPU telemetry.""" - if not telemetry.timestamps: - print("\nNo telemetry data collected (pynvml not available?)") - return - - print("\n" + "=" * 60) - print("GPU TELEMETRY SUMMARY") - print("=" * 60) - - # Power - avg_power = sum(telemetry.power_watts) / len(telemetry.power_watts) - max_power = max(telemetry.power_watts) - min_power = min(telemetry.power_watts) - power_limit = gpu_info.get('power_limit_w', 0) if gpu_info else 0 - print(f"\nPower (W):") - print(f" Avg: {avg_power:.1f} Min: {min_power:.1f} Max: {max_power:.1f} Limit: {power_limit:.0f}") - if power_limit > 0: - print(f" Utilization: {100 * avg_power / power_limit:.1f}% of limit") - - # Temperature - avg_temp = sum(telemetry.temperature_c) / len(telemetry.temperature_c) - max_temp = max(telemetry.temperature_c) - print(f"\nTemperature (°C):") - print(f" Avg: {avg_temp:.0f} Max: {max_temp:.0f}") - - # SM Clock - avg_sm = sum(telemetry.sm_clock_mhz) / len(telemetry.sm_clock_mhz) - max_sm = max(telemetry.sm_clock_mhz) - min_sm = min(telemetry.sm_clock_mhz) - max_sm_possible = gpu_info.get('max_sm_clock_mhz', 0) if gpu_info else 0 - print(f"\nSM Clock (MHz):") - print(f" Avg: {avg_sm:.0f} Min: {min_sm} Max: {max_sm} GPU Max: {max_sm_possible}") - if max_sm_possible > 0: - print(f" Running at: {100 * avg_sm / max_sm_possible:.1f}% of max clock") - - # Check for throttling indicators - print("\n" + "-" * 60) - print("THROTTLING ANALYSIS:") - if power_limit > 0 and max_power >= power_limit * 0.95: - print(" ⚠️ Power usage near limit - possible power throttling") - else: - print(" ✓ Power usage below limit") - - if max_sm_possible > 0 and avg_sm < max_sm_possible * 0.9: - print(f" ⚠️ SM clocks below max ({avg_sm:.0f} vs {max_sm_possible} MHz)") - else: - print(" ✓ SM clocks near max") - - clock_variance = max_sm - min_sm - if clock_variance > 100: - print(f" ⚠️ Clock variance: {clock_variance} MHz (unstable clocks)") - else: - print(f" ✓ Clock variance: {clock_variance} MHz (stable)") - - print("=" * 60) - - -def main(): - parser = argparse.ArgumentParser(description="GEMM Profiler with GPU Telemetry") - parser.add_argument("--size", "-s", type=int, default=1536, - help="Matrix size (square MxKxN)") - parser.add_argument("--precision", "-p", choices=['bf16', 'fp8', 'fp4'], default='bf16', - help="Precision to benchmark") - parser.add_argument("--num-warmup", type=int, default=50, - help="Warmup iterations") - parser.add_argument("--num-iters", type=int, default=500, - help="Timed iterations") - parser.add_argument("--pre-quantize", action="store_true", - help="Use pre-quantized inputs (FP8/FP4 only)") - parser.add_argument("--with-leading-kernel", action="store_true", - help="Run a large GEMM before the timed kernels") - parser.add_argument("--compare", action="store_true", - help="Run both with and without leading kernel for comparison") - parser.add_argument("--gpu-warmup", type=float, default=3.0, - help="Seconds to warm up GPU before profiling") - - args = parser.parse_args() - - M = K = N = args.size - - # Get GPU info - gpu_info = get_gpu_info() - print("\n" + "=" * 70) - print("GEMM PROFILER") - print("=" * 70) - if gpu_info: - print(f"GPU: {gpu_info['name']}") - print(f"Power Limit: {gpu_info['power_limit_w']:.0f}W") - print(f"Max SM Clock: {gpu_info['max_sm_clock_mhz']} MHz") - - print(f"\nConfiguration:") - print(f" Shape: {M}x{K}x{N}") - print(f" Precision: {args.precision.upper()}") - print(f" Iterations: {args.num_warmup} warmup + {args.num_iters} timed") - print(f" Pre-quantize: {args.pre_quantize}") - - # GPU warmup - if args.gpu_warmup > 0: - print(f"\nWarming up GPU for {args.gpu_warmup:.1f} seconds...") - device = torch.device("cuda") - warmup_a = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - warmup_b = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - start = time.time() - while time.time() - start < args.gpu_warmup: - _ = torch.matmul(warmup_a, warmup_b) - torch.cuda.synchronize() - del warmup_a, warmup_b - torch.cuda.empty_cache() - print("GPU warmup complete.") - - if args.compare: - # Run both configurations - configs = [ - ("Without leading kernel", False), - ("With leading kernel", True), - ] - else: - configs = [( - "With leading kernel" if args.with_leading_kernel else "Without leading kernel", - args.with_leading_kernel - )] - - results = [] - - for config_name, use_leading in configs: - print(f"\n{'='*60}") - print(f"Configuration: {config_name}") - print('='*60) - - if args.precision == 'bf16': - tflops, avg_ms, telemetry = profile_bf16_gemm( - M, K, N, args.num_warmup, args.num_iters, use_leading - ) - elif args.precision == 'fp8': - tflops, avg_ms, telemetry = profile_fp8_gemm( - M, K, N, args.num_warmup, args.num_iters, use_leading, args.pre_quantize - ) - elif args.precision == 'fp4': - tflops, avg_ms, telemetry = profile_fp4_gemm( - M, K, N, args.num_warmup, args.num_iters, use_leading, args.pre_quantize - ) - - if tflops is not None: - print(f"\nResults:") - print(f" TFLOPS: {tflops:.1f}") - print(f" Avg time: {avg_ms:.4f} ms") - print_telemetry_summary(telemetry, gpu_info) - results.append((config_name, tflops, avg_ms)) - else: - print("Benchmark failed or not available") - - # Print comparison summary - if len(results) == 2: - print(f"\n{'=' * 60}") - print("COMPARISON SUMMARY") - print("=" * 60) - name1, tflops1, ms1 = results[0] - name2, tflops2, ms2 = results[1] - - print(f"\n {name1}:") - print(f" {tflops1:.1f} TFLOPS, {ms1:.4f} ms") - print(f"\n {name2}:") - print(f" {tflops2:.1f} TFLOPS, {ms2:.4f} ms") - - diff_pct = 100 * (tflops2 - tflops1) / tflops1 - print(f"\n Difference: {diff_pct:+.1f}%") - - if diff_pct < -5: - print(f"\n ⚠️ Leading kernel hurts performance") - print(f" Likely cause: power/thermal throttling from the leading kernel") - elif diff_pct > 5: - print(f"\n ✓ Leading kernel helps performance") - print(f" Without it, CPU dispatch overhead was being measured") - else: - print(f"\n ~ Minimal difference") - print(f" CPU dispatch overhead is not significant for this size") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py deleted file mode 100644 index 6b2e8373c0..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline.py +++ /dev/null @@ -1,522 +0,0 @@ -#!/usr/bin/env python3 -""" -GEMM Benchmarking Script with TFLOPS Measurement and Plotting - -Benchmarks matrix multiplication performance across different precisions -(BF16, FP8 via Transformer Engine) and generates a comparison plot. - -Usage: - python gemm_benchmark.py [--output plot.png] [--num-iters 100] -""" - -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: - """Parse a shapes argument into a list of (M, K, N) tuples. - - Supports either: - - Square sizes: "1024,2048,4096" -> [(1024,1024,1024), ...] - - Explicit triplets: "8192x5120x15360,8192x5120x5120" - """ - items = [s.strip() for s in shapes_arg.split(",") if s.strip()] - if not items: - raise ValueError("Empty --shapes argument.") - - shapes: list[tuple[int, int, int]] = [] - for item in items: - if "x" in item: - parts = [p.strip() for p in item.lower().split("x")] - if len(parts) != 3: - raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") - m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) - shapes.append((m, k, n)) - else: - size = int(item) - shapes.append((size, size, size)) - - return shapes - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0 -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - run_fp8 = include_fp8 and TE_AVAILABLE - run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = benchmark_te_fp8(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = benchmark_te_fp4(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - output_path_obj = Path(output_path) - supported_formats = set(fig.canvas.get_supported_filetypes().keys()) - suffix = output_path_obj.suffix.lower().lstrip(".") - if suffix not in supported_formats: - output_path_obj = output_path_obj.with_suffix(".png") - print( - f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " - f"saving to '{output_path_obj}' instead." - ) - plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") - print(f"\nPlot saved to: {output_path}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help=( - "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " - "or explicit triplets like '8192x5120x15360,8192x5120x5120'." - ), - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - shapes = parse_shapes_arg(args.shapes) - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py deleted file mode 100644 index a8dbb436d2..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized.py +++ /dev/null @@ -1,753 +0,0 @@ -#!/usr/bin/env python3 -""" -GEMM Benchmarking Script with TFLOPS Measurement and Plotting - -Benchmarks matrix multiplication performance across different precisions -(BF16, FP8 via Transformer Engine) and generates a comparison plot. - -Usage: - python gemm_benchmark.py [--output plot.png] [--num-iters 100] -""" - -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Large matrix for leading kernel to saturate GPU - A_large = torch.randn(4096, 4096, dtype=dtype, device=device) - B_large = torch.randn(4096, 4096, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - # Start with a long-running kernel to avoid measuring CPU dispatch overhead - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = torch.matmul(A_large, B_large) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp8_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (requires fp8_dtype argument) - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP8 - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp8 = quantizer.quantize(A_large) - B_large_fp8 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp8, False, # A, transA - B_fp8, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp8, False, B_large_fp8, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp8, False, - B_fp8, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - except Exception as e: - print(f"Warning: FP8 prequantized benchmark failed: {e}") - return None - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def benchmark_te_fp4_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (uses default kFloat4E2M1, but being explicit) - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP4 - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp4 = quantizer.quantize(A_large) - B_large_fp4 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp4, False, # A, transA - B_fp4, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp4, False, B_large_fp4, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp4, False, - B_fp4, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - except Exception as e: - print(f"Warning: FP4 prequantized benchmark failed: {e}") - return None - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0, - pre_quantize: bool = False -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - run_fp8 = include_fp8 and TE_AVAILABLE - run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if pre_quantize: - print("Mode: Pre-quantized inputs (raw kernel throughput)") - else: - print("Mode: Full quantize path (realistic training overhead)") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - # Select benchmark functions based on pre_quantize flag - fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 - fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - output_path_obj = Path(output_path) - supported_formats = set(fig.canvas.get_supported_filetypes().keys()) - suffix = output_path_obj.suffix.lower().lstrip(".") - if suffix not in supported_formats: - output_path_obj = output_path_obj.with_suffix(".png") - print( - f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " - f"saving to '{output_path_obj}' instead." - ) - plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") - print(f"\nPlot saved to: {output_path}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--pre-quantize", - action="store_true", - help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help="Comma-separated list of square matrix sizes (e.g., '1024,2048,4096')" - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - sizes = [int(s.strip()) for s in args.shapes.split(",")] - shapes = [(s, s, s) for s in sizes] - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup, - pre_quantize=args.pre_quantize - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py deleted file mode 100644 index 671cce37c2..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes.py +++ /dev/null @@ -1,772 +0,0 @@ -#!/usr/bin/env python3 -""" -GEMM Benchmarking Script with TFLOPS Measurement and Plotting - -Benchmarks matrix multiplication performance across different precisions -(BF16, FP8 via Transformer Engine) and generates a comparison plot. - -Usage: - python gemm_benchmark.py [--output plot.png] [--num-iters 100] - python roofline_prequantized_with_shapes.py --output gemm_benchmark_expected_shapes.png --num-warmup 100 --num-iters 100 --gpu-warmup 100 --shapes 8192x5120x20480,8192x20480x5120 -""" - -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from typing import Optional - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Large matrix for leading kernel to saturate GPU - A_large = torch.randn(4096, 4096, dtype=dtype, device=device) - B_large = torch.randn(4096, 4096, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - # Start with a long-running kernel to avoid measuring CPU dispatch overhead - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = torch.matmul(A_large, B_large) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp8_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (requires fp8_dtype argument) - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP8 - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp8 = quantizer.quantize(A_large) - B_large_fp8 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp8, False, # A, transA - B_fp8, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp8, False, B_large_fp8, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp8, False, - B_fp8, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - except Exception as e: - print(f"Warning: FP8 prequantized benchmark failed: {e}") - return None - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def benchmark_te_fp4_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (uses default kFloat4E2M1, but being explicit) - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP4 - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp4 = quantizer.quantize(A_large) - B_large_fp4 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp4, False, # A, transA - B_fp4, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp4, False, B_large_fp4, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp4, False, - B_fp4, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - except Exception as e: - print(f"Warning: FP4 prequantized benchmark failed: {e}") - return None - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: - """Parse a shapes argument into a list of (M, K, N) tuples. - - Supports either: - - Square sizes: "1024,2048,4096" -> [(1024,1024,1024), ...] - - Explicit triplets: "8192x5120x15360,8192x5120x5120" - """ - items = [s.strip() for s in shapes_arg.split(",") if s.strip()] - if not items: - raise ValueError("Empty --shapes argument.") - - shapes: list[tuple[int, int, int]] = [] - for item in items: - if "x" in item: - parts = [p.strip() for p in item.lower().split("x")] - if len(parts) != 3: - raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") - m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) - shapes.append((m, k, n)) - else: - size = int(item) - shapes.append((size, size, size)) - - return shapes - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0, - pre_quantize: bool = False -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - run_fp8 = include_fp8 and TE_AVAILABLE - run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if pre_quantize: - print("Mode: Pre-quantized inputs (raw kernel throughput)") - else: - print("Mode: Full quantize path (realistic training overhead)") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - # Select benchmark functions based on pre_quantize flag - fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 - fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - plt.savefig(output_path, dpi=150, bbox_inches='tight') - print(f"\nPlot saved to: {output_path}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--pre-quantize", - action="store_true", - help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help=( - "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " - "or explicit triplets like '8192x5120x15360,8192x5120x5120'." - ), - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - shapes = parse_shapes_arg(args.shapes) - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup, - pre_quantize=args.pre_quantize - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py deleted file mode 100644 index 73539635e9..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_with_shapes_mb.py +++ /dev/null @@ -1,829 +0,0 @@ -#!/usr/bin/env python3 -""" -GEMM Benchmarking Script with TFLOPS Measurement and Plotting - -Benchmarks matrix multiplication performance across different precisions -(BF16, FP8 via Transformer Engine) and generates a comparison plot. - -Usage: - python gemm_benchmark.py [--output plot.png] [--num-iters 100] -""" - -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from pathlib import Path -from typing import Optional -import math - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Large matrix for leading kernel to saturate GPU - A_large = torch.randn(4096, 4096, dtype=dtype, device=device) - B_large = torch.randn(4096, 4096, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - # Start with a long-running kernel to avoid measuring CPU dispatch overhead - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = torch.matmul(A_large, B_large) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp8_prequantized( - M: int, - K: int, - N: int, - workspace_size: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (requires fp8_dtype argument) - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP8 - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp8 = quantizer.quantize(A_large) - B_large_fp8 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp8, False, # A, transA - B_fp8, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp8, False, B_large_fp8, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp8, False, - B_fp8, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - except Exception as e: - print( - f"Warning: FP8 prequantized benchmark failed for shape {M}x{K}x{N}: {e}\n" - f" Tip: try increasing --workspace-mb (current={workspace_size / (1024 * 1024):.0f}MB) " - "or run without --pre-quantize to use te.Linear()." - ) - return None - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def benchmark_te_fp4_prequantized( - M: int, - K: int, - N: int, - workspace_size: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (uses default kFloat4E2M1, but being explicit) - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP4 - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp4 = quantizer.quantize(A_large) - B_large_fp4 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp4, False, # A, transA - B_fp4, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp4, False, B_large_fp4, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp4, False, - B_fp4, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - except Exception as e: - print( - f"Warning: FP4 prequantized benchmark failed for shape {M}x{K}x{N}: {e}\n" - f" Tip: try increasing --workspace-mb (current={workspace_size / (1024 * 1024):.0f}MB)." - ) - return None - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: - """Parse a shapes argument into a list of (M, K, N) tuples. - - Supports either: - - Square sizes: "1024,2048,4096" -> [(1024,1024,1024), ...] - - Explicit triplets: "8192x5120x15360,8192x5120x5120" - """ - items = [s.strip() for s in shapes_arg.split(",") if s.strip()] - if not items: - raise ValueError("Empty --shapes argument.") - - shapes: list[tuple[int, int, int]] = [] - for item in items: - if "x" in item: - parts = [p.strip() for p in item.lower().split("x")] - if len(parts) != 3: - raise ValueError(f"Invalid shape '{item}'. Expected format 'MxKxN'.") - m, k, n = (int(parts[0]), int(parts[1]), int(parts[2])) - shapes.append((m, k, n)) - else: - size = int(item) - shapes.append((size, size, size)) - - return shapes - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0, - pre_quantize: bool = False, - workspace_mb: int = 32, -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - has_fp8 = is_fp8_available() - run_fp8 = include_fp8 and TE_PYTORCH_AVAILABLE and has_fp8 - run_fp4 = include_fp4 and TE_PYTORCH_AVAILABLE and has_blackwell - workspace_size = int(workspace_mb) * 1024 * 1024 - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - major, minor = torch.cuda.get_device_capability() - print(f"CUDA capability: SM{major}{minor}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if pre_quantize: - print("Mode: Pre-quantized inputs (raw kernel throughput)") - else: - print("Mode: Full quantize path (realistic training overhead)") - if pre_quantize: - print(f"Workspace: {workspace_mb}MB") - if (include_fp8 or include_fp4) and not TE_PYTORCH_AVAILABLE: - msg = "Note: FP8/FP4 requested but transformer_engine.pytorch import failed; skipping FP8/FP4." - if TE_IMPORT_ERROR: - msg += f" ImportError: {TE_IMPORT_ERROR}" - print(msg) - if pre_quantize and (include_fp8 or include_fp4) and not TE_TORCH_EXT_AVAILABLE: - msg = ( - "Note: --pre-quantize requires transformer_engine_torch (tex.generic_gemm). " - "transformer_engine_torch import failed; skipping FP8/FP4 pre-quantized benchmarks." - ) - if TE_IMPORT_ERROR: - msg += f" ImportError: {TE_IMPORT_ERROR}" - print(msg) - if include_fp8 and TE_AVAILABLE and not has_fp8: - print("Note: FP8 requested but this GPU does not support FP8 Tensor Cores; skipping FP8.") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - # Select benchmark functions based on pre_quantize flag - if pre_quantize: - # Pre-quantized path uses tex.generic_gemm, which requires the TE torch extension. - if not TE_TORCH_EXT_AVAILABLE: - return {"BF16": results["BF16"]} - fp8_benchmark_fn = lambda m, k, n, nw, ni: benchmark_te_fp8_prequantized( # noqa: E731 - m, k, n, workspace_size=workspace_size, num_warmup=nw, num_iters=ni - ) - fp4_benchmark_fn = lambda m, k, n, nw, ni: benchmark_te_fp4_prequantized( # noqa: E731 - m, k, n, workspace_size=workspace_size, num_warmup=nw, num_iters=ni - ) - else: - fp8_benchmark_fn = benchmark_te_fp8 - fp4_benchmark_fn = benchmark_te_fp4 - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - output_path_obj = Path(output_path) - supported_formats = set(fig.canvas.get_supported_filetypes().keys()) - suffix = output_path_obj.suffix.lower().lstrip(".") - if suffix not in supported_formats: - output_path_obj = output_path_obj.with_suffix(".png") - print( - f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " - f"saving to '{output_path_obj}' instead." - ) - plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") - print(f"\nPlot saved to: {output_path_obj}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--pre-quantize", - action="store_true", - help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" - ) - parser.add_argument( - "--workspace-mb", - type=int, - default=32, - help="Workspace size in MB for pre-quantized generic_gemm() path (default: 32).", - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help=( - "Comma-separated list of GEMM shapes. Either square sizes like '1024,2048,4096' " - "or explicit triplets like '8192x5120x15360,8192x5120x5120'." - ), - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - shapes = parse_shapes_arg(args.shapes) - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup, - pre_quantize=args.pre_quantize, - workspace_mb=args.workspace_mb, - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py b/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py deleted file mode 100644 index a8dbb436d2..0000000000 --- a/bionemo-recipes/recipes/esm2_native_te/gemm_benchmarking/roofline_prequantized_withtorchao.py +++ /dev/null @@ -1,753 +0,0 @@ -#!/usr/bin/env python3 -""" -GEMM Benchmarking Script with TFLOPS Measurement and Plotting - -Benchmarks matrix multiplication performance across different precisions -(BF16, FP8 via Transformer Engine) and generates a comparison plot. - -Usage: - python gemm_benchmark.py [--output plot.png] [--num-iters 100] -""" - -import argparse -import time -import torch -import matplotlib.pyplot as plt -import numpy as np -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -# Optional TE import - gracefully handle if not available -try: - import transformer_engine.pytorch as te - import transformer_engine_torch as tex - from transformer_engine.common.recipe import Format, MXFP8BlockScaling, NVFP4BlockScaling - TE_AVAILABLE = True -except ImportError: - TE_AVAILABLE = False - print("Warning: Transformer Engine not available. FP8/FP4 benchmarks will be skipped.") - -# Check for Blackwell (SM100+) for FP4 support -def is_blackwell_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -@dataclass -class BenchmarkResult: - """Container for benchmark results.""" - tflops: float - avg_time_ms: float - shape: tuple[int, int, int] - precision: str - - -def compute_gemm_flops(M: int, K: int, N: int) -> int: - """ - Compute theoretical FLOP count for GEMM C = A @ B. - - A: (M, K), B: (K, N), C: (M, N) - Each output element requires K multiply-adds = 2K FLOPs - Total: 2 * M * N * K - """ - return 2 * M * N * K - - -def benchmark_torch_matmul( - M: int, - K: int, - N: int, - dtype: torch.dtype, - num_warmup: int = 10, - num_iters: int = 100 -) -> BenchmarkResult: - """Benchmark torch.matmul at specified precision.""" - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - A = torch.randn(M, K, dtype=dtype, device=device) - B = torch.randn(K, N, dtype=dtype, device=device) - - # Large matrix for leading kernel to saturate GPU - A_large = torch.randn(4096, 4096, dtype=dtype, device=device) - B_large = torch.randn(4096, 4096, dtype=dtype, device=device) - - # Warmup - critical for accurate timing - for _ in range(num_warmup): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Timed iterations using CUDA events - # Start with a long-running kernel to avoid measuring CPU dispatch overhead - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = torch.matmul(A_large, B_large) - - start_event.record() - for _ in range(num_iters): - _ = torch.matmul(A, B) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - precision_name = { - torch.bfloat16: "BF16", - torch.float16: "FP16", - torch.float32: "FP32", - }.get(dtype, str(dtype)) - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision=precision_name - ) - - -def benchmark_te_fp8( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM via Transformer Engine Linear layer.""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp8_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - - -def benchmark_te_fp8_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP8 GEMM with pre-quantized inputs (measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (requires fp8_dtype argument) - quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP8 - A_fp8 = quantizer.quantize(A) - B_fp8 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp8 = quantizer.quantize(A_large) - B_large_fp8 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp8, False, # A, transA - B_fp8, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp8, False, B_large_fp8, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp8, False, - B_fp8, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="MXFP8" - ) - except Exception as e: - print(f"Warning: FP8 prequantized benchmark failed: {e}") - return None - - -def benchmark_te_fp4( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM via Transformer Engine Linear layer (Blackwell only).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - # TE Linear: input (M, K) @ weight (K, N) -> output (M, N) - linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) - x = torch.randn(M, K, dtype=torch.bfloat16, device=device) - - # Large tensors for leading kernel - linear_large = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) - x_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - - fp4_recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) - - # Keep autocast context open for warmup and timing - with te.autocast(enabled=True, recipe=fp4_recipe): - # Warmup - for _ in range(num_warmup): - _ = linear(x) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - _ = linear_large(x_large) - - start_event.record() - for _ in range(num_iters): - _ = linear(x) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - - -def benchmark_te_fp4_prequantized( - M: int, - K: int, - N: int, - num_warmup: int = 10, - num_iters: int = 100 -) -> Optional[BenchmarkResult]: - """Benchmark FP4 GEMM with pre-quantized inputs (Blackwell only, measures raw kernel throughput).""" - if not TE_AVAILABLE: - return None - - if not is_blackwell_available(): - return None - - device = torch.device("cuda") - flops = compute_gemm_flops(M, K, N) - - try: - # Create quantizer (uses default kFloat4E2M1, but being explicit) - quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) - - # Create BF16 tensors - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Transposed for GEMM - - # Pre-quantize to FP4 - A_fp4 = quantizer.quantize(A) - B_fp4 = quantizer.quantize(B) - - # Large tensors for leading kernel - A_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - B_large = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) - A_large_fp4 = quantizer.quantize(A_large) - B_large_fp4 = quantizer.quantize(B_large) - - # Output buffers - D = torch.empty(M, N, dtype=torch.bfloat16, device=device) - D_large = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) - - # Allocate workspace (estimate size) - workspace_size = 32 * 1024 * 1024 # 32MB workspace - workspace = torch.empty(workspace_size, dtype=torch.uint8, device=device) - - # Warmup - for _ in range(num_warmup): - tex.generic_gemm( - A_fp4, False, # A, transA - B_fp4, True, # B, transB - D, # output - None, # quantizer (None = no output quantization) - tex.DType.kBFloat16, # output_dtype - None, # bias - tex.DType.kBFloat16, # bias_type (unused when bias=None) - False, # gelu - None, # gelu_in - False, # grad - workspace, # workspace - workspace_size, # workspace_size - False, # accumulate - False, # use_split_accumulator - ) - torch.cuda.synchronize() - - # Timed iterations with leading kernel - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Leading kernel keeps GPU busy while CPU queues the timed kernels - tex.generic_gemm( - A_large_fp4, False, B_large_fp4, True, D_large, None, - tex.DType.kBFloat16, None, tex.DType.kBFloat16, - False, None, False, workspace, workspace_size, False, False, - ) - - start_event.record() - for _ in range(num_iters): - tex.generic_gemm( - A_fp4, False, - B_fp4, True, - D, - None, - tex.DType.kBFloat16, - None, - tex.DType.kBFloat16, - False, - None, - False, - workspace, - workspace_size, - False, - False, - ) - end_event.record() - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / num_iters - avg_time_s = avg_time_ms / 1000.0 - - tflops = (flops / avg_time_s) / 1e12 - - return BenchmarkResult( - tflops=tflops, - avg_time_ms=avg_time_ms, - shape=(M, K, N), - precision="NVFP4" - ) - except Exception as e: - print(f"Warning: FP4 prequantized benchmark failed: {e}") - return None - - -def get_default_matrix_shapes() -> list[tuple[int, int, int]]: - """Return default matrix shapes for benchmarking (square matrices).""" - return [ - (256, 256, 256), - (512, 512, 512), - (768, 768, 768), - (1024, 1024, 1024), - (1536, 1536, 1536), - (2048, 2048, 2048), - (3072, 3072, 3072), - (4096, 4096, 4096), - (6144, 6144, 6144), - (8192, 8192, 8192), - (16384, 16384, 16384), - ] - - -def warmup_gpu(duration_seconds: float = 5.0): - """ - Warmup the GPU to stabilize clocks before benchmarking. - - Runs sustained matmuls to bring GPU out of idle state and - get clocks/thermals to steady state. - """ - print(f"Warming up GPU for {duration_seconds:.1f} seconds...") - - device = torch.device("cuda") - - # Use a moderate size that keeps GPU busy without OOM - M, K, N = 4096, 4096, 4096 - A = torch.randn(M, K, dtype=torch.bfloat16, device=device) - B = torch.randn(K, N, dtype=torch.bfloat16, device=device) - - torch.cuda.synchronize() - start_time = time.time() - - while time.time() - start_time < duration_seconds: - # Run a batch of matmuls - for _ in range(10): - _ = torch.matmul(A, B) - torch.cuda.synchronize() - - # Clear memory - del A, B - torch.cuda.empty_cache() - - print("GPU warmup complete.\n") - - -def run_benchmarks( - shapes: list[tuple[int, int, int]], - num_warmup: int = 10, - num_iters: int = 100, - include_fp8: bool = True, - include_fp4: bool = True, - gpu_warmup_seconds: float = 5.0, - pre_quantize: bool = False -) -> dict[str, list[float]]: - """Run all benchmarks and return results organized by precision.""" - - results = {"BF16": [], "MXFP8": [], "NVFP4": []} - - # Check hardware capabilities - has_blackwell = is_blackwell_available() - run_fp8 = include_fp8 and TE_AVAILABLE - run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell - - # Print header - gpu_name = torch.cuda.get_device_name(0) - print(f"\nGEMM Benchmark on {gpu_name}") - print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") - if pre_quantize: - print("Mode: Pre-quantized inputs (raw kernel throughput)") - else: - print("Mode: Full quantize path (realistic training overhead)") - if not has_blackwell and include_fp4: - print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") - - # Warmup GPU to stabilize clocks - if gpu_warmup_seconds > 0: - warmup_gpu(gpu_warmup_seconds) - - print("=" * 80) - - # Build header dynamically - header = f"{'Shape':<20} {'BF16 TFLOPS':>14}" - if run_fp8: - header += f" {'MXFP8 TFLOPS':>14}" - if run_fp4: - header += f" {'NVFP4 TFLOPS':>14}" - header += f" {'Best Speedup':>12}" - print(header) - print("-" * 80) - - # Select benchmark functions based on pre_quantize flag - fp8_benchmark_fn = benchmark_te_fp8_prequantized if pre_quantize else benchmark_te_fp8 - fp4_benchmark_fn = benchmark_te_fp4_prequantized if pre_quantize else benchmark_te_fp4 - - for M, K, N in shapes: - shape_str = f"{M}x{K}x{N}" - - # BF16 benchmark - bf16_result = benchmark_torch_matmul(M, K, N, torch.bfloat16, num_warmup, num_iters) - results["BF16"].append(bf16_result.tflops) - - row = f"{shape_str:<20} {bf16_result.tflops:>14.1f}" - best_tflops = bf16_result.tflops - - # FP8 benchmark - if run_fp8: - fp8_result = fp8_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp8_result: - results["MXFP8"].append(fp8_result.tflops) - row += f" {fp8_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp8_result.tflops) - else: - results["MXFP8"].append(0) - row += f" {'N/A':>14}" - - # FP4 benchmark (Blackwell only) - if run_fp4: - fp4_result = fp4_benchmark_fn(M, K, N, num_warmup, num_iters) - if fp4_result: - results["NVFP4"].append(fp4_result.tflops) - row += f" {fp4_result.tflops:>14.1f}" - best_tflops = max(best_tflops, fp4_result.tflops) - else: - results["NVFP4"].append(0) - row += f" {'N/A':>14}" - - speedup = best_tflops / bf16_result.tflops - row += f" {speedup:>11.2f}x" - print(row) - - print("=" * 80) - - # Remove empty precision results - results = {k: v for k, v in results.items() if v and any(x > 0 for x in v)} - - return results - - -def create_plot( - shapes: list[tuple[int, int, int]], - results: dict[str, list[float]], - output_path: str = "gemm_benchmark.png", - title: Optional[str] = None -): - """Create a bar plot matching the style of the reference image.""" - - gpu_name = torch.cuda.get_device_name(0) - if title is None: - title = f"Absolute Performance Comparison\nMeasured on {gpu_name}" - - # Create labels for x-axis - labels = [f"{m}x{k}x{n}" for m, k, n in shapes] - x = np.arange(len(labels)) - - # Determine bar width based on number of kernels - num_kernels = len(results) - bar_width = 0.8 / num_kernels - - # Color scheme matching the reference plot - colors = { - "BF16": "#808080", # Gray - "MXFP8": "#4B0082", # Indigo/Purple - "NVFP4": "#B22222", # Firebrick red (for future use) - } - - fig, ax = plt.subplots(figsize=(14, 8)) - - # Plot bars for each precision - for i, (precision, tflops_list) in enumerate(results.items()): - offset = (i - num_kernels / 2 + 0.5) * bar_width - color = colors.get(precision, f"C{i}") - bars = ax.bar(x + offset, tflops_list, bar_width, label=precision, color=color) - - # Customize the plot - ax.set_xlabel("Matrix Shape (MxKxN)", fontsize=12) - ax.set_ylabel("Performance (TFLOPS)", fontsize=12) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xticks(x) - ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) - - # Add legend - ax.legend(title="Kernel", loc='upper left', fontsize=10) - - # Add grid for readability - ax.yaxis.grid(True, linestyle='--', alpha=0.7) - ax.set_axisbelow(True) - - # Set y-axis to start at 0 - ax.set_ylim(bottom=0) - - # Adjust layout to prevent label cutoff - plt.tight_layout() - - # Save the plot - output_path_obj = Path(output_path) - supported_formats = set(fig.canvas.get_supported_filetypes().keys()) - suffix = output_path_obj.suffix.lower().lstrip(".") - if suffix not in supported_formats: - output_path_obj = output_path_obj.with_suffix(".png") - print( - f"Warning: Output extension '.{suffix}' is not supported by matplotlib; " - f"saving to '{output_path_obj}' instead." - ) - plt.savefig(str(output_path_obj), dpi=150, bbox_inches="tight") - print(f"\nPlot saved to: {output_path}") - - return fig, ax - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Benchmarking with TFLOPS measurement and plotting" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="gemm_benchmark.png", - help="Output path for the plot (default: gemm_benchmark.png)" - ) - parser.add_argument( - "--num-warmup", - type=int, - default=10, - help="Number of warmup iterations (default: 10)" - ) - parser.add_argument( - "--num-iters", - type=int, - default=100, - help="Number of timed iterations (default: 100)" - ) - parser.add_argument( - "--gpu-warmup", - type=float, - default=5.0, - help="GPU warmup duration in seconds (default: 5.0, set to 0 to disable)" - ) - parser.add_argument( - "--no-fp8", - action="store_true", - help="Skip FP8 benchmarks" - ) - parser.add_argument( - "--no-fp4", - action="store_true", - help="Skip FP4 benchmarks (only available on Blackwell)" - ) - parser.add_argument( - "--pre-quantize", - action="store_true", - help="Use pre-quantized inputs to measure raw kernel throughput (excludes quantization overhead)" - ) - parser.add_argument( - "--shapes", - type=str, - default=None, - help="Comma-separated list of square matrix sizes (e.g., '1024,2048,4096')" - ) - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("Error: CUDA is not available. This script requires a GPU.") - return 1 - - # Parse custom shapes if provided - if args.shapes: - sizes = [int(s.strip()) for s in args.shapes.split(",")] - shapes = [(s, s, s) for s in sizes] - else: - shapes = get_default_matrix_shapes() - - # Run benchmarks - results = run_benchmarks( - shapes=shapes, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - include_fp8=not args.no_fp8, - include_fp4=not args.no_fp4, - gpu_warmup_seconds=args.gpu_warmup, - pre_quantize=args.pre_quantize - ) - - # Create plot - create_plot(shapes, results, args.output) - - return 0 - - -if __name__ == "__main__": - exit(main()) \ No newline at end of file From 6c12f2acb07ae4fec9b04e7ebb24267cd37c4b15 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 20 Feb 2026 13:11:00 -0800 Subject: [PATCH 21/21] adds stuff to docker ignore Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/.dockerignore | 34 ++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/.dockerignore b/bionemo-recipes/recipes/esm2_native_te/.dockerignore index e67ca715ce..ff0577a466 100644 --- a/bionemo-recipes/recipes/esm2_native_te/.dockerignore +++ b/bionemo-recipes/recipes/esm2_native_te/.dockerignore @@ -1,10 +1,34 @@ +# Docker Dockerfile +Dockerfile.* +.dockerignore + +# Docs README.md -checkpoint_export/ -outputs/ -.ruff_cache + +# Python caches __pycache__ .pytest_cache -.ruff.toml -.dockerignore +.ruff_cache .venv/ + +# Linting +.ruff.toml + +# Profiling & debugging artifacts +memory_snapshots/ +nsight_profiling/ +*.nsys-rep +*.sqlite +logs/ +wandb/ + +# Hydra / training outputs +outputs/ +checkpoints/ + +# Checkpoint export +checkpoint_export/ + +# Temp / scratch +j/