Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ftn/tools/config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

DOCKER_RUN="docker run -i -u $(id -u):$(id -g) -v $(pwd):$(pwd) -w $(pwd)"

alias python-soda="$DOCKER_RUN agostini01/soda:v15.08:v15.08 python3"
alias soda-opt="$DOCKER_RUN agostini01/soda:v15.08 soda-opt"
alias opt-16="$DOCKER_RUN agostini01/soda:v15.08 opt"
alias mlir-opt="$DOCKER_RUN agostini01/soda:v15.08 mlir-opt"
alias soda-mlir-translate="$DOCKER_RUN agostini01/soda:v15.08 mlir-translate"
6 changes: 4 additions & 2 deletions ftn/tools/ftn_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore
from ftn.transforms.merge_memref_deref import MergeMemRefDeref
from ftn.transforms.extract_target import ExtractTarget
from ftn.transforms.fpga.target_to_hls import TargetToHLSPass
from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass
# from ftn.transforms.extract_target import ExtractTarget
# from ftn.transforms.isolate_target import IsolateTarget
# from psy.extract_stencil import ExtractStencil
# from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT
Expand All @@ -27,8 +28,9 @@ def register_all_passes(self):
super().register_all_passes()
self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore)
self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref)
self.register_pass("extract-target", lambda: ExtractTarget)
self.register_pass("target-to-hls", lambda: TargetToHLSPass)
self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass)
# self.register_pass("extract-target", lambda: ExtractTarget)
# self.register_pass("isolate-target", lambda: IsolateTarget)
# self.register_pass("convert-to-tt", lambda: ConvertToTT)

Expand Down
29 changes: 29 additions & 0 deletions ftn/tools/mlir_to_ll.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash

ODIR=fpga
CLKPERIOD=10
TARGET_BOARD=xc7vx690t-ffg1930-3
MLIR_INPUT=$1
KERNELNAME=$2
LIBDIR=/opt/soda-opt/lib/
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

source $SCRIPT_DIR/config.sh
shopt -s expand_aliases

soda-opt $MLIR_INPUT --lower-affine --canonicalize --lower-all-to-llvm=use-bare-ptr-memref-call-conv | \
soda-mlir-translate --mlir-to-llvmir --opaque-pointers=0 -o model.ll

