From 9fd974a5a7a67bba5613e842a41038812a486ba0 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Tue, 17 Mar 2026 22:24:04 +0800 Subject: [PATCH 1/9] [plugin][profiler] refine OOT profiler with record function Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 96 +++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 10029a79f..b5ac73be8 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -1,13 +1,16 @@ from collections.abc import Iterable +from contextlib import nullcontext import importlib import torch import torch.nn as nn +from torch.profiler import record_function from aiter.dist.parallel_state import ( get_pp_group, get_tp_group, ) from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.model_executor.models.interfaces import ( SupportsPP, SupportsQuant, @@ -58,6 +61,71 @@ def _prepare_env(atom_config) -> None: init_aiter_dist(config=atom_config) +def _get_step_attn_metadata(): + if not is_forward_context_available(): + return None + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + return None + + if isinstance(attn_metadata, list): + # In ubatch mode, vLLM stores one metadata dict per microbatch. We need + # the first actual per-layer metadata object, not the outer list itself. + # Keep the empty-dict guard for robustness if a placeholder slips through. + for ubatch_attn_metadata in attn_metadata: + if not ubatch_attn_metadata: + continue + return next(iter(ubatch_attn_metadata.values()), None) + return None + + if isinstance(attn_metadata, dict): + return next(iter(attn_metadata.values()), None) + + return attn_metadata + + +def _build_step_profiler_label() -> str | None: + attn_metadata = _get_step_attn_metadata() + if attn_metadata is None: + return None + + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + if plugin_metadata is None: + return None + + num_actual_tokens = plugin_metadata.num_actual_tokens + num_decodes = plugin_metadata.num_decodes + num_decode_tokens = plugin_metadata.num_decode_tokens + num_prefills = plugin_metadata.num_prefills + num_extends = getattr(plugin_metadata, "num_extends", 0) + num_extend_tokens = getattr(plugin_metadata, "num_extend_tokens", 0) + num_prefill_tokens = getattr( + plugin_metadata, + "num_prefill_tokens", + num_actual_tokens - num_decode_tokens - num_extend_tokens, + ) + + if num_actual_tokens <= 0: + return None + + total_reqs = num_decodes + num_prefills + num_extends + if total_reqs <= 0: + return None + + # Shorthand label format: + # d = decode-only step, p = step containing prefill/extend work. + # req/tok = total requests/tokens in this step. + # dec/pre/ext each carry request count followed by token count. + step = "p" if (num_prefills > 0 or num_extends > 0) else "d" + return ( + f"{step}[req{total_reqs}, tok{num_actual_tokens}, " + f"dec{num_decodes}, tok{num_decode_tokens}, " + f"pre{num_prefills}, tok{num_prefill_tokens}, " + f"ext{num_extends}, tok{num_extend_tokens}]" + ) + + class ATOMModelBase(nn.Module, VllmModel, SupportsQuant, SupportsPP): def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) @@ -72,6 +140,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.quant_config = vllm_config.quant_config + profiler_config = getattr(vllm_config, "profiler_config", None) + self.enable_torch_profile = ( + profiler_config is not None + and bool(getattr(profiler_config, "torch_profiler_dir", "")) + ) # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] @@ -120,14 +193,25 @@ def forward( ] buf[: positions.numel()].copy_(positions) - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, + record_label = ( + _build_step_profiler_label() + if self.enable_torch_profile + else None ) + with ( + record_function(record_label) + if record_label is not None + else nullcontext() + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) From 2b572b56ff77a888b5af7c82e6b8a068fb7f2413 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 14:36:40 +0800 Subject: [PATCH 2/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 152 ++++++++++++++++++++---------- 1 file changed, 104 insertions(+), 48 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index b5ac73be8..8e00d1e99 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -1,5 +1,6 @@ from collections.abc import Iterable from contextlib import nullcontext +from functools import wraps import importlib import torch @@ -60,51 +61,123 @@ def _prepare_env(atom_config) -> None: logger.info("Init aiter dist for using aiter custom collective ops") init_aiter_dist(config=atom_config) + _patch_vllm_profile_labels() -def _get_step_attn_metadata(): + +def _is_torch_profile_enabled(vllm_config: VllmConfig) -> bool: + profiler_config = getattr(vllm_config, "profiler_config", None) + return profiler_config is not None and bool( + getattr(profiler_config, "torch_profiler_dir", "") + ) + + +def _patch_vllm_profile_labels() -> None: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + # `GPUModelRunner._model_forward()` stays in Python for eager, piecewise, + # full cudagraph replay, and ubatch modes, so one patch here covers all of + # those execution paths without modifying vLLM source files. + if not getattr(GPUModelRunner, "_atom_step_label_patched", False): + original_model_forward = GPUModelRunner._model_forward + + @wraps(original_model_forward) + def _wrapped_model_forward( + self, + input_ids=None, + positions=None, + intermediate_tensors=None, + inputs_embeds=None, + **model_kwargs, + ): + record_label = None + if _is_torch_profile_enabled( + self.vllm_config + ) and is_forward_context_available(): + record_label = _build_step_profiler_label() + + with ( + record_function(record_label) + if record_label is not None + else nullcontext() + ): + return original_model_forward( + self, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + GPUModelRunner._model_forward = _wrapped_model_forward + GPUModelRunner._atom_step_label_patched = True + + +def _get_step_attn_metadata_list(): if not is_forward_context_available(): - return None + return [] attn_metadata = get_forward_context().attn_metadata if attn_metadata is None: - return None + return [] if isinstance(attn_metadata, list): + metadatas = [] # In ubatch mode, vLLM stores one metadata dict per microbatch. We need - # the first actual per-layer metadata object, not the outer list itself. - # Keep the empty-dict guard for robustness if a placeholder slips through. + # one representative metadata object per microbatch and aggregate them to + # get the full-step request/token counts. for ubatch_attn_metadata in attn_metadata: if not ubatch_attn_metadata: continue - return next(iter(ubatch_attn_metadata.values()), None) - return None + metadata = next(iter(ubatch_attn_metadata.values()), None) + if metadata is not None: + metadatas.append(metadata) + return metadatas if isinstance(attn_metadata, dict): - return next(iter(attn_metadata.values()), None) + metadata = next(iter(attn_metadata.values()), None) + return [metadata] if metadata is not None else [] - return attn_metadata + return [attn_metadata] def _build_step_profiler_label() -> str | None: - attn_metadata = _get_step_attn_metadata() - if attn_metadata is None: + attn_metadata_list = _get_step_attn_metadata_list() + if not attn_metadata_list: return None - plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) - if plugin_metadata is None: - return None + num_actual_tokens = 0 + num_decodes = 0 + num_decode_tokens = 0 + num_prefills = 0 + num_extends = 0 + num_extend_tokens = 0 + num_prefill_tokens = 0 + + for attn_metadata in attn_metadata_list: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + if plugin_metadata is None: + continue + + actual_tokens_i = plugin_metadata.num_actual_tokens + decodes_i = plugin_metadata.num_decodes + decode_tokens_i = plugin_metadata.num_decode_tokens + prefills_i = plugin_metadata.num_prefills + extends_i = getattr(plugin_metadata, "num_extends", 0) + extend_tokens_i = getattr(plugin_metadata, "num_extend_tokens", 0) + prefill_tokens_i = getattr( + plugin_metadata, + "num_prefill_tokens", + actual_tokens_i - decode_tokens_i - extend_tokens_i, + ) - num_actual_tokens = plugin_metadata.num_actual_tokens - num_decodes = plugin_metadata.num_decodes - num_decode_tokens = plugin_metadata.num_decode_tokens - num_prefills = plugin_metadata.num_prefills - num_extends = getattr(plugin_metadata, "num_extends", 0) - num_extend_tokens = getattr(plugin_metadata, "num_extend_tokens", 0) - num_prefill_tokens = getattr( - plugin_metadata, - "num_prefill_tokens", - num_actual_tokens - num_decode_tokens - num_extend_tokens, - ) + num_actual_tokens += actual_tokens_i + num_decodes += decodes_i + num_decode_tokens += decode_tokens_i + num_prefills += prefills_i + num_extends += extends_i + num_extend_tokens += extend_tokens_i + num_prefill_tokens += prefill_tokens_i if num_actual_tokens <= 0: return None @@ -140,11 +213,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.quant_config = vllm_config.quant_config - profiler_config = getattr(vllm_config, "profiler_config", None) - self.enable_torch_profile = ( - profiler_config is not None - and bool(getattr(profiler_config, "torch_profiler_dir", "")) - ) # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] @@ -192,26 +260,14 @@ def forward( "positions" ] buf[: positions.numel()].copy_(positions) - - record_label = ( - _build_step_profiler_label() - if self.enable_torch_profile - else None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) - with ( - record_function(record_label) - if record_label is not None - else nullcontext() - ): - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) From d65cd7417a0415dc13fcdd83032ea0c403277e5f Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 15:20:42 +0800 Subject: [PATCH 3/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 8e00d1e99..4ff884ffb 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -186,16 +186,22 @@ def _build_step_profiler_label() -> str | None: if total_reqs <= 0: return None - # Shorthand label format: - # d = decode-only step, p = step containing prefill/extend work. - # req/tok = total requests/tokens in this step. - # dec/pre/ext each carry request count followed by token count. - step = "p" if (num_prefills > 0 or num_extends > 0) else "d" + # OOT step naming policy: + # - prefill wins if any prefill tokens exist + # - otherwise extend wins if any extend tokens exist + # - otherwise the step is decode + # `bs` is total request count, `tok` is total token count, + # and `p/e/d` are prefill/extend/decode token counts. + if num_prefills > 0: + step = "prefill" + elif num_extends > 0: + step = "extend" + else: + step = "decode" + return ( - f"{step}[req{total_reqs}, tok{num_actual_tokens}, " - f"dec{num_decodes}, tok{num_decode_tokens}, " - f"pre{num_prefills}, tok{num_prefill_tokens}, " - f"ext{num_extends}, tok{num_extend_tokens}]" + f"{step}[bs={total_reqs} tok={num_actual_tokens} " + f"p={num_prefill_tokens} e={num_extend_tokens} d={num_decode_tokens}]" ) From 99e5c7987cebc6f94b10d2581d7be5c843ea7560 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 16:37:48 +0800 Subject: [PATCH 4/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 4ff884ffb..40def6262 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -75,8 +75,8 @@ def _patch_vllm_profile_labels() -> None: from vllm.v1.worker.gpu_model_runner import GPUModelRunner # `GPUModelRunner._model_forward()` stays in Python for eager, piecewise, - # full cudagraph replay, and ubatch modes, so one patch here covers all of - # those execution paths without modifying vLLM source files. + # full cudagraph replay, and ubatch modes. We patch here so OOT can add + # step labels without modifying vLLM source files. if not getattr(GPUModelRunner, "_atom_step_label_patched", False): original_model_forward = GPUModelRunner._model_forward @@ -90,11 +90,12 @@ def _wrapped_model_forward( **model_kwargs, ): record_label = None - if _is_torch_profile_enabled( - self.vllm_config - ) and is_forward_context_available(): + prof_enabled = _is_torch_profile_enabled(self.vllm_config) + if prof_enabled and is_forward_context_available(): record_label = _build_step_profiler_label() + print('[zejun] record_label = ', record_label, flush=True) + with ( record_function(record_label) if record_label is not None From 89b431c4b7b719d0bc3d587269f076d16ba067d7 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 17:53:31 +0800 Subject: [PATCH 5/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 100 +++++++++++++++++------------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 40def6262..de0abc162 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -1,6 +1,5 @@ from collections.abc import Iterable from contextlib import nullcontext -from functools import wraps import importlib import torch @@ -10,7 +9,7 @@ get_pp_group, get_tp_group, ) -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config_or_none from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.model_executor.models.interfaces import ( SupportsPP, @@ -72,46 +71,63 @@ def _is_torch_profile_enabled(vllm_config: VllmConfig) -> bool: def _patch_vllm_profile_labels() -> None: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - - # `GPUModelRunner._model_forward()` stays in Python for eager, piecewise, - # full cudagraph replay, and ubatch modes. We patch here so OOT can add - # step labels without modifying vLLM source files. - if not getattr(GPUModelRunner, "_atom_step_label_patched", False): - original_model_forward = GPUModelRunner._model_forward - - @wraps(original_model_forward) - def _wrapped_model_forward( - self, - input_ids=None, - positions=None, - intermediate_tensors=None, - inputs_embeds=None, - **model_kwargs, - ): - record_label = None - prof_enabled = _is_torch_profile_enabled(self.vllm_config) - if prof_enabled and is_forward_context_available(): - record_label = _build_step_profiler_label() - - print('[zejun] record_label = ', record_label, flush=True) - - with ( - record_function(record_label) - if record_label is not None - else nullcontext() - ): - return original_model_forward( - self, - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - - GPUModelRunner._model_forward = _wrapped_model_forward - GPUModelRunner._atom_step_label_patched = True + from vllm.v1.worker import gpu_model_runner as gpu_model_runner_mod + from vllm.v1 import utils as v1_utils_mod + + # Use vLLM's existing step boundary in execute_model(): + # + # with set_forward_context(...), + # record_function_or_nullcontext("gpu_model_runner: forward"): + # model_output = self._model_forward(...) + # + # This encloses the real model run for eager, piecewise, and full graph + # modes, so a single dynamic label hook here is simpler than patching + # separate execution paths. Note that gpu_model_runner imported + # `record_function_or_nullcontext` by value, so patching only v1.utils is + # not sufficient after import; we patch the local symbol in + # gpu_model_runner as well. + if not getattr(gpu_model_runner_mod, "_atom_step_label_patched", False): + original_record_ctx = gpu_model_runner_mod.record_function_or_nullcontext + original_utils_record_ctx = v1_utils_mod.record_function_or_nullcontext + + class _DynamicForwardRecordContext: + def __init__(self, name: str, original_ctx): + self.name = name + self.original_ctx = original_ctx + self.ctx = nullcontext() + + def __enter__(self): + if self.name == "gpu_model_runner: forward": + vllm_config = get_current_vllm_config_or_none() + if ( + vllm_config is not None + and _is_torch_profile_enabled(vllm_config) + and is_forward_context_available() + ): + record_label = _build_step_profiler_label() + if record_label is not None: + self.ctx = record_function(record_label) + return self.ctx.__enter__() + + self.ctx = self.original_ctx(self.name) + return self.ctx.__enter__() + + def __exit__(self, exc_type, exc, tb): + return self.ctx.__exit__(exc_type, exc, tb) + + def _wrapped_gpu_record_function_or_nullcontext(name: str): + return _DynamicForwardRecordContext(name, original_record_ctx) + + def _wrapped_utils_record_function_or_nullcontext(name: str): + return _DynamicForwardRecordContext(name, original_utils_record_ctx) + + gpu_model_runner_mod.record_function_or_nullcontext = ( + _wrapped_gpu_record_function_or_nullcontext + ) + v1_utils_mod.record_function_or_nullcontext = ( + _wrapped_utils_record_function_or_nullcontext + ) + gpu_model_runner_mod._atom_step_label_patched = True def _get_step_attn_metadata_list(): From 9d570339c57a6e5978d2a3d2fab1f13d5bd54bcf Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 18:02:16 +0800 Subject: [PATCH 6/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index de0abc162..0810e95bc 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -106,6 +106,7 @@ def __enter__(self): ): record_label = _build_step_profiler_label() if record_label is not None: + print('[zejun] record_label = ', record_label, flush=True) self.ctx = record_function(record_label) return self.ctx.__enter__() From fcb8cdc40bd48870fe6ce7f13cf53aea3eb9a86a Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 18:11:34 +0800 Subject: [PATCH 7/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 0810e95bc..77617d1d4 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -99,12 +99,16 @@ def __init__(self, name: str, original_ctx): def __enter__(self): if self.name == "gpu_model_runner: forward": vllm_config = get_current_vllm_config_or_none() + print('[zejun] vllm_config = ', vllm_config, flush=True) + print('[zejun] _is_torch_profile_enabled(vllm_config) = ', _is_torch_profile_enabled(vllm_config), flush=True) + print('[zejun] is_forward_context_available() = ', is_forward_context_available(), flush=True) if ( vllm_config is not None and _is_torch_profile_enabled(vllm_config) and is_forward_context_available() ): record_label = _build_step_profiler_label() + print('[zejun] record_label = ', record_label, flush=True) if record_label is not None: print('[zejun] record_label = ', record_label, flush=True) self.ctx = record_function(record_label) From 42e5334fa0d669b6bb2e8111421f4f35d5765a2d Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 19:00:10 +0800 Subject: [PATCH 8/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 33 ++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 77617d1d4..89d1b7d5c 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -9,7 +9,7 @@ get_pp_group, get_tp_group, ) -from vllm.config import VllmConfig, get_current_vllm_config_or_none +from vllm.config import VllmConfig from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.model_executor.models.interfaces import ( SupportsPP, @@ -28,6 +28,7 @@ import logging logger = logging.getLogger("atom") +_ATOM_OOT_TORCH_PROFILE_ACTIVE = False _ATOM_MODEL_CLASSES: dict[str, str] = { @@ -71,9 +72,30 @@ def _is_torch_profile_enabled(vllm_config: VllmConfig) -> bool: def _patch_vllm_profile_labels() -> None: + from vllm.profiler.wrapper import TorchProfilerWrapper from vllm.v1.worker import gpu_model_runner as gpu_model_runner_mod from vllm.v1 import utils as v1_utils_mod + global _ATOM_OOT_TORCH_PROFILE_ACTIVE + + if not getattr(TorchProfilerWrapper, "_atom_step_label_patched", False): + original_call_start = TorchProfilerWrapper._call_start + original_call_stop = TorchProfilerWrapper._call_stop + + def _wrapped_call_start(self): + global _ATOM_OOT_TORCH_PROFILE_ACTIVE + original_call_start(self) + _ATOM_OOT_TORCH_PROFILE_ACTIVE = bool(getattr(self, "_running", False)) + + def _wrapped_call_stop(self): + global _ATOM_OOT_TORCH_PROFILE_ACTIVE + original_call_stop(self) + _ATOM_OOT_TORCH_PROFILE_ACTIVE = False + + TorchProfilerWrapper._call_start = _wrapped_call_start + TorchProfilerWrapper._call_stop = _wrapped_call_stop + TorchProfilerWrapper._atom_step_label_patched = True + # Use vLLM's existing step boundary in execute_model(): # # with set_forward_context(...), @@ -98,19 +120,12 @@ def __init__(self, name: str, original_ctx): def __enter__(self): if self.name == "gpu_model_runner: forward": - vllm_config = get_current_vllm_config_or_none() - print('[zejun] vllm_config = ', vllm_config, flush=True) - print('[zejun] _is_torch_profile_enabled(vllm_config) = ', _is_torch_profile_enabled(vllm_config), flush=True) - print('[zejun] is_forward_context_available() = ', is_forward_context_available(), flush=True) if ( - vllm_config is not None - and _is_torch_profile_enabled(vllm_config) + _ATOM_OOT_TORCH_PROFILE_ACTIVE and is_forward_context_available() ): record_label = _build_step_profiler_label() - print('[zejun] record_label = ', record_label, flush=True) if record_label is not None: - print('[zejun] record_label = ', record_label, flush=True) self.ctx = record_function(record_label) return self.ctx.__enter__() From 09a988a6de5234d0e02e7b6c8103b566b03bf086 Mon Sep 17 00:00:00 2001 From: zejunchen-zejun Date: Wed, 18 Mar 2026 19:10:09 +0800 Subject: [PATCH 9/9] add Signed-off-by: zejunchen-zejun --- atom/plugin/vllm/model_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 89d1b7d5c..139f05526 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -125,6 +125,7 @@ def __enter__(self): and is_forward_context_available() ): record_label = _build_step_profiler_label() + print('[zejun] record_label = ', record_label, flush=True) if record_label is not None: self.ctx = record_function(record_label) return self.ctx.__enter__()