diff --git a/ftn/tools/config.sh b/ftn/tools/config.sh new file mode 100644 index 0000000..6f9e4a3 --- /dev/null +++ b/ftn/tools/config.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +DOCKER_RUN="docker run -i -u $(id -u):$(id -g) -v $(pwd):$(pwd) -w $(pwd)" + +alias python-soda="$DOCKER_RUN agostini01/soda:v15.08:v15.08 python3" +alias soda-opt="$DOCKER_RUN agostini01/soda:v15.08 soda-opt" +alias opt-16="$DOCKER_RUN agostini01/soda:v15.08 opt" +alias mlir-opt="$DOCKER_RUN agostini01/soda:v15.08 mlir-opt" +alias soda-mlir-translate="$DOCKER_RUN agostini01/soda:v15.08 mlir-translate" \ No newline at end of file diff --git a/ftn/tools/ftn_opt.py b/ftn/tools/ftn_opt.py index 931f854..1b5c08f 100755 --- a/ftn/tools/ftn_opt.py +++ b/ftn/tools/ftn_opt.py @@ -5,8 +5,9 @@ from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore from ftn.transforms.merge_memref_deref import MergeMemRefDeref +from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.fpga.target_to_hls import TargetToHLSPass from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass -# from ftn.transforms.extract_target import ExtractTarget # from ftn.transforms.isolate_target import IsolateTarget # from psy.extract_stencil import ExtractStencil # from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT @@ -27,8 +28,9 @@ def register_all_passes(self): super().register_all_passes() self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore) self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref) + self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("target-to-hls", lambda: TargetToHLSPass) self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass) - # self.register_pass("extract-target", lambda: ExtractTarget) # self.register_pass("isolate-target", lambda: IsolateTarget) # self.register_pass("convert-to-tt", lambda: ConvertToTT) diff --git a/ftn/tools/mlir_to_ll.sh b/ftn/tools/mlir_to_ll.sh new file mode 100755 index 0000000..a6bd274 --- /dev/null +++ b/ftn/tools/mlir_to_ll.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +ODIR=fpga +CLKPERIOD=10 +TARGET_BOARD=xc7vx690t-ffg1930-3 +MLIR_INPUT=$1 +KERNELNAME=$2 +LIBDIR=/opt/soda-opt/lib/ +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source $SCRIPT_DIR/config.sh +shopt -s expand_aliases + +soda-opt $MLIR_INPUT --lower-affine --canonicalize --lower-all-to-llvm=use-bare-ptr-memref-call-conv | \ +soda-mlir-translate --mlir-to-llvmir --opaque-pointers=0 -o model.ll + +opt-16 model.ll \ + -S \ + -enable-new-pm=0 \ + -load "${LIBDIR}/VhlsLLVMRewriter.so" \ + -mem2arr -strip-debug \ + -instcombine \ + -xlnname \ + -xlnanno -xlntop $KERNELNAME \ + -xlntbgen -xlntbdummynames="$KERNELNAME.dummy.c" \ + -xlntbtclnames="$KERNELNAME.run.tcl" \ + -xlnllvm="$KERNELNAME.opt.ll" \ + -clock-period-ns=$CLKPERIOD -target=$TARGET_BOARD \ + > $KERNELNAME.opt.ll diff --git a/ftn/tools/xftn.py b/ftn/tools/xftn.py index 99f1460..60f1a4f 100644 --- a/ftn/tools/xftn.py +++ b/ftn/tools/xftn.py @@ -55,6 +55,12 @@ def initialise_argument_parser(): default=[], help="Additional transformation pass for ftn-opt", ) + parser.add_argument( + "-f", + "--fpga-llvmir", + action="store_true", + help="Generate LLVM IR for FPGA compatible with Vitis HLS." + ) parser.add_argument( "-I", "--include-directory", @@ -96,7 +102,7 @@ def initialise_argument_parser(): parser.add_argument( "--stages", default=None, - help="Specify which stages will run (a combination of: flang, pre, ftn, post, mlir, obj, clang) in comma separated list without spaces", + help="Specify which stages will run (a combination of: flang, pre, ftn, post, mlir, obj, clang, fpga-llvmir) in comma separated list without spaces", ) parser.add_argument( "-v", @@ -132,6 +138,7 @@ def enable_disable_stages_by_output_type(options_db, output_type): options_db["run_mlir_to_llvmir_stage"] = options_db["run_postprocess_stage"] options_db["run_create_object_stage"] = output_type == OutputType.OBJECT options_db["run_build_executable_stage"] = output_type == OutputType.EXECUTABLE + options_db["run_fpga_llvmir_stage"] = False def build_options_db_from_args(args): @@ -175,6 +182,9 @@ def build_options_db_from_args(args): options_db["run_mlir_to_llvmir_stage"] = False options_db["run_build_executable_stage"] = False + if options_db["fpga_llvmir"]: + options_db["run_fpga_llvmir_stage"] = True + if options_db["stages"] is not None: # We handle stages to run last, this overrides all other stage selection # through other options or output filename @@ -188,6 +198,7 @@ def build_options_db_from_args(args): options_db["run_mlir_to_llvmir_stage"] = "mlir" in stages_to_run options_db["run_create_object_stage"] = "obj" in stages_to_run options_db["run_build_executable_stage"] = "clang" in stages_to_run + options_db["run_fpga_llvmir_stage"] = "fpga-llvmir" in stages_to_run if "flang" in stages_to_run: stages_to_run.remove("flang") @@ -203,6 +214,8 @@ def build_options_db_from_args(args): stages_to_run.remove("obj") if "clang" in stages_to_run: stages_to_run.remove("clang") + if "fpga-llvmir" in stages_to_run: + stages_to_run.remove("fpga-llvmir") if len(stages_to_run) > 0: for e in stages_to_run: print(f"Unknown stage provided as argument '{e}'") @@ -241,6 +254,9 @@ def display_verbose_start_message(options_db): print( f"Stage 'Build executable': {'Enabled' if options_db['run_build_executable_stage'] else 'Disabled'}" ) + print( + f"Stage 'FPGA LLVM-IR': {'Enabled' if options_db['run_fpga_llvmir_stage'] else 'Disabled'}" + ) if options_db["run_build_executable_stage"]: print("") print(f"Also linking against object files {options_db['link-objectfiles']}") @@ -459,6 +475,9 @@ def build_executable(output_tmp_dir, input_fn, executable_fn, options_db): os.system(f"clang {clang_args}") post_stage_check(executable_fn, options_db["verbose"], executable=True) +def generate_llvmir_for_fpga(out_file): + tools_path = os.path.dirname(os.path.abspath(__file__)) + os.system(f"{tools_path}/mlir_to_ll.sh {out_file} tt_device") def print_file_contents(filename): with open(filename) as f: @@ -568,6 +587,10 @@ def main(): if options_db["offload"]: print_verbose_message(options_db, f"Offload MLIR in '{out_file}'") + if options_db["run_fpga_llvmir_stage"]: + print_verbose_message(options_db, "Running FPGA LLVM-IR generation stage",) + generate_llvmir_for_fpga(out_file) + if options_db["cleanup"]: remove_file_if_exists( tmp_dir, diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index a84eb8c..8300612 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -1,10 +1,13 @@ from abc import ABC +from ast import Module +from hmac import new from typing import TypeVar, cast -from dataclasses import dataclass +from dataclasses import dataclass, field import itertools from xdsl.utils.hints import isa from xdsl.dialects import memref, scf, omp -from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region +from xdsl.context import Context +from xdsl.ir import Operation, SSAValue, OpResult, Attribute, Block, Region from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, op_type_rewrite_pattern, @@ -13,10 +16,12 @@ from xdsl.passes import ModulePass from xdsl.dialects import builtin, func, llvm, arith from ftn.util.visitor import Visitor +from xdsl.rewriter import InsertPoint +@dataclass class RewriteTarget(RewritePattern): - def __init__(self): - self.target_ops=[] + module : builtin.ModuleOp + target_ops: list[Operation] = field(default_factory=list) @op_type_rewrite_pattern def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): @@ -31,38 +36,38 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): # Grab bounds and info, then at end the terminator for var in op.map_vars: var_op=var.owner - var_op.parent.detach_op(var_op) - arg_types.append(var_op.var_ptr[0].type) - arg_ssa.append(var_op.var_ptr[0]) - if isa(var_op.var_ptr[0].type, builtin.MemRefType): - memref_type=var_op.var_ptr[0].type - src_memref=var_op.var_ptr[0] + arg_types.append(var_op.var_ptr.type) + arg_ssa.append(var_op.var_ptr) + locations[var_op]=loc_idx + loc_idx+=1 + if isa(var_op.var_ptr.type, builtin.MemRefType): + memref_type=var_op.var_ptr.type + src_memref=var_op.var_ptr if isa(memref_type.element_type, builtin.MemRefType): assert len(memref_type.shape) == 0 - memref_type=var_op.var_ptr[0].type.element_type - memref_loadop=memref.Load.get(src_memref, []) + memref_type=var_op.var_ptr.type.element_type + memref_loadop=memref.LoadOp.get(src_memref, []) src_memref=memref_loadop.results[0] memref_dim_ops.append(memref_loadop) for idx, s in enumerate(memref_type.shape): assert isa(s, builtin.IntAttr) if (s.data == -1): # Need to pass the dimension shape size in explicitly as it is deferred - const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType()) - dim_size=memref.Dim.from_source_and_index(src_memref, const_op) + const_op=arith.ConstantOp.from_int_and_width(idx, builtin.IndexType()) + dim_size=memref.DimOp.from_source_and_index(src_memref, const_op) memref_dim_ops+=[const_op, dim_size] arg_ssa.append(dim_size.results[0]) arg_types.append(dim_size.results[0].type) + loc_idx+=1 - locations[var_op]=loc_idx - loc_idx+=1 if len(var_op.bounds) > 0: bound_op=var_op.bounds[0].owner - bound_op.parent.detach_op(bound_op) - #self.target_ops+=[bound_op, var_op] - arg_types.append(bound_op.lower[0].type) - arg_ssa.append(bound_op.lower[0]) + arg_types.append(bound_op.lower_bound.type) + arg_ssa.append(bound_op.lower_bound) + arg_types.append(bound_op.upper_bound.type) + arg_ssa.append(bound_op.upper_bound) locations[bound_op]=loc_idx - # Add two, as second is the size + # Adding both lower and upper bound loc_idx+=2 else: pass#self.target_ops+=[var_op] @@ -78,7 +83,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): bound_op=var_op.bounds[0].owner res_types=[] for res in bound_op.results: res_types.append(res.type) - new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], + new_bounds_op=omp.MapBoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [new_block.args[locations[bound_op]+1]], [], [], []], properties={"stride_in_bytes": bound_op.stride_in_bytes}, result_types=res_types) @@ -87,8 +92,8 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): res_types=[] for res in var_op.results: res_types.append(res.type) - mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds], - properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type}, + mapinfo_op=omp.MapInfoOp.build(operands=[new_block.args[locations[var_op]], [], [], map_bounds], + properties={"map_type": var_op.map_type, "name": var_op.var_name, "var_type": var_op.var_type, "map_capture_type": omp.VariableCaptureKindAttr(omp.VariableCaptureKind.BY_REF)}, result_types=res_types) new_mapinfo_ssa.append(mapinfo_op.results[0]) @@ -97,9 +102,9 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): reg=op.region op.detach_region(reg) - new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg]) + new_omp_target_op=omp.TargetOp.build(operands=[[],[],[],[],[],[],[],[],[], new_mapinfo_ssa, [], []], regions=[reg]) new_block.add_op(new_omp_target_op) - new_block.add_op(func.Return()) + new_block.add_op(func.ReturnOp()) new_fn_type=builtin.FunctionType.from_lists(arg_types, []) @@ -107,11 +112,13 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): body.add_block(new_block) new_func=func.FuncOp("tt_device", new_fn_type, body) + new_func_signature=func.FuncOp.external("tt_device", new_fn_type.inputs.data, new_fn_type.outputs.data) self.target_ops=[new_func] - call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) + call_fn=func.CallOp.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) op.parent.insert_ops_before(memref_dim_ops+[call_fn], op) + rewriter.insert_op(new_func_signature, InsertPoint.at_start(self.module.body.block)) op.parent.detach_op(op) @@ -122,15 +129,17 @@ class ExtractTarget(ModulePass): """ name = 'extract-target' - def apply(self, ctx: MLContext, module: builtin.ModuleOp): - rw_target= RewriteTarget() + def apply(self, ctx: Context, module: builtin.ModuleOp): + rw_target= RewriteTarget(module) walker = PatternRewriteWalker(GreedyRewritePatternApplier([ rw_target, ]), apply_recursively=False, walk_reverse=True) walker.rewrite_module(module) - containing_mod=builtin.ModuleOp([]) + # NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of + # the module will not be satisfied. + containing_mod=builtin.ModuleOp(Region()) module.regions[0].move_blocks(containing_mod.regions[0]) new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")}) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py new file mode 100644 index 0000000..7cf3613 --- /dev/null +++ b/ftn/transforms/fpga/target_to_hls.py @@ -0,0 +1,281 @@ +from abc import ABC +from ast import FunctionType, Module +from hmac import new +from json import load +from sys import set_coroutine_origin_tracking_depth +from typing import TypeVar, cast +from dataclasses import dataclass, field +import itertools +from xdsl.utils.hints import isa +from xdsl.dialects import memref, scf, omp +from xdsl.context import Context +from xdsl.ir import Operation, SSAValue, OpResult, Attribute, Block, Region + +from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, + op_type_rewrite_pattern, + PatternRewriteWalker, + GreedyRewritePatternApplier) +from xdsl.passes import ModulePass +from xdsl.dialects import builtin, func, llvm, arith +from ftn.util.visitor import Visitor +from xdsl.rewriter import InsertPoint +from xdsl.dialects.experimental.hls import PragmaPipelineOp, PragmaUnrollOp + +class DerefMemrefs: + @staticmethod + def deref_scalar_memops(scalar_ssa: SSAValue, rewriter: PatternRewriter): + for use in scalar_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + load_op = use.operation + # NOTE: the operand of the load operation is not a memref anymore, we have dereferenced it, + # so we forward it. + load_op.res.replace_by(load_op.memref) + rewriter.erase_op(use.operation) + elif isinstance(use.operation, memref.StoreOp): + rewriter.erase_op(use.operation) + + @staticmethod + def deref_memref_memops(memref_ssa: SSAValue, rewriter: PatternRewriter): + for use in memref_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied + # in the next load. Since we have dereferenced the first pointer, we need to end up with a single load that accesses + # the array directly. + ptr_load_op = use.operation + + # FIXME: this is assuming each dereferencing load only has one use + for ptr_use in ptr_load_op.res.uses: + if isinstance(ptr_use.operation, memref.LoadOp): + array_load_op = ptr_use.operation + array_idx = array_load_op.indices + + new_load_op = memref.LoadOp.get(memref_ssa, array_idx) + array_load_op.res.replace_by(ptr_load_op.res) + rewriter.replace_op(ptr_load_op, new_load_op) + rewriter.erase_op(array_load_op) + + elif isinstance(ptr_use.operation, memref.StoreOp): + array_store_op = ptr_use.operation + array_idx = array_store_op.indices + new_store_op = memref.StoreOp.get(array_store_op.value, memref_ssa, array_idx) + rewriter.insert_op(new_store_op, InsertPoint.before(ptr_load_op)) + rewriter.erase_op(array_store_op) + rewriter.erase_op(ptr_load_op) + + @staticmethod + def deref_args(func_op: func.FuncOp, rewriter : PatternRewriter): + """Dereference the arguments of a function operation.""" + new_input_types = [] + + for arg in func_op.body.block.args: + if isa(arg.type, builtin.MemRefType): + deref_type = arg.type.element_type + func_op.replace_argument_type(arg, deref_type, rewriter) + new_input_types.append(deref_type) + +class RemoveOps: + @staticmethod + def transform_omp_loop_nest_into_scf_for(loop_nest_op : omp.LoopNestOp, rewriter: PatternRewriter): + flat_loop_body = rewriter.move_region_contents_to_new_regions(loop_nest_op.body) + rewriter.replace_op(flat_loop_body.block.last_op, scf.YieldOp()) + flat_lb = arith.IndexCastOp(loop_nest_op.lowerBound[0], builtin.IndexType()) + flat_ub = arith.IndexCastOp(loop_nest_op.upperBound[0], builtin.IndexType()) + flat_step = arith.IndexCastOp(loop_nest_op.step[0], builtin.IndexType()) + rewriter.insert_op(flat_lb, InsertPoint.after(loop_nest_op.lowerBound[0].owner)) + rewriter.insert_op(flat_ub, InsertPoint.after(loop_nest_op.upperBound[0].owner)) + rewriter.insert_op(flat_step, InsertPoint.after(loop_nest_op.step[0].owner)) + rewriter.replace_value_with_new_type(flat_loop_body.block.args[0], builtin.IndexType()) + + # TODO: convert between i32 and index where appropriate, since omp.loop_nest operates with i32 and + # scf.for with index. + for arg_use in flat_loop_body.block.args[0].uses: + if isinstance(arg_use.operation, memref.StoreOp): + store_op = arg_use.operation + + ## Replace the store ops first + if isinstance(store_op.memref.type.element_type, builtin.IntegerType): + index_to_i32 = arith.IndexCastOp(store_op.value, builtin.i32) + store_op.value.replace_by_if(index_to_i32.result, lambda use: isinstance(use.operation, memref.StoreOp)) + rewriter.insert_op(index_to_i32, InsertPoint.before(store_op)) + + else: + alloca_op = store_op.memref.owner + assert isinstance(alloca_op, memref.AllocaOp) + index_alloca = memref.AllocaOp.get(builtin.IndexType(), shape=alloca_op.memref.type.shape) + rewriter.replace_op(alloca_op, index_alloca) + + idx_memref = store_op.memref + for idx_memref_use in idx_memref.uses: + if isinstance(idx_memref_use.operation, memref.LoadOp): + load_op = idx_memref_use.operation + index_load = memref.LoadOp.get(load_op.memref, load_op.indices) + rewriter.replace_op(load_op, index_load) + + # Original type of the block arg of the loop nest op + cast_ind_var = arith.IndexCastOp(index_load.res, builtin.i32) + rewriter.insert_op(cast_ind_var, InsertPoint.after(index_load)) + index_load.res.replace_by_if(cast_ind_var.result, lambda use: use.operation != cast_ind_var) + + flat_loop = scf.ForOp(flat_lb, flat_ub, flat_step, (), flat_loop_body) + #flat_loop = scf.ForOp(loop_nest_op.lowerBound[0], omp_loop_op.upperBound[0], omp_loop_op.step[0], (), flat_loop_body) + rewriter.replace_op(loop_nest_op, flat_loop) + + return flat_loop + + + @staticmethod + def remove_target(target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: PatternRewriter): + """Remove the target operation from the module.""" + for operand in target_op.map_vars: + operand_idx = target_op.operands.index(operand) + block_arg = target_op.region.block.args[operand_idx] + block_arg.replace_by(operand) + + target_op_terminator = target_op.region.block.last_op + assert target_op_terminator + rewriter.erase_op(target_op_terminator) + + target_func_terminator = target_func.body.block.last_op + assert target_func_terminator + for block in reversed(target_op.region.blocks): + rewriter.inline_block(block, InsertPoint.before(target_func_terminator)) + + rewriter.erase_op(target_op) + + @staticmethod + def remove_omp_parallel(parallel_op: omp.ParallelOp, rewriter: PatternRewriter): + ws_loop = None + for op in parallel_op.walk(): + if isinstance(op, omp.WsLoopOp): + ws_loop = op + break + + assert ws_loop + print(ws_loop.body.block) + omp_loop_op = ws_loop.body.block.first_op + ws_loop_block = ws_loop.body.block + ws_loop.body.detach_block(ws_loop_block) + rewriter.inline_block(ws_loop_block, InsertPoint.before(ws_loop)) + rewriter.erase_op(ws_loop) + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + one = arith.ConstantOp.from_int_and_width(1, 32) + pragma_pipeline = PragmaPipelineOp(one) + rewriter.insert_op([one, pragma_pipeline], InsertPoint.at_start(flat_loop.body.block)) + + parallel_block = parallel_op.region.block + parallel_op.region.detach_block(parallel_block) + rewriter.erase_op(parallel_block.last_op) + rewriter.inline_block(parallel_block, InsertPoint.before(parallel_op)) + rewriter.erase_op(parallel_op) + + @staticmethod + def remove_omp_simd(simd_op : omp.SimdOp, rewriter : PatternRewriter): + for priv_var in simd_op.private_vars: + arg_idx = simd_op.operands.index(priv_var) + simd_op.body.block.args[arg_idx].replace_by(priv_var) + + omp_loop_op = simd_op.body.block.first_op + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + simd_factor = simd_op.simdlen.value.data + ssa_simd_factor = arith.ConstantOp.from_int_and_width(simd_factor, 32) + pragma_unroll = PragmaUnrollOp(ssa_simd_factor) + rewriter.insert_op([ssa_simd_factor, pragma_unroll], InsertPoint.at_start(flat_loop.body.block)) + else: + flat_loop = simd_op.body.block.first_op + + assert isinstance(flat_loop, scf.ForOp) + flat_loop.detach() + rewriter.insert_op(flat_loop, InsertPoint.before(simd_op)) + #rewriter.erase_matched_op() #FIXME: this does not work + rewriter.erase_op(simd_op) + + + @staticmethod + def remove_remaining_omp_ops(target_func: func.FuncOp, rewriter: PatternRewriter): + """Remove any remaining OpenMP operations in the target function.""" + for op in target_func.walk(): + if isinstance(op, omp.MapInfoOp): + rewriter.erase_op(op) + + for op in target_func.walk(): + if isinstance(op, omp.MapBoundsOp): + rewriter.erase_op(op) + + @staticmethod + def forward_map_info(map_info: omp.MapInfoOp, rewriter: PatternRewriter): + map_info.omp_ptr.replace_by(map_info.var_ptr) + + +@dataclass +class TargetFuncToHLS(RewritePattern): + target_funcs : list[func.FuncOp] = field(default_factory=list) + + @op_type_rewrite_pattern + def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): + if "target" not in module.attributes: + return + + target_name = module.attributes["target"].data + target_func = [op for op in module.walk() if isinstance(op, func.FuncOp) and op.sym_name.data == target_name][0] + self.target_funcs.append(target_func) + + for map_info in target_func.walk(): + if not isinstance(map_info, omp.MapInfoOp): + continue + + RemoveOps.forward_map_info(map_info, rewriter) + + omp_parallel = None + for op in target_func.walk(): + if isinstance(op, omp.ParallelOp): + omp_parallel = op + break + + if omp_parallel: + RemoveOps.remove_omp_parallel(omp_parallel, rewriter) + + omp_simd = None + for op in target_func.walk(): + if isinstance(op, omp.SimdOp): + omp_simd = op + break + + if omp_simd: + RemoveOps.remove_omp_simd(omp_simd, rewriter) + + target_op = [op for op in target_func.walk() if isinstance(op, omp.TargetOp)][0] + RemoveOps.remove_target(target_op, target_func, rewriter) + RemoveOps.remove_remaining_omp_ops(target_func, rewriter) + + +@dataclass(frozen=True) +class TargetToHLSPass(ModulePass): + """ + This is the entry point for the transformation pass which will then apply the rewriter + """ + name = 'target-to-hls' + + generate : str = "hls" + + def apply(self, ctx: Context, module: builtin.ModuleOp): + target_funcs : list[func.FuncOp] = [] + walker = PatternRewriteWalker(GreedyRewritePatternApplier([ + TargetFuncToHLS(target_funcs), + ]), apply_recursively=False, walk_reverse=True) + + walker.rewrite_module(module) + + if self.generate == "hls": + # Keep only the top level module to contain the function + for target_func in target_funcs: + target_func.detach() + + module_block = module.body.block + module.body.detach_block(module.body.block) + module_block.erase() + module.body.add_block(Block(target_funcs)) + module.attributes = {} \ No newline at end of file