From def18139264933b68c964516d08b9898826b860a Mon Sep 17 00:00:00 2001 From: songh11 <1020258195@qq.com> Date: Tue, 31 Mar 2026 17:27:19 +0800 Subject: [PATCH] [fix] compile state with instance-level forward wrapping --- magi_compiler/_api.py | 3 +- .../magi_backend/magi_compiler_base.py | 10 ++- tests/api_tests/test_magi_compile.py | 74 +++++++++++++++++++ 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/magi_compiler/_api.py b/magi_compiler/_api.py index 1b811a4..2f51eec 100644 --- a/magi_compiler/_api.py +++ b/magi_compiler/_api.py @@ -241,6 +241,7 @@ def _magi_compile_instance( # module.__class__.forward is the unbound class method — never affected by our # instance-level override, so calling it goes straight to original forward logic. old_call = module.__class__.forward + module._magi_original_forward = module.forward @torch.compiler.disable() def new_call(*args, **kwargs): @@ -250,7 +251,7 @@ def new_call(*args, **kwargs): return old_call(module, *args, **kwargs) with _isolated_dynamo_config(): - return _run_orchestration(state, lambda: old_call(module, *args, **kwargs), args, kwargs) + return _run_orchestration(state, lambda: module.__class__.__call__(module, *args, **kwargs), args, kwargs) module.forward = new_call module._magi_compiled = True diff --git a/magi_compiler/magi_backend/magi_compiler_base.py b/magi_compiler/magi_backend/magi_compiler_base.py index 1fb0732..3f6230e 100644 --- a/magi_compiler/magi_backend/magi_compiler_base.py +++ b/magi_compiler/magi_backend/magi_compiler_base.py @@ -97,7 +97,7 @@ def __init__( if isinstance(obj, torch.nn.Module): self.original_code_object: CodeType = obj.__class__.forward.__code__ - self._target_callable = obj.forward + self._target_callable = getattr(obj, "_magi_original_forward", obj.forward) elif callable(obj): self.original_code_object: CodeType = inspect.unwrap(obj).__code__ self._target_callable = obj @@ -315,7 +315,13 @@ def dispatch_to_compiled_fwd(self, mode: Literal["jit", "aot"] = "jit"): assert self.compiled_code is not None if isinstance(self.obj, torch.nn.Module): self.obj.__class__.forward.__code__ = self.compiled_code - yield + if hasattr(self.obj, "_magi_original_forward"): + original_forward = self.obj.forward + self.obj.forward = self.obj.__class__.forward.__get__(self.obj, self.obj.__class__) + yield + self.obj.forward = original_forward + else: + yield self.obj.__class__.forward.__code__ = self.original_code_object else: # Function/Method level diff --git a/tests/api_tests/test_magi_compile.py b/tests/api_tests/test_magi_compile.py index 78564bb..63ddd78 100644 --- a/tests/api_tests/test_magi_compile.py +++ b/tests/api_tests/test_magi_compile.py @@ -18,6 +18,7 @@ import shutil import tempfile +import time from typing import Tuple from unittest.mock import MagicMock, patch @@ -481,3 +482,76 @@ class ClsSimpleModel(SimpleModel): assert_close(cls_out, native_out, rtol=1e-3, atol=1e-3) assert_close(inst_out, native_out, rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support for stable timing") + def test_simple_model_timing_class_function_instance_method(self): + """Lightweight timing sanity: Class / Function / Instance / Method entrypoints.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.dim = 32 + self.layers = nn.ModuleList([nn.Linear(self.dim, self.dim) for _ in range(4)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + res = x + x = layer(x) + x = torch.nn.functional.gelu(x) + x = x + res + return x + + device = torch.device("cuda:0") + seq_len = 16 + test_input = torch.randn(seq_len, 32, device=device) + + native = SimpleModel().to(device).eval() + + @magi_compile(dynamic_arg_dims={"x": 0}) + class ClsSimpleModel(SimpleModel): + pass + + cls_model = ClsSimpleModel().to(device).eval() + cls_model.load_state_dict(native.state_dict()) + + inst_model = SimpleModel().to(device).eval() + inst_model.load_state_dict(native.state_dict()) + inst_model = magi_compile(inst_model, dynamic_arg_dims={"x": 0}) + + func_native = SimpleModel().to(device).eval() + func_native.load_state_dict(native.state_dict()) + + @magi_compile(dynamic_arg_dims={"x": 0}) + def func_entry(x: torch.Tensor) -> torch.Tensor: + return func_native(x) + + mtd_model = SimpleModel().to(device).eval() + mtd_model.load_state_dict(native.state_dict()) + mtd_model.forward = magi_compile(mtd_model.forward, dynamic_arg_dims={"x": 0}) + + def _bench(callable_obj, label: str, warmup: int = 5, iters: int = 200) -> float: + with torch.no_grad(): + for _ in range(warmup): + callable_obj(test_input) + torch.cuda.synchronize() + start = time.perf_counter() + with torch.no_grad(): + for _ in range(iters): + callable_obj(test_input) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + print(f"{label}: {elapsed:.4f}s") + return elapsed + + t_class = _bench(cls_model, "class") + t_func = _bench(func_entry, "function") + t_inst = _bench(inst_model, "instance") + t_mtd = _bench(mtd_model, "method") + + compiled_times = [t_class, t_func, t_inst, t_mtd] + max_compiled = max(compiled_times) + min_compiled = min(compiled_times) + assert max_compiled / min_compiled < 1.2, ( + "Magi entry timings diverged too much: " + f"class={t_class:.4f}s, function={t_func:.4f}s, instance={t_inst:.4f}s, method={t_mtd:.4f}s" + )