11import torch
2- import os
2+ import argparse
33
44from mlir import ir
55from mlir .dialects import transform
66from mlir .dialects .transform import structured
77from mlir .dialects .transform import interpreter
88from mlir .execution_engine import ExecutionEngine
99from mlir .passmanager import PassManager
10- from mlir .runtime .np_to_memref import (
11- get_ranked_memref_descriptor ,
12- )
1310
1411from lighthouse import utils as lh_utils
1512
@@ -50,25 +47,26 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
5047 ir .UnitAttr .get ()
5148 )
5249
50+ # For simplicity, use generic matchers without requiring specific types.
51+ anytype = transform .any_op_t ()
52+
5353 # Create entry point transformation sequence.
5454 with ir .InsertionPoint (schedule .body ):
5555 named_seq = transform .NamedSequenceOp (
56- "__transform_main" ,
57- [ transform . AnyOpType . get () ],
58- [],
56+ sym_name = "__transform_main" ,
57+ input_types = [ anytype ],
58+ result_types = [],
5959 arg_attrs = [{"transform.readonly" : ir .UnitAttr .get ()}],
6060 )
6161
6262 # Create the schedule.
6363 with ir .InsertionPoint (named_seq .body ):
64- # For simplicity, use generic transform matchers.
65- anytype = transform .AnyOpType .get ()
66-
6764 # Find the kernel's function op.
6865 func = structured .MatchOp .match_op_names (
6966 named_seq .bodyTarget , ["func.func" ]
7067 )
71- # Use C interface wrappers - required to make function executable after jitting.
68+ # Use C interface wrappers - required to make function executable
69+ # after jitting.
7270 func = transform .apply_registered_pass (
7371 anytype , func , "llvm-request-c-wrappers"
7472 )
@@ -82,12 +80,12 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
8280 anytype , mod , "convert-linalg-to-loops"
8381 )
8482 # Cleanup.
85- transform .ApplyCommonSubexpressionEliminationOp (mod )
83+ transform .apply_cse (mod )
8684 with ir .InsertionPoint (transform .ApplyPatternsOp (mod ).patterns ):
87- transform .ApplyCanonicalizationPatternsOp ()
85+ transform .apply_patterns_canonicalization ()
8886
8987 # Terminate the schedule.
90- transform .YieldOp ( )
88+ transform .yield_ ([] )
9189 return schedule
9290
9391
@@ -129,7 +127,7 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager:
129127
130128
131129# The example's entry point.
132- def main ():
130+ def main (args ):
133131 ### Baseline computation ###
134132 # Create inputs.
135133 a = torch .randn (16 , 32 , dtype = torch .float32 )
@@ -152,28 +150,37 @@ def main():
152150 pm .run (kernel .operation )
153151
154152 ### Compilation ###
155- # External shared libraries, containing MLIR runner utilities, are are generally
156- # required to execute the compiled module.
153+ # Parse additional libraries if present.
157154 #
158- # Get paths to MLIR runner shared libraries through an environment variable.
159- mlir_libs = os .environ .get ("LIGHTHOUSE_SHARED_LIBS" , default = "" ).split (":" )
155+ # External shared libraries, runtime utilities, might be needed to execute
156+ # the compiled module.
157+ # The execution engine requires full paths to the libraries.
158+ mlir_libs = []
159+ if args .shared_libs :
160+ mlir_libs += args .shared_libs .split ("," )
160161
161162 # JIT the kernel.
162163 eng = ExecutionEngine (kernel , opt_level = 2 , shared_libs = mlir_libs )
164+
165+ # Initialize the JIT engine.
166+ #
167+ # The deferred initialization executes global constructors that might
168+ # have been created by the module during engine creation (for example,
169+ # when `gpu.module` is present) or registered afterwards.
170+ #
171+ # Initialization is not strictly necessary in this case.
172+ # However, it is a good practice to perform it regardless.
173+ eng .initialize ()
174+
163175 # Get the kernel function.
164176 add_func = eng .lookup ("add" )
165177
166178 ### Execution ###
167- # Create corresponding memref descriptors containing input data.
168- a_mem = get_ranked_memref_descriptor (a .numpy ())
169- b_mem = get_ranked_memref_descriptor (b .numpy ())
170-
171179 # Create an empty buffer to hold results.
172180 out = torch .empty_like (out_ref )
173- out_mem = get_ranked_memref_descriptor (out .numpy ())
174181
175182 # Execute the kernel.
176- args = lh_utils .memrefs_to_packed_args ([ a_mem , b_mem , out_mem ])
183+ args = lh_utils .torch_to_packed_args ([ a , b , out ])
177184 add_func (args )
178185
179186 ### Verification ###
@@ -185,4 +192,21 @@ def main():
185192
186193
187194if __name__ == "__main__" :
188- main ()
195+ parser = argparse .ArgumentParser ()
196+
197+ # External shared libraries, runtime utilities, might be needed to
198+ # execute the compiled module.
199+ # For example, MLIR runner utils libraries such as:
200+ # - libmlir_runner_utils.so
201+ # - libmlir_c_runner_utils.so
202+ #
203+ # Full paths to the libraries should be provided.
204+ # For example:
205+ # --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so
206+ parser .add_argument (
207+ "--shared-libs" ,
208+ type = str ,
209+ help = "Comma-separated list of libraries to link dynamically" ,
210+ )
211+ args = parser .parse_args ()
212+ main (args )
0 commit comments