diff --git a/README.md b/README.md index 55cbe4f..0c0530a 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,7 @@ xe-forge-skill profile kernel.py --spec spec.yaml ## Writing the Model Class Every kernel file must contain a `Model` class that wraps the Triton kernel launch. The optimizer uses this class to execute, benchmark, and verify correctness. +The `Model` class can include custom initialization and an optional `get_example_inputs` method to provide a complex combination of inputs if they cannot be generated randomly by their shapes and a common dtype. ### Structure @@ -298,6 +299,13 @@ class Model(torch.nn.Module): num_warps=32, ) return OUT + + # Optional + # If this method is provided, the Xe-Forge will get input tensors by this method. + # If not, Xe-Forge will generate random inputs based on shapes and dtype. + # Don't provide this method (remove it) if you don't need it. + def get_example_inputs(self, input_shapes: list | None = None, device: str = 'xpu'): + pass ``` ### Model with Init Arguments diff --git a/src/xe_forge/core/executor.py b/src/xe_forge/core/executor.py index 8debed4..b0fb327 100644 --- a/src/xe_forge/core/executor.py +++ b/src/xe_forge/core/executor.py @@ -27,6 +27,11 @@ logger = logging.getLogger(__name__) +def _has_callable_attr(obj, attr_name): + """Check if object has a callable attribute with the given name.""" + return hasattr(obj, attr_name) and callable(getattr(obj, attr_name)) + + @dataclass class ComparisonResult: """Result of comparing original vs optimized kernel performance.""" @@ -153,7 +158,9 @@ def execute( # Create inputs if not provided if inputs is None: - if input_shapes: + if _has_callable_attr(model, "get_example_inputs"): + inputs = model.get_example_inputs(input_shapes, self.device) + elif input_shapes: inputs = self._create_inputs( input_shapes, dtype=dtype, input_dtypes=input_dtypes ) @@ -302,7 +309,10 @@ def _check_correctness( # Shared inputs with deterministic seed set_all_seeds(123) - inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes) + if _has_callable_attr(original_model, "get_example_inputs"): + inputs = original_model.get_example_inputs(input_shapes, self.device) + else: + inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes) inputs_orig = [inp.clone() for inp in inputs] inputs_opt = [inp.clone() for inp in inputs] diff --git a/src/xe_forge/core/spec_loader.py b/src/xe_forge/core/spec_loader.py index 53d76a9..f407ff2 100644 --- a/src/xe_forge/core/spec_loader.py +++ b/src/xe_forge/core/spec_loader.py @@ -234,8 +234,8 @@ def get_flop( # Substitute dimension values into formula, then evaluate via ai_bench formula = str(variant.flop_formula) - for dim, value in variant.dims.items(): - formula = formula.replace(dim, str(value)) + for key in sorted(variant.dims.keys(), key=len, reverse=True): + formula = formula.replace(key, str(variant.dims[key])) return eval_eq(formula) def get_rtol(