Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion magi_compiler/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def _magi_compile_instance(
# module.__class__.forward is the unbound class method β€” never affected by our
# instance-level override, so calling it goes straight to original forward logic.
old_call = module.__class__.forward
module._magi_original_forward = module.forward

@torch.compiler.disable()
def new_call(*args, **kwargs):
Expand All @@ -250,7 +251,7 @@ def new_call(*args, **kwargs):
return old_call(module, *args, **kwargs)

with _isolated_dynamo_config():
return _run_orchestration(state, lambda: old_call(module, *args, **kwargs), args, kwargs)
return _run_orchestration(state, lambda: module.__class__.__call__(module, *args, **kwargs), args, kwargs)

module.forward = new_call
module._magi_compiled = True
Expand Down
10 changes: 8 additions & 2 deletions magi_compiler/magi_backend/magi_compiler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(

if isinstance(obj, torch.nn.Module):
self.original_code_object: CodeType = obj.__class__.forward.__code__
self._target_callable = obj.forward
self._target_callable = getattr(obj, "_magi_original_forward", obj.forward)
elif callable(obj):
self.original_code_object: CodeType = inspect.unwrap(obj).__code__
self._target_callable = obj
Expand Down Expand Up @@ -315,7 +315,13 @@ def dispatch_to_compiled_fwd(self, mode: Literal["jit", "aot"] = "jit"):
assert self.compiled_code is not None
if isinstance(self.obj, torch.nn.Module):
self.obj.__class__.forward.__code__ = self.compiled_code
yield
if hasattr(self.obj, "_magi_original_forward"):
original_forward = self.obj.forward
self.obj.forward = self.obj.__class__.forward.__get__(self.obj, self.obj.__class__)
yield
self.obj.forward = original_forward
else:
yield
self.obj.__class__.forward.__code__ = self.original_code_object
else:
# Function/Method level
Expand Down
74 changes: 74 additions & 0 deletions tests/api_tests/test_magi_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import shutil
import tempfile
import time
from typing import Tuple
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -481,3 +482,76 @@ class ClsSimpleModel(SimpleModel):

assert_close(cls_out, native_out, rtol=1e-3, atol=1e-3)
assert_close(inst_out, native_out, rtol=1e-3, atol=1e-3)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support for stable timing")
def test_simple_model_timing_class_function_instance_method(self):
"""Lightweight timing sanity: Class / Function / Instance / Method entrypoints."""

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.dim = 32
self.layers = nn.ModuleList([nn.Linear(self.dim, self.dim) for _ in range(4)])

def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
res = x
x = layer(x)
x = torch.nn.functional.gelu(x)
x = x + res
return x

device = torch.device("cuda:0")
seq_len = 16
test_input = torch.randn(seq_len, 32, device=device)

native = SimpleModel().to(device).eval()

@magi_compile(dynamic_arg_dims={"x": 0})
class ClsSimpleModel(SimpleModel):
pass

cls_model = ClsSimpleModel().to(device).eval()
cls_model.load_state_dict(native.state_dict())

inst_model = SimpleModel().to(device).eval()
inst_model.load_state_dict(native.state_dict())
inst_model = magi_compile(inst_model, dynamic_arg_dims={"x": 0})

func_native = SimpleModel().to(device).eval()
func_native.load_state_dict(native.state_dict())

@magi_compile(dynamic_arg_dims={"x": 0})
def func_entry(x: torch.Tensor) -> torch.Tensor:
return func_native(x)

mtd_model = SimpleModel().to(device).eval()
mtd_model.load_state_dict(native.state_dict())
mtd_model.forward = magi_compile(mtd_model.forward, dynamic_arg_dims={"x": 0})

def _bench(callable_obj, label: str, warmup: int = 5, iters: int = 200) -> float:
with torch.no_grad():
for _ in range(warmup):
callable_obj(test_input)
torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
for _ in range(iters):
callable_obj(test_input)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print(f"{label}: {elapsed:.4f}s")
return elapsed

t_class = _bench(cls_model, "class")
t_func = _bench(func_entry, "function")
t_inst = _bench(inst_model, "instance")
t_mtd = _bench(mtd_model, "method")

compiled_times = [t_class, t_func, t_inst, t_mtd]
max_compiled = max(compiled_times)
min_compiled = min(compiled_times)
assert max_compiled / min_compiled < 1.2, (
"Magi entry timings diverged too much: "
f"class={t_class:.4f}s, function={t_func:.4f}s, instance={t_inst:.4f}s, method={t_mtd:.4f}s"
)
Loading