From b6b1dd24715e54df045ec045690d4c68ca733918 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Fri, 1 Aug 2025 00:15:05 +0200 Subject: [PATCH 01/12] Modernise extract-target pass. Note how the location of the bounds was not correct in the previous version --- ftn/tools/ftn_opt.py | 4 +-- ftn/transforms/extract_target.py | 53 ++++++++++++++++---------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/ftn/tools/ftn_opt.py b/ftn/tools/ftn_opt.py index 2e1cec4..d42dfd1 100755 --- a/ftn/tools/ftn_opt.py +++ b/ftn/tools/ftn_opt.py @@ -5,7 +5,7 @@ from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore from ftn.transforms.merge_memref_deref import MergeMemRefDeref -# from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.extract_target import ExtractTarget # from ftn.transforms.isolate_target import IsolateTarget # from psy.extract_stencil import ExtractStencil # from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT @@ -26,7 +26,7 @@ def register_all_passes(self): super().register_all_passes() self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore) self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref) - # self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("extract-target", lambda: ExtractTarget) # self.register_pass("isolate-target", lambda: IsolateTarget) # self.register_pass("convert-to-tt", lambda: ConvertToTT) diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index a84eb8c..84a413a 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -4,7 +4,8 @@ import itertools from xdsl.utils.hints import isa from xdsl.dialects import memref, scf, omp -from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region +from xdsl.context import Context +from xdsl.ir import Operation, SSAValue, OpResult, Attribute, Block, Region from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, op_type_rewrite_pattern, @@ -31,39 +32,35 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): # Grab bounds and info, then at end the terminator for var in op.map_vars: var_op=var.owner - var_op.parent.detach_op(var_op) - arg_types.append(var_op.var_ptr[0].type) - arg_ssa.append(var_op.var_ptr[0]) - if isa(var_op.var_ptr[0].type, builtin.MemRefType): - memref_type=var_op.var_ptr[0].type - src_memref=var_op.var_ptr[0] + arg_types.append(var_op.var_ptr.type) + arg_ssa.append(var_op.var_ptr) + locations[var_op]=loc_idx + if isa(var_op.var_ptr.type, builtin.MemRefType): + memref_type=var_op.var_ptr.type + src_memref=var_op.var_ptr if isa(memref_type.element_type, builtin.MemRefType): assert len(memref_type.shape) == 0 - memref_type=var_op.var_ptr[0].type.element_type - memref_loadop=memref.Load.get(src_memref, []) + memref_type=var_op.var_ptr.type.element_type + memref_loadop=memref.LoadOp.get(src_memref, []) src_memref=memref_loadop.results[0] memref_dim_ops.append(memref_loadop) for idx, s in enumerate(memref_type.shape): assert isa(s, builtin.IntAttr) if (s.data == -1): # Need to pass the dimension shape size in explicitly as it is deferred - const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType()) - dim_size=memref.Dim.from_source_and_index(src_memref, const_op) + const_op=arith.ConstantOp.from_int_and_width(idx, builtin.IndexType()) + dim_size=memref.DimOp.from_source_and_index(src_memref, const_op) memref_dim_ops+=[const_op, dim_size] arg_ssa.append(dim_size.results[0]) arg_types.append(dim_size.results[0].type) + loc_idx+=1 - locations[var_op]=loc_idx - loc_idx+=1 if len(var_op.bounds) > 0: bound_op=var_op.bounds[0].owner - bound_op.parent.detach_op(bound_op) - #self.target_ops+=[bound_op, var_op] - arg_types.append(bound_op.lower[0].type) - arg_ssa.append(bound_op.lower[0]) + arg_types.append(bound_op.lower_bound.type) + arg_ssa.append(bound_op.lower_bound) locations[bound_op]=loc_idx - # Add two, as second is the size - loc_idx+=2 + loc_idx+=1 else: pass#self.target_ops+=[var_op] @@ -78,7 +75,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): bound_op=var_op.bounds[0].owner res_types=[] for res in bound_op.results: res_types.append(res.type) - new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], + new_bounds_op=omp.MapBoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], properties={"stride_in_bytes": bound_op.stride_in_bytes}, result_types=res_types) @@ -87,8 +84,8 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): res_types=[] for res in var_op.results: res_types.append(res.type) - mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds], - properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type}, + mapinfo_op=omp.MapInfoOp.build(operands=[new_block.args[locations[var_op]], [], [], map_bounds], + properties={"map_type": var_op.map_type, "name": var_op.var_name, "var_type": var_op.var_type, "map_capture_type": omp.VariableCaptureKindAttr(omp.VariableCaptureKind.BY_REF)}, result_types=res_types) new_mapinfo_ssa.append(mapinfo_op.results[0]) @@ -97,9 +94,9 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): reg=op.region op.detach_region(reg) - new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg]) + new_omp_target_op=omp.TargetOp.build(operands=[[],[],[],[],[],[],[],[],[], new_mapinfo_ssa, [], []], regions=[reg]) new_block.add_op(new_omp_target_op) - new_block.add_op(func.Return()) + new_block.add_op(func.ReturnOp()) new_fn_type=builtin.FunctionType.from_lists(arg_types, []) @@ -110,7 +107,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): self.target_ops=[new_func] - call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) + call_fn=func.CallOp.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) op.parent.insert_ops_before(memref_dim_ops+[call_fn], op) op.parent.detach_op(op) @@ -122,7 +119,7 @@ class ExtractTarget(ModulePass): """ name = 'extract-target' - def apply(self, ctx: MLContext, module: builtin.ModuleOp): + def apply(self, ctx: Context, module: builtin.ModuleOp): rw_target= RewriteTarget() walker = PatternRewriteWalker(GreedyRewritePatternApplier([ rw_target, @@ -130,7 +127,9 @@ def apply(self, ctx: MLContext, module: builtin.ModuleOp): walker.rewrite_module(module) - containing_mod=builtin.ModuleOp([]) + # NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of + # the module will not be satisfied. + containing_mod=builtin.ModuleOp(Region()) module.regions[0].move_blocks(containing_mod.regions[0]) new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")}) From a62cba657260b2f2ab5823dd359d101a3929f5d5 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Fri, 1 Aug 2025 17:43:29 +0200 Subject: [PATCH 02/12] Add the signature of the offloaded function to the module calling the function. Otherwise, the function will not be present in the symbol table of the calling module --- ftn/transforms/extract_target.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index 84a413a..468617a 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -1,6 +1,8 @@ from abc import ABC +from ast import Module +from hmac import new from typing import TypeVar, cast -from dataclasses import dataclass +from dataclasses import dataclass, field import itertools from xdsl.utils.hints import isa from xdsl.dialects import memref, scf, omp @@ -14,10 +16,12 @@ from xdsl.passes import ModulePass from xdsl.dialects import builtin, func, llvm, arith from ftn.util.visitor import Visitor +from xdsl.rewriter import InsertPoint +@dataclass class RewriteTarget(RewritePattern): - def __init__(self): - self.target_ops=[] + module : builtin.ModuleOp + target_ops: list[Operation] = field(default_factory=list) @op_type_rewrite_pattern def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): @@ -104,11 +108,13 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): body.add_block(new_block) new_func=func.FuncOp("tt_device", new_fn_type, body) + new_func_signature=func.FuncOp.external("tt_device", new_fn_type.inputs.data, new_fn_type.outputs.data) self.target_ops=[new_func] call_fn=func.CallOp.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) op.parent.insert_ops_before(memref_dim_ops+[call_fn], op) + rewriter.insert_op(new_func_signature, InsertPoint.at_start(self.module.body.block)) op.parent.detach_op(op) @@ -120,7 +126,7 @@ class ExtractTarget(ModulePass): name = 'extract-target' def apply(self, ctx: Context, module: builtin.ModuleOp): - rw_target= RewriteTarget() + rw_target= RewriteTarget(module) walker = PatternRewriteWalker(GreedyRewritePatternApplier([ rw_target, ]), apply_recursively=False, walk_reverse=True) From cf99ae9cafc30c7b08a06c2e818f54aa77e4ef07 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Fri, 1 Aug 2025 17:47:53 +0200 Subject: [PATCH 03/12] Fix location index. Now all the examples build --- ftn/transforms/extract_target.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index 468617a..b7073a2 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -39,6 +39,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): arg_types.append(var_op.var_ptr.type) arg_ssa.append(var_op.var_ptr) locations[var_op]=loc_idx + loc_idx+=1 if isa(var_op.var_ptr.type, builtin.MemRefType): memref_type=var_op.var_ptr.type src_memref=var_op.var_ptr From b1c2ac1d07b3d1ea2233c028e7df17574c833ec3 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Fri, 1 Aug 2025 18:23:52 +0200 Subject: [PATCH 04/12] Add upper bound to the offload function arguments. Otherwise the generated IR does not pass the mlir-opt verifier. Tested on all the offload examples --- ftn/transforms/extract_target.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index b7073a2..8300612 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -64,8 +64,11 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): bound_op=var_op.bounds[0].owner arg_types.append(bound_op.lower_bound.type) arg_ssa.append(bound_op.lower_bound) + arg_types.append(bound_op.upper_bound.type) + arg_ssa.append(bound_op.upper_bound) locations[bound_op]=loc_idx - loc_idx+=1 + # Adding both lower and upper bound + loc_idx+=2 else: pass#self.target_ops+=[var_op] @@ -80,7 +83,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): bound_op=var_op.bounds[0].owner res_types=[] for res in bound_op.results: res_types.append(res.type) - new_bounds_op=omp.MapBoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], + new_bounds_op=omp.MapBoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [new_block.args[locations[bound_op]+1]], [], [], []], properties={"stride_in_bytes": bound_op.stride_in_bytes}, result_types=res_types) From c6c360f2669c41dff765bedb0f8375abfdb86ac9 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Fri, 1 Aug 2025 23:14:21 +0200 Subject: [PATCH 05/12] Add pass that converts target function to HLS compatible format. Function arguments are dereferenced, since HLS is not compatible with pointers to pointers. OpenMP variables in the omp.map.info operations are forwarded and the omp.target operation is forwarded (operands are forwarded to the block operations). Tested with offload/ex1.F90 --- ftn/tools/ftn_opt.py | 2 + ftn/transforms/fpga/target_to_hls.py | 137 +++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 ftn/transforms/fpga/target_to_hls.py diff --git a/ftn/tools/ftn_opt.py b/ftn/tools/ftn_opt.py index d42dfd1..149a96f 100755 --- a/ftn/tools/ftn_opt.py +++ b/ftn/tools/ftn_opt.py @@ -6,6 +6,7 @@ from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore from ftn.transforms.merge_memref_deref import MergeMemRefDeref from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.fpga.target_to_hls import TargetToHLSPass # from ftn.transforms.isolate_target import IsolateTarget # from psy.extract_stencil import ExtractStencil # from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT @@ -27,6 +28,7 @@ def register_all_passes(self): self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore) self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref) self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("target-to-hls", lambda: TargetToHLSPass) # self.register_pass("isolate-target", lambda: IsolateTarget) # self.register_pass("convert-to-tt", lambda: ConvertToTT) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py new file mode 100644 index 0000000..92a1a42 --- /dev/null +++ b/ftn/transforms/fpga/target_to_hls.py @@ -0,0 +1,137 @@ +from abc import ABC +from ast import FunctionType, Module +from hmac import new +from json import load +from typing import TypeVar, cast +from dataclasses import dataclass, field +import itertools +from xdsl.utils.hints import isa +from xdsl.dialects import memref, scf, omp +from xdsl.context import Context +from xdsl.ir import Operation, SSAValue, OpResult, Attribute, Block, Region + +from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, + op_type_rewrite_pattern, + PatternRewriteWalker, + GreedyRewritePatternApplier) +from xdsl.passes import ModulePass +from xdsl.dialects import builtin, func, llvm, arith +from ftn.util.visitor import Visitor +from xdsl.rewriter import InsertPoint + +@dataclass +class TargetFuncToHLS(RewritePattern): + def deref_args(self, func_op: func.FuncOp, rewriter : PatternRewriter): + """Dereference the arguments of a function operation.""" + new_input_types = [] + + for arg in func_op.body.block.args: + if isa(arg.type, builtin.MemRefType): + deref_type = arg.type.element_type + func_op.replace_argument_type(arg, deref_type, rewriter) + new_input_types.append(deref_type) + + def forward_map_info(self, map_info: omp.MapInfoOp, rewriter: PatternRewriter): + map_info.omp_ptr.replace_by(map_info.var_ptr) + + def remove_target(self, target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: PatternRewriter): + """Remove the target operation from the module.""" + for operand in target_op.map_vars: + operand_idx = target_op.operands.index(operand) + block_arg = target_op.region.block.args[operand_idx] + block_arg.replace_by(operand) + + if not isinstance(operand.type, builtin.MemRefType): + for use in operand.uses: + if isinstance(use.operation, memref.LoadOp): + load_op = use.operation + # NOTE: the operand of the load operation is not a memref anymore, we have dereferenced it, + # so we forward it. + load_op.res.replace_by(load_op.memref) + rewriter.erase_op(use.operation) + elif isinstance(use.operation, memref.StoreOp): + rewriter.erase_op(use.operation) + else: + # Adjust the store and load operations that use the dereferenced pointer. + for use in operand.uses: + if isinstance(use.operation, memref.LoadOp): + # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied + # in the next load. Since we have dereferenced the first pointer, we need to end up with a single load that accesses + # the array directly. + ptr_load_op = use.operation + + # FIXME: this is assuming each dereferencing load only has one use + for ptr_use in ptr_load_op.res.uses: + if isinstance(ptr_use.operation, memref.LoadOp): + array_load_op = ptr_use.operation + array_idx = array_load_op.indices + + new_load_op = memref.LoadOp.get(operand, array_idx) + array_load_op.res.replace_by(ptr_load_op.res) + rewriter.replace_op(ptr_load_op, new_load_op) + rewriter.erase_op(array_load_op) + + elif isinstance(ptr_use.operation, memref.StoreOp): + array_store_op = ptr_use.operation + array_idx = array_store_op.indices + new_store_op = memref.StoreOp.get(array_store_op.value, operand, array_idx) + rewriter.insert_op(new_store_op, InsertPoint.before(ptr_load_op)) + rewriter.erase_op(array_store_op) + rewriter.erase_op(ptr_load_op) + + + + target_op_terminator = target_op.region.block.last_op + assert target_op_terminator + rewriter.erase_op(target_op_terminator) + + target_func_terminator = target_func.body.block.last_op + assert target_func_terminator + for block in reversed(target_op.region.blocks): + rewriter.inline_block(block, InsertPoint.before(target_func_terminator)) + + rewriter.erase_op(target_op) + + def remove_remaining_omp_ops(self, target_func: func.FuncOp, rewriter: PatternRewriter): + """Remove any remaining OpenMP operations in the target function.""" + for op in target_func.walk(): + if isinstance(op, omp.MapInfoOp): + rewriter.erase_op(op) + + for op in target_func.walk(): + if isinstance(op, omp.MapBoundsOp): + rewriter.erase_op(op) + + @op_type_rewrite_pattern + def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): + if "target" not in module.attributes: + return + + target_name = module.attributes["target"].data + target_func = [op for op in module.walk() if isinstance(op, func.FuncOp) and op.sym_name.data == target_name][0] + + self.deref_args(target_func, rewriter) + for map_info in target_func.walk(): + if not isinstance(map_info, omp.MapInfoOp): + continue + + self.forward_map_info(map_info, rewriter) + + target_op = [op for op in target_func.walk() if isinstance(op, omp.TargetOp)][0] + self.remove_target(target_op, target_func, rewriter) + self.remove_remaining_omp_ops(target_func, rewriter) + + +@dataclass(frozen=True) +class TargetToHLSPass(ModulePass): + """ + This is the entry point for the transformation pass which will then apply the rewriter + """ + name = 'target-to-hls' + + def apply(self, ctx: Context, module: builtin.ModuleOp): + walker = PatternRewriteWalker(GreedyRewritePatternApplier([ + TargetFuncToHLS(), + ]), apply_recursively=False, walk_reverse=True) + + walker.rewrite_module(module) \ No newline at end of file From a7cf606d1abdf1fa915814f44b5221552c1e1c29 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Fri, 1 Aug 2025 23:24:59 +0200 Subject: [PATCH 06/12] More modular design to remove the target operation --- ftn/transforms/fpga/target_to_hls.py | 78 +++++++++++++++------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py index 92a1a42..e765b31 100644 --- a/ftn/transforms/fpga/target_to_hls.py +++ b/ftn/transforms/fpga/target_to_hls.py @@ -34,6 +34,45 @@ def deref_args(self, func_op: func.FuncOp, rewriter : PatternRewriter): def forward_map_info(self, map_info: omp.MapInfoOp, rewriter: PatternRewriter): map_info.omp_ptr.replace_by(map_info.var_ptr) + def deref_scalar_memops(self, scalar_ssa: SSAValue, rewriter: PatternRewriter): + for use in scalar_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + load_op = use.operation + # NOTE: the operand of the load operation is not a memref anymore, we have dereferenced it, + # so we forward it. + load_op.res.replace_by(load_op.memref) + rewriter.erase_op(use.operation) + elif isinstance(use.operation, memref.StoreOp): + rewriter.erase_op(use.operation) + + def deref_memref_memops(self, memref_ssa: SSAValue, rewriter: PatternRewriter): + for use in memref_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied + # in the next load. Since we have dereferenced the first pointer, we need to end up with a single load that accesses + # the array directly. + ptr_load_op = use.operation + + # FIXME: this is assuming each dereferencing load only has one use + for ptr_use in ptr_load_op.res.uses: + if isinstance(ptr_use.operation, memref.LoadOp): + array_load_op = ptr_use.operation + array_idx = array_load_op.indices + + new_load_op = memref.LoadOp.get(memref_ssa, array_idx) + array_load_op.res.replace_by(ptr_load_op.res) + rewriter.replace_op(ptr_load_op, new_load_op) + rewriter.erase_op(array_load_op) + + elif isinstance(ptr_use.operation, memref.StoreOp): + array_store_op = ptr_use.operation + array_idx = array_store_op.indices + new_store_op = memref.StoreOp.get(array_store_op.value, memref_ssa, array_idx) + rewriter.insert_op(new_store_op, InsertPoint.before(ptr_load_op)) + rewriter.erase_op(array_store_op) + rewriter.erase_op(ptr_load_op) + + def remove_target(self, target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: PatternRewriter): """Remove the target operation from the module.""" for operand in target_op.map_vars: @@ -42,44 +81,9 @@ def remove_target(self, target_op: omp.TargetOp, target_func: func.FuncOp, rewri block_arg.replace_by(operand) if not isinstance(operand.type, builtin.MemRefType): - for use in operand.uses: - if isinstance(use.operation, memref.LoadOp): - load_op = use.operation - # NOTE: the operand of the load operation is not a memref anymore, we have dereferenced it, - # so we forward it. - load_op.res.replace_by(load_op.memref) - rewriter.erase_op(use.operation) - elif isinstance(use.operation, memref.StoreOp): - rewriter.erase_op(use.operation) + self.deref_scalar_memops(operand, rewriter) else: - # Adjust the store and load operations that use the dereferenced pointer. - for use in operand.uses: - if isinstance(use.operation, memref.LoadOp): - # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied - # in the next load. Since we have dereferenced the first pointer, we need to end up with a single load that accesses - # the array directly. - ptr_load_op = use.operation - - # FIXME: this is assuming each dereferencing load only has one use - for ptr_use in ptr_load_op.res.uses: - if isinstance(ptr_use.operation, memref.LoadOp): - array_load_op = ptr_use.operation - array_idx = array_load_op.indices - - new_load_op = memref.LoadOp.get(operand, array_idx) - array_load_op.res.replace_by(ptr_load_op.res) - rewriter.replace_op(ptr_load_op, new_load_op) - rewriter.erase_op(array_load_op) - - elif isinstance(ptr_use.operation, memref.StoreOp): - array_store_op = ptr_use.operation - array_idx = array_store_op.indices - new_store_op = memref.StoreOp.get(array_store_op.value, operand, array_idx) - rewriter.insert_op(new_store_op, InsertPoint.before(ptr_load_op)) - rewriter.erase_op(array_store_op) - rewriter.erase_op(ptr_load_op) - - + self.deref_memref_memops(operand, rewriter) target_op_terminator = target_op.region.block.last_op assert target_op_terminator From a56dd6a3c1440f641d9530f55ae4fd15dae75975 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Sat, 2 Aug 2025 00:36:25 +0200 Subject: [PATCH 07/12] Keep only the top-level module to contain the HLS function - necessary to generate LLVM IR --- ftn/transforms/fpga/target_to_hls.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py index e765b31..57a631c 100644 --- a/ftn/transforms/fpga/target_to_hls.py +++ b/ftn/transforms/fpga/target_to_hls.py @@ -21,6 +21,8 @@ @dataclass class TargetFuncToHLS(RewritePattern): + target_funcs : list[func.FuncOp] = field(default_factory=list) + def deref_args(self, func_op: func.FuncOp, rewriter : PatternRewriter): """Dereference the arguments of a function operation.""" new_input_types = [] @@ -113,6 +115,7 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, target_name = module.attributes["target"].data target_func = [op for op in module.walk() if isinstance(op, func.FuncOp) and op.sym_name.data == target_name][0] + self.target_funcs.append(target_func) self.deref_args(target_func, rewriter) for map_info in target_func.walk(): @@ -133,9 +136,22 @@ class TargetToHLSPass(ModulePass): """ name = 'target-to-hls' + generate : str = "hls" + def apply(self, ctx: Context, module: builtin.ModuleOp): + target_funcs : list[func.FuncOp] = [] walker = PatternRewriteWalker(GreedyRewritePatternApplier([ - TargetFuncToHLS(), + TargetFuncToHLS(target_funcs), ]), apply_recursively=False, walk_reverse=True) - walker.rewrite_module(module) \ No newline at end of file + walker.rewrite_module(module) + + if self.generate == "hls": + # Keep only the top level module to contain the function + for target_func in target_funcs: + target_func.detach() + + module_block = module.body.block + module.body.detach_block(module.body.block) + module_block.erase() + module.body.add_block(Block(target_funcs)) \ No newline at end of file From 58a649cb77928017b3c8596c0f4acb625e70b9b3 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Mon, 4 Aug 2025 20:39:34 +0200 Subject: [PATCH 08/12] Dereference in the HLS pass not necessary anymore, as it is processed earlier in the pipeline. Also, remove attributes from the module, since they are incompatible with the downgrading to LLVM v7 --- ftn/transforms/fpga/target_to_hls.py | 65 +++++++++++++++------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py index 57a631c..e001dfa 100644 --- a/ftn/transforms/fpga/target_to_hls.py +++ b/ftn/transforms/fpga/target_to_hls.py @@ -19,24 +19,9 @@ from ftn.util.visitor import Visitor from xdsl.rewriter import InsertPoint -@dataclass -class TargetFuncToHLS(RewritePattern): - target_funcs : list[func.FuncOp] = field(default_factory=list) - - def deref_args(self, func_op: func.FuncOp, rewriter : PatternRewriter): - """Dereference the arguments of a function operation.""" - new_input_types = [] - - for arg in func_op.body.block.args: - if isa(arg.type, builtin.MemRefType): - deref_type = arg.type.element_type - func_op.replace_argument_type(arg, deref_type, rewriter) - new_input_types.append(deref_type) - - def forward_map_info(self, map_info: omp.MapInfoOp, rewriter: PatternRewriter): - map_info.omp_ptr.replace_by(map_info.var_ptr) - - def deref_scalar_memops(self, scalar_ssa: SSAValue, rewriter: PatternRewriter): +class DerefMemrefs: + @staticmethod + def deref_scalar_memops(scalar_ssa: SSAValue, rewriter: PatternRewriter): for use in scalar_ssa.uses: if isinstance(use.operation, memref.LoadOp): load_op = use.operation @@ -47,7 +32,8 @@ def deref_scalar_memops(self, scalar_ssa: SSAValue, rewriter: PatternRewriter): elif isinstance(use.operation, memref.StoreOp): rewriter.erase_op(use.operation) - def deref_memref_memops(self, memref_ssa: SSAValue, rewriter: PatternRewriter): + @staticmethod + def deref_memref_memops(memref_ssa: SSAValue, rewriter: PatternRewriter): for use in memref_ssa.uses: if isinstance(use.operation, memref.LoadOp): # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied @@ -74,19 +60,26 @@ def deref_memref_memops(self, memref_ssa: SSAValue, rewriter: PatternRewriter): rewriter.erase_op(array_store_op) rewriter.erase_op(ptr_load_op) + @staticmethod + def deref_args(func_op: func.FuncOp, rewriter : PatternRewriter): + """Dereference the arguments of a function operation.""" + new_input_types = [] + + for arg in func_op.body.block.args: + if isa(arg.type, builtin.MemRefType): + deref_type = arg.type.element_type + func_op.replace_argument_type(arg, deref_type, rewriter) + new_input_types.append(deref_type) - def remove_target(self, target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: PatternRewriter): +class RemoveOps: + @staticmethod + def remove_target(target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: PatternRewriter): """Remove the target operation from the module.""" for operand in target_op.map_vars: operand_idx = target_op.operands.index(operand) block_arg = target_op.region.block.args[operand_idx] block_arg.replace_by(operand) - if not isinstance(operand.type, builtin.MemRefType): - self.deref_scalar_memops(operand, rewriter) - else: - self.deref_memref_memops(operand, rewriter) - target_op_terminator = target_op.region.block.last_op assert target_op_terminator rewriter.erase_op(target_op_terminator) @@ -98,7 +91,8 @@ def remove_target(self, target_op: omp.TargetOp, target_func: func.FuncOp, rewri rewriter.erase_op(target_op) - def remove_remaining_omp_ops(self, target_func: func.FuncOp, rewriter: PatternRewriter): + @staticmethod + def remove_remaining_omp_ops(target_func: func.FuncOp, rewriter: PatternRewriter): """Remove any remaining OpenMP operations in the target function.""" for op in target_func.walk(): if isinstance(op, omp.MapInfoOp): @@ -108,6 +102,15 @@ def remove_remaining_omp_ops(self, target_func: func.FuncOp, rewriter: PatternRe if isinstance(op, omp.MapBoundsOp): rewriter.erase_op(op) + @staticmethod + def forward_map_info(map_info: omp.MapInfoOp, rewriter: PatternRewriter): + map_info.omp_ptr.replace_by(map_info.var_ptr) + + +@dataclass +class TargetFuncToHLS(RewritePattern): + target_funcs : list[func.FuncOp] = field(default_factory=list) + @op_type_rewrite_pattern def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): if "target" not in module.attributes: @@ -117,16 +120,15 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, target_func = [op for op in module.walk() if isinstance(op, func.FuncOp) and op.sym_name.data == target_name][0] self.target_funcs.append(target_func) - self.deref_args(target_func, rewriter) for map_info in target_func.walk(): if not isinstance(map_info, omp.MapInfoOp): continue - self.forward_map_info(map_info, rewriter) + RemoveOps.forward_map_info(map_info, rewriter) target_op = [op for op in target_func.walk() if isinstance(op, omp.TargetOp)][0] - self.remove_target(target_op, target_func, rewriter) - self.remove_remaining_omp_ops(target_func, rewriter) + RemoveOps.remove_target(target_op, target_func, rewriter) + RemoveOps.remove_remaining_omp_ops(target_func, rewriter) @dataclass(frozen=True) @@ -154,4 +156,5 @@ def apply(self, ctx: Context, module: builtin.ModuleOp): module_block = module.body.block module.body.detach_block(module.body.block) module_block.erase() - module.body.add_block(Block(target_funcs)) \ No newline at end of file + module.body.add_block(Block(target_funcs)) + module.attributes = {} \ No newline at end of file From c13fc0612f2d14095c14c1702fa76efda89f3726 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Mon, 4 Aug 2025 20:52:51 +0200 Subject: [PATCH 09/12] Add scripts to generate LLVM IR compatible with Vitis HLS --- ftn/tools/config.sh | 9 +++++++++ ftn/tools/mlir_to_ll.sh | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 ftn/tools/config.sh create mode 100755 ftn/tools/mlir_to_ll.sh diff --git a/ftn/tools/config.sh b/ftn/tools/config.sh new file mode 100644 index 0000000..6f9e4a3 --- /dev/null +++ b/ftn/tools/config.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +DOCKER_RUN="docker run -i -u $(id -u):$(id -g) -v $(pwd):$(pwd) -w $(pwd)" + +alias python-soda="$DOCKER_RUN agostini01/soda:v15.08:v15.08 python3" +alias soda-opt="$DOCKER_RUN agostini01/soda:v15.08 soda-opt" +alias opt-16="$DOCKER_RUN agostini01/soda:v15.08 opt" +alias mlir-opt="$DOCKER_RUN agostini01/soda:v15.08 mlir-opt" +alias soda-mlir-translate="$DOCKER_RUN agostini01/soda:v15.08 mlir-translate" \ No newline at end of file diff --git a/ftn/tools/mlir_to_ll.sh b/ftn/tools/mlir_to_ll.sh new file mode 100755 index 0000000..6190145 --- /dev/null +++ b/ftn/tools/mlir_to_ll.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +ODIR=fpga +CLKPERIOD=10 +TARGET_BOARD=xc7vx690t-ffg1930-3 +KERNELNAME=$1 +LIBDIR=/opt/soda-opt/lib/ +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source $SCRIPT_DIR/config.sh +shopt -s expand_aliases + +soda-opt ex1_offload.mlir --lower-affine --canonicalize --lower-all-to-llvm=use-bare-ptr-memref-call-conv | \ +soda-mlir-translate --mlir-to-llvmir --opaque-pointers=0 -o model.ll + +opt-16 model.ll \ + -S \ + -enable-new-pm=0 \ + -load "${LIBDIR}/VhlsLLVMRewriter.so" \ + -mem2arr -strip-debug \ + -instcombine \ + -xlnname \ + -xlnanno -xlntop $KERNELNAME \ + -xlntbgen -xlntbdummynames="$KERNELNAME.dummy.c" \ + -xlntbtclnames="$KERNELNAME.run.tcl" \ + -xlnllvm="$KERNELNAME.opt.ll" \ + -clock-period-ns=$CLKPERIOD -target=$TARGET_BOARD \ + > $KERNELNAME.opt.ll From 10b1893a097291b657fde0a97dd8b5b26f55ac77 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Mon, 4 Aug 2025 21:31:04 +0200 Subject: [PATCH 10/12] Add option to trigger the generation of LLVM IR for Vitis HLS. This only works when called after offload and extract-target, target-to-hls, but it is not enforced yet for debuggability --- ftn/tools/mlir_to_ll.sh | 5 +++-- ftn/tools/xftn.py | 25 ++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/ftn/tools/mlir_to_ll.sh b/ftn/tools/mlir_to_ll.sh index 6190145..a6bd274 100755 --- a/ftn/tools/mlir_to_ll.sh +++ b/ftn/tools/mlir_to_ll.sh @@ -3,14 +3,15 @@ ODIR=fpga CLKPERIOD=10 TARGET_BOARD=xc7vx690t-ffg1930-3 -KERNELNAME=$1 +MLIR_INPUT=$1 +KERNELNAME=$2 LIBDIR=/opt/soda-opt/lib/ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source $SCRIPT_DIR/config.sh shopt -s expand_aliases -soda-opt ex1_offload.mlir --lower-affine --canonicalize --lower-all-to-llvm=use-bare-ptr-memref-call-conv | \ +soda-opt $MLIR_INPUT --lower-affine --canonicalize --lower-all-to-llvm=use-bare-ptr-memref-call-conv | \ soda-mlir-translate --mlir-to-llvmir --opaque-pointers=0 -o model.ll opt-16 model.ll \ diff --git a/ftn/tools/xftn.py b/ftn/tools/xftn.py index 99f1460..60f1a4f 100644 --- a/ftn/tools/xftn.py +++ b/ftn/tools/xftn.py @@ -55,6 +55,12 @@ def initialise_argument_parser(): default=[], help="Additional transformation pass for ftn-opt", ) + parser.add_argument( + "-f", + "--fpga-llvmir", + action="store_true", + help="Generate LLVM IR for FPGA compatible with Vitis HLS." + ) parser.add_argument( "-I", "--include-directory", @@ -96,7 +102,7 @@ def initialise_argument_parser(): parser.add_argument( "--stages", default=None, - help="Specify which stages will run (a combination of: flang, pre, ftn, post, mlir, obj, clang) in comma separated list without spaces", + help="Specify which stages will run (a combination of: flang, pre, ftn, post, mlir, obj, clang, fpga-llvmir) in comma separated list without spaces", ) parser.add_argument( "-v", @@ -132,6 +138,7 @@ def enable_disable_stages_by_output_type(options_db, output_type): options_db["run_mlir_to_llvmir_stage"] = options_db["run_postprocess_stage"] options_db["run_create_object_stage"] = output_type == OutputType.OBJECT options_db["run_build_executable_stage"] = output_type == OutputType.EXECUTABLE + options_db["run_fpga_llvmir_stage"] = False def build_options_db_from_args(args): @@ -175,6 +182,9 @@ def build_options_db_from_args(args): options_db["run_mlir_to_llvmir_stage"] = False options_db["run_build_executable_stage"] = False + if options_db["fpga_llvmir"]: + options_db["run_fpga_llvmir_stage"] = True + if options_db["stages"] is not None: # We handle stages to run last, this overrides all other stage selection # through other options or output filename @@ -188,6 +198,7 @@ def build_options_db_from_args(args): options_db["run_mlir_to_llvmir_stage"] = "mlir" in stages_to_run options_db["run_create_object_stage"] = "obj" in stages_to_run options_db["run_build_executable_stage"] = "clang" in stages_to_run + options_db["run_fpga_llvmir_stage"] = "fpga-llvmir" in stages_to_run if "flang" in stages_to_run: stages_to_run.remove("flang") @@ -203,6 +214,8 @@ def build_options_db_from_args(args): stages_to_run.remove("obj") if "clang" in stages_to_run: stages_to_run.remove("clang") + if "fpga-llvmir" in stages_to_run: + stages_to_run.remove("fpga-llvmir") if len(stages_to_run) > 0: for e in stages_to_run: print(f"Unknown stage provided as argument '{e}'") @@ -241,6 +254,9 @@ def display_verbose_start_message(options_db): print( f"Stage 'Build executable': {'Enabled' if options_db['run_build_executable_stage'] else 'Disabled'}" ) + print( + f"Stage 'FPGA LLVM-IR': {'Enabled' if options_db['run_fpga_llvmir_stage'] else 'Disabled'}" + ) if options_db["run_build_executable_stage"]: print("") print(f"Also linking against object files {options_db['link-objectfiles']}") @@ -459,6 +475,9 @@ def build_executable(output_tmp_dir, input_fn, executable_fn, options_db): os.system(f"clang {clang_args}") post_stage_check(executable_fn, options_db["verbose"], executable=True) +def generate_llvmir_for_fpga(out_file): + tools_path = os.path.dirname(os.path.abspath(__file__)) + os.system(f"{tools_path}/mlir_to_ll.sh {out_file} tt_device") def print_file_contents(filename): with open(filename) as f: @@ -568,6 +587,10 @@ def main(): if options_db["offload"]: print_verbose_message(options_db, f"Offload MLIR in '{out_file}'") + if options_db["run_fpga_llvmir_stage"]: + print_verbose_message(options_db, "Running FPGA LLVM-IR generation stage",) + generate_llvmir_for_fpga(out_file) + if options_db["cleanup"]: remove_file_if_exists( tmp_dir, From fa1bfb9e1135beebaa0aa39778400edc92bad11e Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Tue, 5 Aug 2025 22:27:30 +0200 Subject: [PATCH 11/12] Add support for omp.parallel as a pipelined loop. Tested on ex4.F90 --- ftn/transforms/fpga/target_to_hls.py | 74 ++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py index e001dfa..14c9b94 100644 --- a/ftn/transforms/fpga/target_to_hls.py +++ b/ftn/transforms/fpga/target_to_hls.py @@ -2,6 +2,7 @@ from ast import FunctionType, Module from hmac import new from json import load +from sys import set_coroutine_origin_tracking_depth from typing import TypeVar, cast from dataclasses import dataclass, field import itertools @@ -18,6 +19,7 @@ from xdsl.dialects import builtin, func, llvm, arith from ftn.util.visitor import Visitor from xdsl.rewriter import InsertPoint +from xdsl.dialects.experimental.hls import PragmaPipelineOp class DerefMemrefs: @staticmethod @@ -91,6 +93,69 @@ def remove_target(target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: P rewriter.erase_op(target_op) + @staticmethod + def remove_omp_parallel(parallel_op: omp.ParallelOp, rewriter: PatternRewriter): + ws_loop = None + for op in parallel_op.walk(): + if isinstance(op, omp.WsLoopOp): + ws_loop = op + break + + assert ws_loop + print(ws_loop.body.block) + omp_loop_op = ws_loop.body.block.first_op + ws_loop_block = ws_loop.body.block + ws_loop.body.detach_block(ws_loop_block) + rewriter.inline_block(ws_loop_block, InsertPoint.before(ws_loop)) + rewriter.erase_op(ws_loop) + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop_body = rewriter.move_region_contents_to_new_regions(omp_loop_op.body) + rewriter.replace_op(flat_loop_body.block.last_op, scf.YieldOp()) + flat_lb = arith.IndexCastOp(omp_loop_op.lowerBound[0], builtin.IndexType()) + flat_ub = arith.IndexCastOp(omp_loop_op.upperBound[0], builtin.IndexType()) + flat_step = arith.IndexCastOp(omp_loop_op.step[0], builtin.IndexType()) + rewriter.insert_op(flat_lb, InsertPoint.after(omp_loop_op.lowerBound[0].owner)) + rewriter.insert_op(flat_ub, InsertPoint.after(omp_loop_op.upperBound[0].owner)) + rewriter.insert_op(flat_step, InsertPoint.after(omp_loop_op.step[0].owner)) + rewriter.replace_value_with_new_type(flat_loop_body.block.args[0], builtin.IndexType()) + + # Replace the store ops first + for arg_use in flat_loop_body.block.args[0].uses: + if isinstance(arg_use.operation, memref.StoreOp): + store_op = arg_use.operation + alloca_op = store_op.memref.owner + assert isinstance(alloca_op, memref.AllocaOp) + index_alloca = memref.AllocaOp.get(builtin.IndexType(), shape=alloca_op.memref.type.shape) + rewriter.replace_op(alloca_op, index_alloca) + + idx_memref = store_op.memref + for idx_memref_use in idx_memref.uses: + if isinstance(idx_memref_use.operation, memref.LoadOp): + load_op = idx_memref_use.operation + index_load = memref.LoadOp.get(load_op.memref, load_op.indices) + rewriter.replace_op(load_op, index_load) + + # Original type of the block arg of the loop nest op + cast_ind_var = arith.IndexCastOp(index_load.res, builtin.i32) + rewriter.insert_op(cast_ind_var, InsertPoint.after(index_load)) + index_load.res.replace_by_if(cast_ind_var.result, lambda use: use.operation != cast_ind_var) + + flat_loop = scf.ForOp(flat_lb, flat_ub, flat_step, (), flat_loop_body) + #flat_loop = scf.ForOp(omp_loop_op.lowerBound[0], omp_loop_op.upperBound[0], omp_loop_op.step[0], (), flat_loop_body) + rewriter.replace_op(omp_loop_op, flat_loop) + + one = arith.ConstantOp.from_int_and_width(1, 32) + pragma_pipeline = PragmaPipelineOp(one) + rewriter.insert_op([one, pragma_pipeline], InsertPoint.at_start(flat_loop_body.block)) + + parallel_block = parallel_op.region.block + parallel_op.region.detach_block(parallel_block) + rewriter.erase_op(parallel_block.last_op) + rewriter.inline_block(parallel_block, InsertPoint.before(parallel_op)) + rewriter.erase_op(parallel_op) + + @staticmethod def remove_remaining_omp_ops(target_func: func.FuncOp, rewriter: PatternRewriter): """Remove any remaining OpenMP operations in the target function.""" @@ -126,6 +191,15 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, RemoveOps.forward_map_info(map_info, rewriter) + omp_parallel = None + for op in target_func.walk(): + if isinstance(op, omp.ParallelOp): + omp_parallel = op + break + + if omp_parallel: + RemoveOps.remove_omp_parallel(omp_parallel, rewriter) + target_op = [op for op in target_func.walk() if isinstance(op, omp.TargetOp)][0] RemoveOps.remove_target(target_op, target_func, rewriter) RemoveOps.remove_remaining_omp_ops(target_func, rewriter) From fea7ad2e38c67d4498277ad4194f0aba5e4943f4 Mon Sep 17 00:00:00 2001 From: Gabriel Rodriguez-Canal Date: Tue, 5 Aug 2025 23:24:52 +0200 Subject: [PATCH 12/12] Add support for SIMD directive as an unrolled loop --- ftn/transforms/fpga/target_to_hls.py | 121 +++++++++++++++++++-------- 1 file changed, 84 insertions(+), 37 deletions(-) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py index 14c9b94..7cf3613 100644 --- a/ftn/transforms/fpga/target_to_hls.py +++ b/ftn/transforms/fpga/target_to_hls.py @@ -19,7 +19,7 @@ from xdsl.dialects import builtin, func, llvm, arith from ftn.util.visitor import Visitor from xdsl.rewriter import InsertPoint -from xdsl.dialects.experimental.hls import PragmaPipelineOp +from xdsl.dialects.experimental.hls import PragmaPipelineOp, PragmaUnrollOp class DerefMemrefs: @staticmethod @@ -74,6 +74,55 @@ def deref_args(func_op: func.FuncOp, rewriter : PatternRewriter): new_input_types.append(deref_type) class RemoveOps: + @staticmethod + def transform_omp_loop_nest_into_scf_for(loop_nest_op : omp.LoopNestOp, rewriter: PatternRewriter): + flat_loop_body = rewriter.move_region_contents_to_new_regions(loop_nest_op.body) + rewriter.replace_op(flat_loop_body.block.last_op, scf.YieldOp()) + flat_lb = arith.IndexCastOp(loop_nest_op.lowerBound[0], builtin.IndexType()) + flat_ub = arith.IndexCastOp(loop_nest_op.upperBound[0], builtin.IndexType()) + flat_step = arith.IndexCastOp(loop_nest_op.step[0], builtin.IndexType()) + rewriter.insert_op(flat_lb, InsertPoint.after(loop_nest_op.lowerBound[0].owner)) + rewriter.insert_op(flat_ub, InsertPoint.after(loop_nest_op.upperBound[0].owner)) + rewriter.insert_op(flat_step, InsertPoint.after(loop_nest_op.step[0].owner)) + rewriter.replace_value_with_new_type(flat_loop_body.block.args[0], builtin.IndexType()) + + # TODO: convert between i32 and index where appropriate, since omp.loop_nest operates with i32 and + # scf.for with index. + for arg_use in flat_loop_body.block.args[0].uses: + if isinstance(arg_use.operation, memref.StoreOp): + store_op = arg_use.operation + + ## Replace the store ops first + if isinstance(store_op.memref.type.element_type, builtin.IntegerType): + index_to_i32 = arith.IndexCastOp(store_op.value, builtin.i32) + store_op.value.replace_by_if(index_to_i32.result, lambda use: isinstance(use.operation, memref.StoreOp)) + rewriter.insert_op(index_to_i32, InsertPoint.before(store_op)) + + else: + alloca_op = store_op.memref.owner + assert isinstance(alloca_op, memref.AllocaOp) + index_alloca = memref.AllocaOp.get(builtin.IndexType(), shape=alloca_op.memref.type.shape) + rewriter.replace_op(alloca_op, index_alloca) + + idx_memref = store_op.memref + for idx_memref_use in idx_memref.uses: + if isinstance(idx_memref_use.operation, memref.LoadOp): + load_op = idx_memref_use.operation + index_load = memref.LoadOp.get(load_op.memref, load_op.indices) + rewriter.replace_op(load_op, index_load) + + # Original type of the block arg of the loop nest op + cast_ind_var = arith.IndexCastOp(index_load.res, builtin.i32) + rewriter.insert_op(cast_ind_var, InsertPoint.after(index_load)) + index_load.res.replace_by_if(cast_ind_var.result, lambda use: use.operation != cast_ind_var) + + flat_loop = scf.ForOp(flat_lb, flat_ub, flat_step, (), flat_loop_body) + #flat_loop = scf.ForOp(loop_nest_op.lowerBound[0], omp_loop_op.upperBound[0], omp_loop_op.step[0], (), flat_loop_body) + rewriter.replace_op(loop_nest_op, flat_loop) + + return flat_loop + + @staticmethod def remove_target(target_op: omp.TargetOp, target_func: func.FuncOp, rewriter: PatternRewriter): """Remove the target operation from the module.""" @@ -110,44 +159,10 @@ def remove_omp_parallel(parallel_op: omp.ParallelOp, rewriter: PatternRewriter): rewriter.erase_op(ws_loop) if isinstance(omp_loop_op, omp.LoopNestOp): - flat_loop_body = rewriter.move_region_contents_to_new_regions(omp_loop_op.body) - rewriter.replace_op(flat_loop_body.block.last_op, scf.YieldOp()) - flat_lb = arith.IndexCastOp(omp_loop_op.lowerBound[0], builtin.IndexType()) - flat_ub = arith.IndexCastOp(omp_loop_op.upperBound[0], builtin.IndexType()) - flat_step = arith.IndexCastOp(omp_loop_op.step[0], builtin.IndexType()) - rewriter.insert_op(flat_lb, InsertPoint.after(omp_loop_op.lowerBound[0].owner)) - rewriter.insert_op(flat_ub, InsertPoint.after(omp_loop_op.upperBound[0].owner)) - rewriter.insert_op(flat_step, InsertPoint.after(omp_loop_op.step[0].owner)) - rewriter.replace_value_with_new_type(flat_loop_body.block.args[0], builtin.IndexType()) - - # Replace the store ops first - for arg_use in flat_loop_body.block.args[0].uses: - if isinstance(arg_use.operation, memref.StoreOp): - store_op = arg_use.operation - alloca_op = store_op.memref.owner - assert isinstance(alloca_op, memref.AllocaOp) - index_alloca = memref.AllocaOp.get(builtin.IndexType(), shape=alloca_op.memref.type.shape) - rewriter.replace_op(alloca_op, index_alloca) - - idx_memref = store_op.memref - for idx_memref_use in idx_memref.uses: - if isinstance(idx_memref_use.operation, memref.LoadOp): - load_op = idx_memref_use.operation - index_load = memref.LoadOp.get(load_op.memref, load_op.indices) - rewriter.replace_op(load_op, index_load) - - # Original type of the block arg of the loop nest op - cast_ind_var = arith.IndexCastOp(index_load.res, builtin.i32) - rewriter.insert_op(cast_ind_var, InsertPoint.after(index_load)) - index_load.res.replace_by_if(cast_ind_var.result, lambda use: use.operation != cast_ind_var) - - flat_loop = scf.ForOp(flat_lb, flat_ub, flat_step, (), flat_loop_body) - #flat_loop = scf.ForOp(omp_loop_op.lowerBound[0], omp_loop_op.upperBound[0], omp_loop_op.step[0], (), flat_loop_body) - rewriter.replace_op(omp_loop_op, flat_loop) - + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) one = arith.ConstantOp.from_int_and_width(1, 32) pragma_pipeline = PragmaPipelineOp(one) - rewriter.insert_op([one, pragma_pipeline], InsertPoint.at_start(flat_loop_body.block)) + rewriter.insert_op([one, pragma_pipeline], InsertPoint.at_start(flat_loop.body.block)) parallel_block = parallel_op.region.block parallel_op.region.detach_block(parallel_block) @@ -155,6 +170,29 @@ def remove_omp_parallel(parallel_op: omp.ParallelOp, rewriter: PatternRewriter): rewriter.inline_block(parallel_block, InsertPoint.before(parallel_op)) rewriter.erase_op(parallel_op) + @staticmethod + def remove_omp_simd(simd_op : omp.SimdOp, rewriter : PatternRewriter): + for priv_var in simd_op.private_vars: + arg_idx = simd_op.operands.index(priv_var) + simd_op.body.block.args[arg_idx].replace_by(priv_var) + + omp_loop_op = simd_op.body.block.first_op + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + simd_factor = simd_op.simdlen.value.data + ssa_simd_factor = arith.ConstantOp.from_int_and_width(simd_factor, 32) + pragma_unroll = PragmaUnrollOp(ssa_simd_factor) + rewriter.insert_op([ssa_simd_factor, pragma_unroll], InsertPoint.at_start(flat_loop.body.block)) + else: + flat_loop = simd_op.body.block.first_op + + assert isinstance(flat_loop, scf.ForOp) + flat_loop.detach() + rewriter.insert_op(flat_loop, InsertPoint.before(simd_op)) + #rewriter.erase_matched_op() #FIXME: this does not work + rewriter.erase_op(simd_op) + @staticmethod def remove_remaining_omp_ops(target_func: func.FuncOp, rewriter: PatternRewriter): @@ -200,6 +238,15 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, if omp_parallel: RemoveOps.remove_omp_parallel(omp_parallel, rewriter) + omp_simd = None + for op in target_func.walk(): + if isinstance(op, omp.SimdOp): + omp_simd = op + break + + if omp_simd: + RemoveOps.remove_omp_simd(omp_simd, rewriter) + target_op = [op for op in target_func.walk() if isinstance(op, omp.TargetOp)][0] RemoveOps.remove_target(target_op, target_func, rewriter) RemoveOps.remove_remaining_omp_ops(target_func, rewriter)