diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 10029a79f..139f05526 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -1,13 +1,16 @@ from collections.abc import Iterable +from contextlib import nullcontext import importlib import torch import torch.nn as nn +from torch.profiler import record_function from aiter.dist.parallel_state import ( get_pp_group, get_tp_group, ) from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.model_executor.models.interfaces import ( SupportsPP, SupportsQuant, @@ -25,6 +28,7 @@ import logging logger = logging.getLogger("atom") +_ATOM_OOT_TORCH_PROFILE_ACTIVE = False _ATOM_MODEL_CLASSES: dict[str, str] = { @@ -57,6 +61,187 @@ def _prepare_env(atom_config) -> None: logger.info("Init aiter dist for using aiter custom collective ops") init_aiter_dist(config=atom_config) + _patch_vllm_profile_labels() + + +def _is_torch_profile_enabled(vllm_config: VllmConfig) -> bool: + profiler_config = getattr(vllm_config, "profiler_config", None) + return profiler_config is not None and bool( + getattr(profiler_config, "torch_profiler_dir", "") + ) + + +def _patch_vllm_profile_labels() -> None: + from vllm.profiler.wrapper import TorchProfilerWrapper + from vllm.v1.worker import gpu_model_runner as gpu_model_runner_mod + from vllm.v1 import utils as v1_utils_mod + + global _ATOM_OOT_TORCH_PROFILE_ACTIVE + + if not getattr(TorchProfilerWrapper, "_atom_step_label_patched", False): + original_call_start = TorchProfilerWrapper._call_start + original_call_stop = TorchProfilerWrapper._call_stop + + def _wrapped_call_start(self): + global _ATOM_OOT_TORCH_PROFILE_ACTIVE + original_call_start(self) + _ATOM_OOT_TORCH_PROFILE_ACTIVE = bool(getattr(self, "_running", False)) + + def _wrapped_call_stop(self): + global _ATOM_OOT_TORCH_PROFILE_ACTIVE + original_call_stop(self) + _ATOM_OOT_TORCH_PROFILE_ACTIVE = False + + TorchProfilerWrapper._call_start = _wrapped_call_start + TorchProfilerWrapper._call_stop = _wrapped_call_stop + TorchProfilerWrapper._atom_step_label_patched = True + + # Use vLLM's existing step boundary in execute_model(): + # + # with set_forward_context(...), + # record_function_or_nullcontext("gpu_model_runner: forward"): + # model_output = self._model_forward(...) + # + # This encloses the real model run for eager, piecewise, and full graph + # modes, so a single dynamic label hook here is simpler than patching + # separate execution paths. Note that gpu_model_runner imported + # `record_function_or_nullcontext` by value, so patching only v1.utils is + # not sufficient after import; we patch the local symbol in + # gpu_model_runner as well. + if not getattr(gpu_model_runner_mod, "_atom_step_label_patched", False): + original_record_ctx = gpu_model_runner_mod.record_function_or_nullcontext + original_utils_record_ctx = v1_utils_mod.record_function_or_nullcontext + + class _DynamicForwardRecordContext: + def __init__(self, name: str, original_ctx): + self.name = name + self.original_ctx = original_ctx + self.ctx = nullcontext() + + def __enter__(self): + if self.name == "gpu_model_runner: forward": + if ( + _ATOM_OOT_TORCH_PROFILE_ACTIVE + and is_forward_context_available() + ): + record_label = _build_step_profiler_label() + print('[zejun] record_label = ', record_label, flush=True) + if record_label is not None: + self.ctx = record_function(record_label) + return self.ctx.__enter__() + + self.ctx = self.original_ctx(self.name) + return self.ctx.__enter__() + + def __exit__(self, exc_type, exc, tb): + return self.ctx.__exit__(exc_type, exc, tb) + + def _wrapped_gpu_record_function_or_nullcontext(name: str): + return _DynamicForwardRecordContext(name, original_record_ctx) + + def _wrapped_utils_record_function_or_nullcontext(name: str): + return _DynamicForwardRecordContext(name, original_utils_record_ctx) + + gpu_model_runner_mod.record_function_or_nullcontext = ( + _wrapped_gpu_record_function_or_nullcontext + ) + v1_utils_mod.record_function_or_nullcontext = ( + _wrapped_utils_record_function_or_nullcontext + ) + gpu_model_runner_mod._atom_step_label_patched = True + + +def _get_step_attn_metadata_list(): + if not is_forward_context_available(): + return [] + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + return [] + + if isinstance(attn_metadata, list): + metadatas = [] + # In ubatch mode, vLLM stores one metadata dict per microbatch. We need + # one representative metadata object per microbatch and aggregate them to + # get the full-step request/token counts. + for ubatch_attn_metadata in attn_metadata: + if not ubatch_attn_metadata: + continue + metadata = next(iter(ubatch_attn_metadata.values()), None) + if metadata is not None: + metadatas.append(metadata) + return metadatas + + if isinstance(attn_metadata, dict): + metadata = next(iter(attn_metadata.values()), None) + return [metadata] if metadata is not None else [] + + return [attn_metadata] + + +def _build_step_profiler_label() -> str | None: + attn_metadata_list = _get_step_attn_metadata_list() + if not attn_metadata_list: + return None + + num_actual_tokens = 0 + num_decodes = 0 + num_decode_tokens = 0 + num_prefills = 0 + num_extends = 0 + num_extend_tokens = 0 + num_prefill_tokens = 0 + + for attn_metadata in attn_metadata_list: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + if plugin_metadata is None: + continue + + actual_tokens_i = plugin_metadata.num_actual_tokens + decodes_i = plugin_metadata.num_decodes + decode_tokens_i = plugin_metadata.num_decode_tokens + prefills_i = plugin_metadata.num_prefills + extends_i = getattr(plugin_metadata, "num_extends", 0) + extend_tokens_i = getattr(plugin_metadata, "num_extend_tokens", 0) + prefill_tokens_i = getattr( + plugin_metadata, + "num_prefill_tokens", + actual_tokens_i - decode_tokens_i - extend_tokens_i, + ) + + num_actual_tokens += actual_tokens_i + num_decodes += decodes_i + num_decode_tokens += decode_tokens_i + num_prefills += prefills_i + num_extends += extends_i + num_extend_tokens += extend_tokens_i + num_prefill_tokens += prefill_tokens_i + + if num_actual_tokens <= 0: + return None + + total_reqs = num_decodes + num_prefills + num_extends + if total_reqs <= 0: + return None + + # OOT step naming policy: + # - prefill wins if any prefill tokens exist + # - otherwise extend wins if any extend tokens exist + # - otherwise the step is decode + # `bs` is total request count, `tok` is total token count, + # and `p/e/d` are prefill/extend/decode token counts. + if num_prefills > 0: + step = "prefill" + elif num_extends > 0: + step = "extend" + else: + step = "decode" + + return ( + f"{step}[bs={total_reqs} tok={num_actual_tokens} " + f"p={num_prefill_tokens} e={num_extend_tokens} d={num_decode_tokens}]" + ) + class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP): def __init_subclass__(cls, *args, **kwargs): @@ -119,7 +304,6 @@ def forward( "positions" ] buf[: positions.numel()].copy_(positions) - hidden_states = self.model( input_ids=input_ids, positions=positions,