Skip to content

Commit f7db6c4

Browse files
committed
fix rebase mess
1 parent bd23099 commit f7db6c4

File tree

3 files changed

+51
-59
lines changed

3 files changed

+51
-59
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
"""Tests for MLIR generation."""

ingress/mlir-gen/mlir_gen/test/conftest.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,5 @@ def setup_mlir_environment():
2121

2222
yield
2323

24-
# Cleanup after all tests (if needed)
24+
# todo: cleanup
2525
pass
26-
27-
28-
@pytest.fixture
29-
def mlir_context():
30-
"""
31-
Provide a fresh MLIR context for each test.
32-
"""
33-
from mlir import ir
34-
35-
return ir.Context()
36-
37-
38-
@pytest.fixture
39-
def sample_shapes():
40-
"""
41-
Provide common tensor shapes for testing.
42-
"""
43-
return [
44-
(4, 16),
45-
(8, 8),
46-
(16, 32),
47-
(1, 64),
48-
]
49-
50-
51-
@pytest.fixture
52-
def sample_types():
53-
"""
54-
Provide common element types for testing.
55-
"""
56-
return ["f32", "f64"]

python/examples/mlir/compile_and_run.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import torch
2-
import os
2+
import argparse
33

44
from mlir import ir
55
from mlir.dialects import transform
66
from mlir.dialects.transform import structured
77
from mlir.dialects.transform import interpreter
88
from mlir.execution_engine import ExecutionEngine
99
from mlir.passmanager import PassManager
10-
from mlir.runtime.np_to_memref import (
11-
get_ranked_memref_descriptor,
12-
)
1310

1411
from lighthouse import utils as lh_utils
1512

@@ -50,25 +47,26 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
5047
ir.UnitAttr.get()
5148
)
5249

50+
# For simplicity, use generic matchers without requiring specific types.
51+
anytype = transform.any_op_t()
52+
5353
# Create entry point transformation sequence.
5454
with ir.InsertionPoint(schedule.body):
5555
named_seq = transform.NamedSequenceOp(
56-
"__transform_main",
57-
[transform.AnyOpType.get()],
58-
[],
56+
sym_name="__transform_main",
57+
input_types=[anytype],
58+
result_types=[],
5959
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
6060
)
6161

6262
# Create the schedule.
6363
with ir.InsertionPoint(named_seq.body):
64-
# For simplicity, use generic transform matchers.
65-
anytype = transform.AnyOpType.get()
66-
6764
# Find the kernel's function op.
6865
func = structured.MatchOp.match_op_names(
6966
named_seq.bodyTarget, ["func.func"]
7067
)
71-
# Use C interface wrappers - required to make function executable after jitting.
68+
# Use C interface wrappers - required to make function executable
69+
# after jitting.
7270
func = transform.apply_registered_pass(
7371
anytype, func, "llvm-request-c-wrappers"
7472
)
@@ -82,12 +80,12 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
8280
anytype, mod, "convert-linalg-to-loops"
8381
)
8482
# Cleanup.
85-
transform.ApplyCommonSubexpressionEliminationOp(mod)
83+
transform.apply_cse(mod)
8684
with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns):
87-
transform.ApplyCanonicalizationPatternsOp()
85+
transform.apply_patterns_canonicalization()
8886

8987
# Terminate the schedule.
90-
transform.YieldOp()
88+
transform.yield_([])
9189
return schedule
9290

9391

@@ -129,7 +127,7 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager:
129127

130128

131129
# The example's entry point.
132-
def main():
130+
def main(args):
133131
### Baseline computation ###
134132
# Create inputs.
135133
a = torch.randn(16, 32, dtype=torch.float32)
@@ -152,28 +150,37 @@ def main():
152150
pm.run(kernel.operation)
153151

154152
### Compilation ###
155-
# External shared libraries, containing MLIR runner utilities, are are generally
156-
# required to execute the compiled module.
153+
# Parse additional libraries if present.
157154
#
158-
# Get paths to MLIR runner shared libraries through an environment variable.
159-
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":")
155+
# External shared libraries, runtime utilities, might be needed to execute
156+
# the compiled module.
157+
# The execution engine requires full paths to the libraries.
158+
mlir_libs = []
159+
if args.shared_libs:
160+
mlir_libs += args.shared_libs.split(",")
160161

161162
# JIT the kernel.
162163
eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs)
164+
165+
# Initialize the JIT engine.
166+
#
167+
# The deferred initialization executes global constructors that might
168+
# have been created by the module during engine creation (for example,
169+
# when `gpu.module` is present) or registered afterwards.
170+
#
171+
# Initialization is not strictly necessary in this case.
172+
# However, it is a good practice to perform it regardless.
173+
eng.initialize()
174+
163175
# Get the kernel function.
164176
add_func = eng.lookup("add")
165177

166178
### Execution ###
167-
# Create corresponding memref descriptors containing input data.
168-
a_mem = get_ranked_memref_descriptor(a.numpy())
169-
b_mem = get_ranked_memref_descriptor(b.numpy())
170-
171179
# Create an empty buffer to hold results.
172180
out = torch.empty_like(out_ref)
173-
out_mem = get_ranked_memref_descriptor(out.numpy())
174181

175182
# Execute the kernel.
176-
args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem])
183+
args = lh_utils.torch_to_packed_args([a, b, out])
177184
add_func(args)
178185

179186
### Verification ###
@@ -185,4 +192,21 @@ def main():
185192

186193

187194
if __name__ == "__main__":
188-
main()
195+
parser = argparse.ArgumentParser()
196+
197+
# External shared libraries, runtime utilities, might be needed to
198+
# execute the compiled module.
199+
# For example, MLIR runner utils libraries such as:
200+
# - libmlir_runner_utils.so
201+
# - libmlir_c_runner_utils.so
202+
#
203+
# Full paths to the libraries should be provided.
204+
# For example:
205+
# --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so
206+
parser.add_argument(
207+
"--shared-libs",
208+
type=str,
209+
help="Comma-separated list of libraries to link dynamically",
210+
)
211+
args = parser.parse_args()
212+
main(args)

0 commit comments

Comments
 (0)