opt-16 model.ll \
-S \
-enable-new-pm=0 \
-load "${LIBDIR}/VhlsLLVMRewriter.so" \
-mem2arr -strip-debug \
-instcombine \
-xlnname \
-xlnanno -xlntop $KERNELNAME \
-xlntbgen -xlntbdummynames="$KERNELNAME.dummy.c" \
-xlntbtclnames="$KERNELNAME.run.tcl" \
-xlnllvm="$KERNELNAME.opt.ll" \
-clock-period-ns=$CLKPERIOD -target=$TARGET_BOARD \
> $KERNELNAME.opt.ll
25 changes: 24 additions & 1 deletion ftn/tools/xftn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def initialise_argument_parser():
default=[],
help="Additional transformation pass for ftn-opt",
)
parser.add_argument(
"-f",
"--fpga-llvmir",
action="store_true",
help="Generate LLVM IR for FPGA compatible with Vitis HLS."
)
parser.add_argument(
"-I",
"--include-directory",
Expand Down Expand Up @@ -96,7 +102,7 @@ def initialise_argument_parser():
parser.add_argument(
"--stages",
default=None,
help="Specify which stages will run (a combination of: flang, pre, ftn, post, mlir, obj, clang) in comma separated list without spaces",
help="Specify which stages will run (a combination of: flang, pre, ftn, post, mlir, obj, clang, fpga-llvmir) in comma separated list without spaces",
)
parser.add_argument(
"-v",
Expand Down Expand Up @@ -132,6 +138,7 @@ def enable_disable_stages_by_output_type(options_db, output_type):
options_db["run_mlir_to_llvmir_stage"] = options_db["run_postprocess_stage"]
options_db["run_create_object_stage"] = output_type == OutputType.OBJECT
options_db["run_build_executable_stage"] = output_type == OutputType.EXECUTABLE
options_db["run_fpga_llvmir_stage"] = False


def build_options_db_from_args(args):
Expand Down Expand Up @@ -175,6 +182,9 @@ def build_options_db_from_args(args):
options_db["run_mlir_to_llvmir_stage"] = False
options_db["run_build_executable_stage"] = False

if options_db["fpga_llvmir"]:
options_db["run_fpga_llvmir_stage"] = True

if options_db["stages"] is not None:
# We handle stages to run last, this overrides all other stage selection
# through other options or output filename
Expand All @@ -188,6 +198,7 @@ def build_options_db_from_args(args):
options_db["run_mlir_to_llvmir_stage"] = "mlir" in stages_to_run
options_db["run_create_object_stage"] = "obj" in stages_to_run
options_db["run_build_executable_stage"] = "clang" in stages_to_run
options_db["run_fpga_llvmir_stage"] = "fpga-llvmir" in stages_to_run

if "flang" in stages_to_run:
stages_to_run.remove("flang")
Expand All @@ -203,6 +214,8 @@ def build_options_db_from_args(args):
stages_to_run.remove("obj")
if "clang" in stages_to_run:
stages_to_run.remove("clang")
if "fpga-llvmir" in stages_to_run:
stages_to_run.remove("fpga-llvmir")
if len(stages_to_run) > 0:
for e in stages_to_run:
print(f"Unknown stage provided as argument '{e}'")
Expand Down Expand Up @@ -241,6 +254,9 @@ def display_verbose_start_message(options_db):
print(
f"Stage 'Build executable': {'Enabled' if options_db['run_build_executable_stage'] else 'Disabled'}"
)
print(
f"Stage 'FPGA LLVM-IR': {'Enabled' if options_db['run_fpga_llvmir_stage'] else 'Disabled'}"
)
if options_db["run_build_executable_stage"]:
print("")
print(f"Also linking against object files {options_db['link-objectfiles']}")
Expand Down Expand Up @@ -459,6 +475,9 @@ def build_executable(output_tmp_dir, input_fn, executable_fn, options_db):
os.system(f"clang {clang_args}")
post_stage_check(executable_fn, options_db["verbose"], executable=True)

def generate_llvmir_for_fpga(out_file):
tools_path = os.path.dirname(os.path.abspath(__file__))
os.system(f"{tools_path}/mlir_to_ll.sh {out_file} tt_device")

def print_file_contents(filename):
with open(filename) as f:
Expand Down Expand Up @@ -568,6 +587,10 @@ def main():
if options_db["offload"]:
print_verbose_message(options_db, f"Offload MLIR in '{out_file}'")

if options_db["run_fpga_llvmir_stage"]:
print_verbose_message(options_db, "Running FPGA LLVM-IR generation stage",)
generate_llvmir_for_fpga(out_file)

if options_db["cleanup"]:
remove_file_if_exists(
tmp_dir,
Expand Down
69 changes: 39 additions & 30 deletions ftn/transforms/extract_target.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import ABC
from ast import Module
from hmac import new
from typing import TypeVar, cast
from dataclasses import dataclass
from dataclasses import dataclass, field
import itertools
from xdsl.utils.hints import isa
from xdsl.dialects import memref, scf, omp
from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region
from xdsl.context import Context
from xdsl.ir import Operation, SSAValue, OpResult, Attribute, Block, Region

from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter,
op_type_rewrite_pattern,
Expand All @@ -13,10 +16,12 @@
from xdsl.passes import ModulePass
from xdsl.dialects import builtin, func, llvm, arith
from ftn.util.visitor import Visitor
from xdsl.rewriter import InsertPoint

@dataclass
class RewriteTarget(RewritePattern):
def __init__(self):
self.target_ops=[]
module : builtin.ModuleOp
target_ops: list[Operation] = field(default_factory=list)

@op_type_rewrite_pattern
def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /):
Expand All @@ -31,38 +36,38 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /):
# Grab bounds and info, then at end the terminator
for var in op.map_vars:
var_op=var.owner
var_op.parent.detach_op(var_op)
arg_types.append(var_op.var_ptr[0].type)
arg_ssa.append(var_op.var_ptr[0])
if isa(var_op.var_ptr[0].type, builtin.MemRefType):
memref_type=var_op.var_ptr[0].type
src_memref=var_op.var_ptr[0]
arg_types.append(var_op.var_ptr.type)
arg_ssa.append(var_op.var_ptr)
locations[var_op]=loc_idx
loc_idx+=1
if isa(var_op.var_ptr.type, builtin.MemRefType):
memref_type=var_op.var_ptr.type
src_memref=var_op.var_ptr
if isa(memref_type.element_type, builtin.MemRefType):
assert len(memref_type.shape) == 0
memref_type=var_op.var_ptr[0].type.element_type
memref_loadop=memref.Load.get(src_memref, [])
memref_type=var_op.var_ptr.type.element_type
memref_loadop=memref.LoadOp.get(src_memref, [])
src_memref=memref_loadop.results[0]
memref_dim_ops.append(memref_loadop)
for idx, s in enumerate(memref_type.shape):
assert isa(s, builtin.IntAttr)
if (s.data == -1):
# Need to pass the dimension shape size in explicitly as it is deferred
const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType())
dim_size=memref.Dim.from_source_and_index(src_memref, const_op)
const_op=arith.ConstantOp.from_int_and_width(idx, builtin.IndexType())
dim_size=memref.DimOp.from_source_and_index(src_memref, const_op)
memref_dim_ops+=[const_op, dim_size]
arg_ssa.append(dim_size.results[0])
arg_types.append(dim_size.results[0].type)
loc_idx+=1

