Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ xe-forge-skill profile kernel.py --spec spec.yaml
## Writing the Model Class

Every kernel file must contain a `Model` class that wraps the Triton kernel launch. The optimizer uses this class to execute, benchmark, and verify correctness.
The `Model` class can include custom initialization and an optional `get_example_inputs` method to provide a complex combination of inputs if they cannot be generated randomly by their shapes and a common dtype.

### Structure

Expand Down Expand Up @@ -298,6 +299,13 @@ class Model(torch.nn.Module):
num_warps=32,
)
return OUT

# Optional
# If this method is provided, the Xe-Forge will get input tensors by this method.
# If not, Xe-Forge will generate random inputs based on shapes and dtype.
# Don't provide this method (remove it) if you don't need it.
def get_example_inputs(self, input_shapes: list | None = None, device: str = 'xpu'):
pass
```

### Model with Init Arguments
Expand Down
14 changes: 12 additions & 2 deletions src/xe_forge/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
logger = logging.getLogger(__name__)


def _has_callable_attr(obj, attr_name):
"""Check if object has a callable attribute with the given name."""
return hasattr(obj, attr_name) and callable(getattr(obj, attr_name))


@dataclass
class ComparisonResult:
"""Result of comparing original vs optimized kernel performance."""
Expand Down Expand Up @@ -153,7 +158,9 @@ def execute(

# Create inputs if not provided
if inputs is None:
if input_shapes:
if _has_callable_attr(model, "get_example_inputs"):
inputs = model.get_example_inputs(input_shapes, self.device)
elif input_shapes:
inputs = self._create_inputs(
input_shapes, dtype=dtype, input_dtypes=input_dtypes
)
Expand Down Expand Up @@ -302,7 +309,10 @@ def _check_correctness(

# Shared inputs with deterministic seed
set_all_seeds(123)
inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes)
if _has_callable_attr(original_model, "get_example_inputs"):
inputs = original_model.get_example_inputs(input_shapes, self.device)
else:
inputs = self._create_inputs(input_shapes, dtype=dtype, input_dtypes=input_dtypes)

inputs_orig = [inp.clone() for inp in inputs]
inputs_opt = [inp.clone() for inp in inputs]
Expand Down
4 changes: 2 additions & 2 deletions src/xe_forge/core/spec_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def get_flop(

# Substitute dimension values into formula, then evaluate via ai_bench
formula = str(variant.flop_formula)
for dim, value in variant.dims.items():
formula = formula.replace(dim, str(value))
for key in sorted(variant.dims.keys(), key=len, reverse=True):
formula = formula.replace(key, str(variant.dims[key]))
return eval_eq(formula)

def get_rtol(
Expand Down
Loading