From a3cfb76f98476566a5e567aff1018286ce5549f1 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 3 Jun 2026 15:51:29 +0800 Subject: [PATCH 1/3] Optional get_example_inputs method of model --- README.md | 7 +++++++ src/xe_forge/core/executor.py | 9 +++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6f05e4a..65b5f54 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,13 @@ class Model(torch.nn.Module): num_warps=32, ) return OUT + + # Optional + # If this method is provided, the Xe-Forge will get input tensors by this method. + # If not, Xe-Forge will generate random inputs based on shapes and dtype. + # Don't provide this method (remove it) if you don't need it. + def get_example_inputs(self, input_shapes: list | None = None): + pass ``` ### Model with Init Arguments diff --git a/src/xe_forge/core/executor.py b/src/xe_forge/core/executor.py index ccdc879..2aea9b1 100644 --- a/src/xe_forge/core/executor.py +++ b/src/xe_forge/core/executor.py @@ -151,7 +151,9 @@ def execute( # Create inputs if not provided if inputs is None: - if input_shapes: + if hasattr(model, "get_example_inputs"): + inputs = model.get_example_inputs(input_shapes, self.device) + elif input_shapes: inputs = self._create_inputs(input_shapes, dtype=dtype) else: return ExecutionResult( @@ -297,7 +299,10 @@ def _check_correctness( # Shared inputs with deterministic seed set_all_seeds(123) - inputs = self._create_inputs(input_shapes, dtype=dtype) + if hasattr(original_model, "get_example_inputs"): + inputs = original_model.get_example_inputs(input_shapes, self.device) + else: + inputs = self._create_inputs(input_shapes, dtype=dtype) inputs_orig = [inp.clone() for inp in inputs] inputs_opt = [inp.clone() for inp in inputs] From 0172f5ff601fba1efd8b70ed2b4602eb7f7c5bb7 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 8 Jun 2026 10:16:33 +0800 Subject: [PATCH 2/3] Check get_example_inputs is callable and update README --- README.md | 1 + src/xe_forge/core/executor.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7cf8d7e..9cf8ab3 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,7 @@ xe-forge-skill profile kernel.py --spec spec.yaml ## Writing the Model Class Every kernel file must contain a `Model` class that wraps the Triton kernel launch. The optimizer uses this class to execute, benchmark, and verify correctness. +The `Model` class can include custom initialization and an optional `get_example_inputs` method to provide a complex combination of inputs if they cannot be generated randomly by their shapes and a common dtype. ### Structure diff --git a/src/xe_forge/core/executor.py b/src/xe_forge/core/executor.py index 3b9073d..b0fb327 100644 --- a/src/xe_forge/core/executor.py +++ b/src/xe_forge/core/executor.py @@ -27,6 +27,11 @@ logger = logging.getLogger(__name__) +def _has_callable_attr(obj, attr_name): + """Check if object has a callable attribute with the given name.""" + return hasattr(obj, attr_name) and callable(getattr(obj, attr_name)) + + @dataclass class ComparisonResult: """Result of comparing original vs optimized kernel performance.""" @@ -153,7 +158,7 @@ def execute( # Create inputs if not provided if inputs is None: - if hasattr(model, "get_example_inputs"): + if _has_callable_attr(model, "get_example_inputs"): inputs = model.get_example_inputs(input_shapes, self.device) elif input_shapes: inputs = self._create_inputs( @@ -304,7 +309,7 @@ def _check_correctness( # Shared inputs with deterministic seed set_all_seeds(123) - if hasattr(original_model, "get_example_inputs"): + if _has_callable_attr(original_model, "get_example_inputs"): inputs = original_model.get_example_inputs(input_shapes, self.device) else: inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes) From 57826cbd0468622ee7244cccf2065aaf6408c76f Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 8 Jun 2026 13:03:07 +0800 Subject: [PATCH 3/3] Refine code --- README.md | 2 +- src/xe_forge/core/spec_loader.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9cf8ab3..0c0530a 100644 --- a/README.md +++ b/README.md @@ -304,7 +304,7 @@ class Model(torch.nn.Module): # If this method is provided, the Xe-Forge will get input tensors by this method. # If not, Xe-Forge will generate random inputs based on shapes and dtype. # Don't provide this method (remove it) if you don't need it. - def get_example_inputs(self, input_shapes: list | None = None): + def get_example_inputs(self, input_shapes: list | None = None, device: str = 'xpu'): pass ``` diff --git a/src/xe_forge/core/spec_loader.py b/src/xe_forge/core/spec_loader.py index 53d76a9..f407ff2 100644 --- a/src/xe_forge/core/spec_loader.py +++ b/src/xe_forge/core/spec_loader.py @@ -234,8 +234,8 @@ def get_flop( # Substitute dimension values into formula, then evaluate via ai_bench formula = str(variant.flop_formula) - for dim, value in variant.dims.items(): - formula = formula.replace(dim, str(value)) + for key in sorted(variant.dims.keys(), key=len, reverse=True): + formula = formula.replace(key, str(variant.dims[key])) return eval_eq(formula) def get_rtol(