locations[var_op]=loc_idx
loc_idx+=1
if len(var_op.bounds) > 0:
bound_op=var_op.bounds[0].owner
bound_op.parent.detach_op(bound_op)
#self.target_ops+=[bound_op, var_op]
arg_types.append(bound_op.lower[0].type)
arg_ssa.append(bound_op.lower[0])
arg_types.append(bound_op.lower_bound.type)
arg_ssa.append(bound_op.lower_bound)
arg_types.append(bound_op.upper_bound.type)
arg_ssa.append(bound_op.upper_bound)
locations[bound_op]=loc_idx
# Add two, as second is the size
# Adding both lower and upper bound
loc_idx+=2
else:
pass#self.target_ops+=[var_op]
Expand All @@ -78,7 +83,7 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /):
bound_op=var_op.bounds[0].owner
res_types=[]
for res in bound_op.results: res_types.append(res.type)
new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []],
new_bounds_op=omp.MapBoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [new_block.args[locations[bound_op]+1]], [], [], []],
properties={"stride_in_bytes": bound_op.stride_in_bytes},
result_types=res_types)

Expand All @@ -87,8 +92,8 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /):

res_types=[]
for res in var_op.results: res_types.append(res.type)
mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds],
properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type},
mapinfo_op=omp.MapInfoOp.build(operands=[new_block.args[locations[var_op]], [], [], map_bounds],
properties={"map_type": var_op.map_type, "name": var_op.var_name, "var_type": var_op.var_type, "map_capture_type": omp.VariableCaptureKindAttr(omp.VariableCaptureKind.BY_REF)},
result_types=res_types)
new_mapinfo_ssa.append(mapinfo_op.results[0])

Expand All @@ -97,21 +102,23 @@ def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /):
reg=op.region
op.detach_region(reg)

new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg])
new_omp_target_op=omp.TargetOp.build(operands=[[],[],[],[],[],[],[],[],[], new_mapinfo_ssa, [], []], regions=[reg])
new_block.add_op(new_omp_target_op)
new_block.add_op(func.Return())
new_block.add_op(func.ReturnOp())

new_fn_type=builtin.FunctionType.from_lists(arg_types, [])

body = Region()
body.add_block(new_block)

new_func=func.FuncOp("tt_device", new_fn_type, body)
new_func_signature=func.FuncOp.external("tt_device", new_fn_type.inputs.data, new_fn_type.outputs.data)

self.target_ops=[new_func]

call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[])
call_fn=func.CallOp.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[])
op.parent.insert_ops_before(memref_dim_ops+[call_fn], op)
rewriter.insert_op(new_func_signature, InsertPoint.at_start(self.module.body.block))

op.parent.detach_op(op)

Expand All @@ -122,15 +129,17 @@ class ExtractTarget(ModulePass):
"""
name = 'extract-target'

def apply(self, ctx: MLContext, module: builtin.ModuleOp):
rw_target= RewriteTarget()
def apply(self, ctx: Context, module: builtin.ModuleOp):
rw_target= RewriteTarget(module)
walker = PatternRewriteWalker(GreedyRewritePatternApplier([
rw_target,
]), apply_recursively=False, walk_reverse=True)

walker.rewrite_module(module)

containing_mod=builtin.ModuleOp([])
# NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of
# the module will not be satisfied.
containing_mod=builtin.ModuleOp(Region())
module.regions[0].move_blocks(containing_mod.regions[0])

new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")})
Expand Down
Loading