Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 185 additions & 1 deletion atom/plugin/vllm/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections.abc import Iterable
from contextlib import nullcontext

import importlib
import torch
import torch.nn as nn
from torch.profiler import record_function
from aiter.dist.parallel_state import (
get_pp_group,
get_tp_group,
)
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.model_executor.models.interfaces import (
SupportsPP,
SupportsQuant,
Expand All @@ -25,6 +28,7 @@
import logging

logger = logging.getLogger("atom")
_ATOM_OOT_TORCH_PROFILE_ACTIVE = False


_ATOM_MODEL_CLASSES: dict[str, str] = {
Expand Down Expand Up @@ -57,6 +61,187 @@ def _prepare_env(atom_config) -> None:
logger.info("Init aiter dist for using aiter custom collective ops")
init_aiter_dist(config=atom_config)

_patch_vllm_profile_labels()


def _is_torch_profile_enabled(vllm_config: VllmConfig) -> bool:
profiler_config = getattr(vllm_config, "profiler_config", None)
return profiler_config is not None and bool(
getattr(profiler_config, "torch_profiler_dir", "")
)


def _patch_vllm_profile_labels() -> None:
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.v1.worker import gpu_model_runner as gpu_model_runner_mod
from vllm.v1 import utils as v1_utils_mod

global _ATOM_OOT_TORCH_PROFILE_ACTIVE

if not getattr(TorchProfilerWrapper, "_atom_step_label_patched", False):
original_call_start = TorchProfilerWrapper._call_start
original_call_stop = TorchProfilerWrapper._call_stop

def _wrapped_call_start(self):
global _ATOM_OOT_TORCH_PROFILE_ACTIVE
original_call_start(self)
_ATOM_OOT_TORCH_PROFILE_ACTIVE = bool(getattr(self, "_running", False))

def _wrapped_call_stop(self):
global _ATOM_OOT_TORCH_PROFILE_ACTIVE
original_call_stop(self)
_ATOM_OOT_TORCH_PROFILE_ACTIVE = False

TorchProfilerWrapper._call_start = _wrapped_call_start
TorchProfilerWrapper._call_stop = _wrapped_call_stop
TorchProfilerWrapper._atom_step_label_patched = True

# Use vLLM's existing step boundary in execute_model():
#
# with set_forward_context(...),
# record_function_or_nullcontext("gpu_model_runner: forward"):
# model_output = self._model_forward(...)
#
# This encloses the real model run for eager, piecewise, and full graph
# modes, so a single dynamic label hook here is simpler than patching
# separate execution paths. Note that gpu_model_runner imported
# `record_function_or_nullcontext` by value, so patching only v1.utils is
# not sufficient after import; we patch the local symbol in
# gpu_model_runner as well.
if not getattr(gpu_model_runner_mod, "_atom_step_label_patched", False):
original_record_ctx = gpu_model_runner_mod.record_function_or_nullcontext
original_utils_record_ctx = v1_utils_mod.record_function_or_nullcontext

class _DynamicForwardRecordContext:
def __init__(self, name: str, original_ctx):
self.name = name
self.original_ctx = original_ctx
self.ctx = nullcontext()

def __enter__(self):
if self.name == "gpu_model_runner: forward":
if (
_ATOM_OOT_TORCH_PROFILE_ACTIVE
and is_forward_context_available()
):
record_label = _build_step_profiler_label()
print('[zejun] record_label = ', record_label, flush=True)
if record_label is not None:
self.ctx = record_function(record_label)
return self.ctx.__enter__()

self.ctx = self.original_ctx(self.name)
return self.ctx.__enter__()

def __exit__(self, exc_type, exc, tb):
return self.ctx.__exit__(exc_type, exc, tb)

def _wrapped_gpu_record_function_or_nullcontext(name: str):
return _DynamicForwardRecordContext(name, original_record_ctx)

def _wrapped_utils_record_function_or_nullcontext(name: str):
return _DynamicForwardRecordContext(name, original_utils_record_ctx)

gpu_model_runner_mod.record_function_or_nullcontext = (
_wrapped_gpu_record_function_or_nullcontext
)
v1_utils_mod.record_function_or_nullcontext = (
_wrapped_utils_record_function_or_nullcontext
)
gpu_model_runner_mod._atom_step_label_patched = True


def _get_step_attn_metadata_list():
if not is_forward_context_available():
return []

attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
return []

if isinstance(attn_metadata, list):
metadatas = []
# In ubatch mode, vLLM stores one metadata dict per microbatch. We need
# one representative metadata object per microbatch and aggregate them to
# get the full-step request/token counts.
for ubatch_attn_metadata in attn_metadata:
if not ubatch_attn_metadata:
continue
metadata = next(iter(ubatch_attn_metadata.values()), None)
if metadata is not None:
metadatas.append(metadata)
return metadatas

if isinstance(attn_metadata, dict):
metadata = next(iter(attn_metadata.values()), None)
return [metadata] if metadata is not None else []

return [attn_metadata]


def _build_step_profiler_label() -> str | None:
attn_metadata_list = _get_step_attn_metadata_list()
if not attn_metadata_list:
return None

num_actual_tokens = 0
num_decodes = 0
num_decode_tokens = 0
num_prefills = 0
num_extends = 0
num_extend_tokens = 0
num_prefill_tokens = 0

for attn_metadata in attn_metadata_list:
plugin_metadata = getattr(attn_metadata, "plugin_metadata", None)
if plugin_metadata is None:
continue

actual_tokens_i = plugin_metadata.num_actual_tokens
decodes_i = plugin_metadata.num_decodes
decode_tokens_i = plugin_metadata.num_decode_tokens
prefills_i = plugin_metadata.num_prefills
extends_i = getattr(plugin_metadata, "num_extends", 0)
extend_tokens_i = getattr(plugin_metadata, "num_extend_tokens", 0)
prefill_tokens_i = getattr(
plugin_metadata,
"num_prefill_tokens",
actual_tokens_i - decode_tokens_i - extend_tokens_i,
)

num_actual_tokens += actual_tokens_i
num_decodes += decodes_i
num_decode_tokens += decode_tokens_i
num_prefills += prefills_i
num_extends += extends_i
num_extend_tokens += extend_tokens_i
num_prefill_tokens += prefill_tokens_i

if num_actual_tokens <= 0:
return None

total_reqs = num_decodes + num_prefills + num_extends
if total_reqs <= 0:
return None

# OOT step naming policy:
# - prefill wins if any prefill tokens exist
# - otherwise extend wins if any extend tokens exist
# - otherwise the step is decode
# `bs` is total request count, `tok` is total token count,
# and `p/e/d` are prefill/extend/decode token counts.
if num_prefills > 0:
step = "prefill"
elif num_extends > 0:
step = "extend"
else:
step = "decode"

return (
f"{step}[bs={total_reqs} tok={num_actual_tokens} "
f"p={num_prefill_tokens} e={num_extend_tokens} d={num_decode_tokens}]"
)


class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP):
def __init_subclass__(cls, *args, **kwargs):
Expand Down Expand Up @@ -119,7 +304,6 @@ def forward(
"positions"
]
buf[: positions.numel()].copy_(positions)

hidden_states = self.model(
input_ids=input_ids,
positions=positions,
Expand Down
Loading