diff --git a/mlir/__init__.py b/mlir/__init__.py index 0df1233..5fe79ab 100644 --- a/mlir/__init__.py +++ b/mlir/__init__.py @@ -1,3 +1,4 @@ from .parser import parse_file, parse_path, parse_string, Parser from . import astnodes +from . import dialects from .visitors import NodeVisitor, NodeTransformer diff --git a/mlir/astnodes.py b/mlir/astnodes.py index c5b3cc8..065bb20 100644 --- a/mlir/astnodes.py +++ b/mlir/astnodes.py @@ -485,11 +485,15 @@ def dump(self, indent: int = 0) -> str: return self.value.dump(indent) + ( (':' + dump_or_value(self.count, indent)) if self.count else '') +class Op(Node): + pass + + @dataclass class Operation(Node): result_list: List[OpResult] - op: "Op" + op: Node location: Optional["Location"] = None def dump(self, indent: int = 0) -> str: @@ -503,10 +507,6 @@ def dump(self, indent: int = 0) -> str: return result -class Op(Node): - pass - - @dataclass class GenericOperation(Op): name: str diff --git a/mlir/builder/builder.py b/mlir/builder/builder.py index d9cdb4e..ed52506 100644 --- a/mlir/builder/builder.py +++ b/mlir/builder/builder.py @@ -1,7 +1,8 @@ """ MLIR IR Builder.""" import mlir.astnodes as mast -import mlir.dialects.standard as std +import mlir.dialects.arith as arith +import mlir.dialects.memref as memref import mlir.dialects.affine as affine import mlir.dialects.func as func from typing import Optional, Tuple, Union, List, Any @@ -436,28 +437,28 @@ def goto_after(self, query: MatchExpressionBase, def addf(self, op_a: mast.SsaId, op_b: mast.SsaId, type: mast.Type, name: Optional[str] = None): - op = std.AddfOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) + op = arith.AddFOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) return self._insert_op_in_block([name], op) def mulf(self, op_a: mast.SsaId, op_b: mast.SsaId, type: mast.Type, name: Optional[str] = None): - op = std.MulfOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) + op = arith.MulFOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) return self._insert_op_in_block([name], op) def dim(self, memref_or_tensor: mast.SsaId, index: mast.SsaId, memref_type: Union[mast.MemRefType, mast.TensorType], name: Optional[str] = None): - op = std.DimOperation(match=0, operand=memref_or_tensor, index=index, + op = memref.DimOperation(match=0, operand=memref_or_tensor, index=index, type=memref_type) return self._insert_op_in_block([name], op) def index_constant(self, value: int, name: Optional[str] = None): - op = std.ConstantOperation(match=0, value=value, type=mast.IndexType()) + op = arith.ConstantOperation(match=0, value=value, type=mast.IndexType()) return self._insert_op_in_block([name], op) def float_constant(self, value: float, type: mast.FloatType, name: Optional[str] = None): - op = std.ConstantOperation(match=0, value=value, type=type) + op = arith.ConstantOperation(match=0, value=value, type=type) return self._insert_op_in_block([name], op) # }}} diff --git a/mlir/dialects/__init__.py b/mlir/dialects/__init__.py index eefd918..9305eb4 100644 --- a/mlir/dialects/__init__.py +++ b/mlir/dialects/__init__.py @@ -1,8 +1,12 @@ from .affine import affine as affine_dialect -from .standard import standard as std_dialect +from .cf import cf as cf_dialect +from .math import math as math_dialect +from .tensor import tensor as tensor_dialect +from .arith import arith as arith_dialect from .scf import scf as scf_dialect from .linalg import linalg from .func import func as func_dialect +from .memref import memref as memref_dialect -STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect] +STANDARD_DIALECTS = [affine_dialect, cf_dialect, math_dialect, tensor_dialect, arith_dialect, scf_dialect, linalg, func_dialect, memref_dialect] diff --git a/mlir/dialects/arith.py b/mlir/dialects/arith.py new file mode 100644 index 0000000..92ed0f3 --- /dev/null +++ b/mlir/dialects/arith.py @@ -0,0 +1,121 @@ +""" Implementation of the arith (Arithmetic) dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation, BinaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple, Union + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + + +# Unary Operations + +class BitcastOperation(UnaryOperation): _opname_ = 'arith.bitcast' +class ExtFOperation(UnaryOperation): _opname_ = 'arith.extf' +class ExtSIOperation(UnaryOperation): _opname_ = 'arith.extsi' +class ExtUIOperation(UnaryOperation): _opname_ = 'arith.extui' +class FPToSIOperation(UnaryOperation): _opname_ = 'arith.fptosi' +class FPToUIOperation(UnaryOperation): _opname_ = 'arith.fptoui' +class NegFOperation(UnaryOperation): _opname_ = 'arith.negf' +class SIToFPOperation(UnaryOperation): _opname_ = 'arith.sitofp' +class UIToFPOperation(UnaryOperation): _opname_ = 'arith.uitofp' + +# Arithmetic Operations +class AddFOperation(BinaryOperation): _opname_ = 'arith.addf' +class AddIOperation(BinaryOperation): _opname_ = 'arith.addi' +class AndIOperation(BinaryOperation): _opname_ = 'arith.andi' +class CeilDivSIOperation(BinaryOperation): _opname_ = 'arith.ceildivsi' +class CeilDivUIOperation(BinaryOperation): _opname_ = 'arith.ceildivui' +class DivFOperation(BinaryOperation): _opname_ = 'arith.divf' +class DivSIOperation(BinaryOperation): _opname_ = 'arith.divsi' +class DivUIOperation(BinaryOperation): _opname_ = 'arith.divui' +class FloorDivSIOperation(BinaryOperation): _opname_ = 'arith.floordivsi' +class MaximumFOperation(BinaryOperation): _opname_ = 'arith.maximumf' +class MaxNumFOperation(BinaryOperation): _opname_ = 'arith.maxnumf' +class MaxSIOperation(BinaryOperation): _opname_ = 'arith.maxsi' +class MaxUIOperation(BinaryOperation): _opname_ = 'arith.maxui' +class MinimumFOperation(BinaryOperation): _opname_ = 'arith.minimumf' +class MinNumFOperation(BinaryOperation): _opname_ = 'arith.minnumf' +class MinSIOperation(BinaryOperation): _opname_ = 'arith.minsi' +class MinUIOperation(BinaryOperation): _opname_ = 'arith.minui' +class MulFOperation(BinaryOperation): _opname_ = 'arith.mulf' +class MulIOperation(BinaryOperation): _opname_ = 'arith.muli' +class MulSIExtendedOp(BinaryOperation): _opname_ = 'arith.mulsi_extended' +class MulUIExtendedOp(BinaryOperation): _opname_ = 'arith.mului_extended' +class OrIOperation(BinaryOperation): _opname_ = 'arith.ori' +class RemFOperation(BinaryOperation): _opname_ = 'arith.remf' +class RemSIOperation(BinaryOperation): _opname_ = 'arith.remsi' +class RemUIOperation(BinaryOperation): _opname_ = 'arith.remui' +class ShLIOperation(BinaryOperation): _opname_ = 'arith.shli' +class ShRSIOperation(BinaryOperation): _opname_ = 'arith.shrsi' +class ShRUIOperation(BinaryOperation): _opname_ = 'arith.shrui' +class SubIOperation(BinaryOperation): _opname_ = 'arith.subi' +class SubFOperation(BinaryOperation): _opname_ = 'arith.subf' +class TruncFOperation(BinaryOperation): _opname_ = 'arith.truncf' +class TruncIOperation(BinaryOperation): _opname_ = 'arith.trunci' +class XorIOperation(BinaryOperation): _opname_ = 'arith.xori' + + +@dataclass +class AddUIExtendedOperation(DialectOp): + lhs_operand: mast.SsaId + rhs_operand: mast.SsaId + sum_type: mast.Type + ovf_type = mast.Type + _syntax_ = 'arith.addui_extended {lhs_operand.ssa_id} , {rhs_operand.ssa_id} : {sum_type.type} , {ovf_type.type}' + + +@dataclass +class CmpiOperation(DialectOp): + comptype: str + operand_a: mast.SsaId + operand_b: mast.SsaId + type: mast.Type + _syntax_ = 'arith.cmpi {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' + + +@dataclass +class CmpfOperation(DialectOp): + comptype: str + operand_a: mast.SsaId + operand_b: mast.SsaId + type: mast.Type + _syntax_ = 'arith.cmpf {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' + + +@dataclass +class ConstantOperation(DialectOp): + value: Literal + type: mast.Type + _syntax_ = ['arith.constant {value.constant_literal} : {type.type}', 'arith.constant {value.constant_literal}'] + + + +@dataclass +class IndexCastOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'arith.index_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +@dataclass +class IndexCastUIOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'arith.index_castui {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +@dataclass +class SelectOperation(DialectOp): + cond: SsaUse + arg_true: SsaUse + arg_false: SsaUse + _syntax_ = 'arith.select {cond.ssa_use} , {arg_true.ssa_use} , {arg_false.ssa_use} : {type.type}' + + +# Inspect current module to get all classes defined above +arith = Dialect('arith', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/cf.py b/mlir/dialects/cf.py new file mode 100644 index 0000000..8ab338f --- /dev/null +++ b/mlir/dialects/cf.py @@ -0,0 +1,32 @@ +""" Implementation of the CF (Control Flow) dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple, Union + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + + +@dataclass +class BrOperation(DialectOp): + block_id: mast.BlockId + args: Optional[List[Tuple[mast.SsaId, mast.Type]]] = None + _syntax_ = ['cf.br {block.block_id}', + 'cf.br {block.block_id} {args.block_arg_list}'] + + +@dataclass +class CondBrOperation(DialectOp): + cond: SsaUse + block_true: mast.BlockId + block_false: mast.BlockId + _syntax_ = ['cf.cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}'] + + +# Inspect current module to get all classes defined above +cf = Dialect('cf', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/func.py b/mlir/dialects/func.py index c27d86a..876e273 100644 --- a/mlir/dialects/func.py +++ b/mlir/dialects/func.py @@ -34,6 +34,20 @@ class ConstantOperation(DialectOp): _syntax_ = ['func.constant {value.symbol_ref_id} : {type.type}'] # Note: The 'func.func' operation is defined as 'function' in mlir.lark. +# not anymore lmfaooooo +@dataclass +class FuncOperation(DialectOp): + name: mast.SymbolRefId + args: Optional[List[mast.NamedArgument]] + result_list: Optional[List[mast.OpResult]] | mast.OpResult + func_mod_attrs: Optional[mast.AttributeDict] + body: Optional[mast.Region] + trail: Optional[mast.Location] = None + + _syntax_ = [ + 'func.func {name.symbol_ref_id} ( {args.optional_arg_list} ) {result_list.optional_fn_result_list} {func_mod_attrs.optional_func_mod_attrs} {body.optional_fn_body}', + 'func.func {name.symbol_ref_id} ( {args.optional_arg_list} ) {result_list.optional_fn_result_list} {func_mod_attrs.optional_func_mod_attrs} {body.optional_fn_body} (loc ({trail.optional_location}))'] + @dataclass class ReturnOperation(DialectOp): diff --git a/mlir/dialects/math.py b/mlir/dialects/math.py new file mode 100644 index 0000000..b574cca --- /dev/null +++ b/mlir/dialects/math.py @@ -0,0 +1,20 @@ +""" Implementation of the math (Mathematics) dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple + + +# Unary Operations +class AbsfOperation(UnaryOperation): _opname_ = 'math.absf' +class CosOperation(UnaryOperation): _opname_ = 'math.cos' +class ExpOperation(UnaryOperation): _opname_ = 'math.exp' +class TanhOperation(UnaryOperation): _opname_ = 'math.tanh' +class CopysignOperation(UnaryOperation): _opname_ = 'math.copysign' + +# Inspect current module to get all classes defined above +math = Dialect('math', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/memref.py b/mlir/dialects/memref.py new file mode 100644 index 0000000..e52f70f --- /dev/null +++ b/mlir/dialects/memref.py @@ -0,0 +1,184 @@ +""" Implementation of the Memref dialect. """ + +import inspect +import sys +from typing import List, Tuple, Optional, Union +from dataclasses import dataclass + +import mlir.astnodes as mast +from mlir.dialect import Dialect, DialectOp, is_op + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + + +# class AssumeAlignmentOperation(DialectOp): pass +# class AtomicRMWOperation(DialectOp): pass + +@dataclass +class AtomicYieldOperation(DialectOp): + result: SsaUse + result_type: mast.Type + _syntax_ = 'memref.atomic_yield {result.ssa_use} : {result_type.type}' + +@dataclass +class CopyOperation(DialectOp): + source: SsaUse + target: SsaUse + source_type: mast.MemRefType + target_type: mast.MemRefType + _syntax_ = 'memref.copy {source.ssa_use} , {target.ssa_use} : {source.memref_type} to {target.memref_type}' + + +# class GenericAtomicRMWOperation(DialectOp): pass + +@dataclass +class LoadOperation(DialectOp): + arg: SsaUse + index: List[SsaUse] + type: mast.MemRefType + _syntax_ = 'memref.load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + + +@dataclass +class AllocOperation(DialectOp): + args: mast.DimAndSymbolList + type: mast.MemRefType + _syntax_ = 'memref.alloc {args.dim_and_symbol_use_list} : {type.memref_type}' + +@dataclass +class AllocaOperation(DialectOp): + args: mast.DimAndSymbolList + type: mast.MemRefType + _syntax_ = 'memref.alloca {args.dim_and_symbol_use_list} : {type.memref_type}' + +# class AllocaScopeOperation(DialectOp): pass +# class AllocaScopeReturnOperation(DialectOp): pass + +@dataclass +class CastOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'memref_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +# class CollapseShapeOperation(DialectOp): pass + +@dataclass +class DeallocOperation(DialectOp): + arg: SsaUse + type: mast.MemRefType + _syntax_ = 'memref.dealloc {arg.ssa_use} : {type.memref_type}' + +@dataclass +class DimOperation(DialectOp): + operand: mast.SsaId + index: mast.SsaId + type: mast.Type + _syntax_ = 'memref.dim {operand.ssa_id} , {index.ssa_id} : {type.type}' + +@dataclass +class DmaStartOperation(DialectOp): + src: SsaUse + src_index: List[SsaUse] + dst: SsaUse + dst_index: List[SsaUse] + size: SsaUse + tag: SsaUse + tag_index: List[SsaUse] + src_type: mast.MemRefType + dst_type: mast.MemRefType + tag_type: mast.MemRefType + stride: Optional[SsaUse] = None + transfer_per_stride: Optional[SsaUse] = None + _syntax_ = [ + 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}', + 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}' + ] + + +@dataclass +class DmaWaitOperation(DialectOp): + tag: SsaUse + tag_index: List[SsaUse] + size: SsaUse + type: mast.MemRefType + _syntax_ = 'dma_wait {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {size.ssa_use} : {type.memref_type}' + +# class ExpandShapeOperation(DialectOp): pass + +@dataclass +class ExtractAlignedPointerAsIndexOperation(DialectOp): + source: SsaUse + source_type: mast.Type + dest_type: mast.Type + _syntax_ = 'memref.extract_aligned_pointer_as_index {source.ssa_use} : {source_type.type} -> {dest_type.type}' + +# class ExtractStridedMetadataOperation(DialectOp): pass + +@dataclass +class GetGlobalOperation(DialectOp): + name: SsaUse + result_type: mast.MemRefType + _syntax_ = 'memref.get_global {name.ssa_use} : {result_type.type}' + + +# class GlobalOperation(DialectOp): pass + +@dataclass +class MemorySpaceCastOperation(DialectOp): + source: SsaUse + source_type: mast.MemRefType + dest_type: mast.MemRefType + _syntax_ = 'memref.memory_space_cast {source.ssa_use} : {source_type.memref_type} to {dest_type.memref_type}' + + +# class PrefetchOperation(DialectOp): pass + +@dataclass +class RankOperation(DialectOp): + operand: SsaUse + op_type: mast.MemRefType + _syntax_ = 'memref.rank {operand.ssa_use} : {op_type.memref_type}' + + +# class ReallocOperation(DialectOp): pass +# class ReinterpretCastOperation(DialectOp): pass +# class ReshapeOperation(DialectOp): pass + +@dataclass +class StoreOperation(DialectOp): + addr: SsaUse + ref: SsaUse + index: List[SsaUse] + type: mast.MemRefType + _syntax_ = 'memref.store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + + +# class TransposeOperation(DialectOp): pass + +@dataclass +class ViewOperation(DialectOp): + operand: SsaUse + offset: SsaUse + src_type: mast.Type + dst_type: mast.Type + sizes: Optional[List[SsaUse]] = None + _syntax_ = ['memref.view {operand.ssa_use} [ {offset.ssa_use} ] [ {sizes.ssa_use_list} ] : {src_type.type} to {dst_type.type}', + 'memref.view {operand.ssa_use} [ {offset.ssa_use} ] [ ] : {src_type.type} to {dst_type.type}'] + + +@dataclass +class SubviewOperation(DialectOp): + operand: SsaUse + offsets: List[SsaUse] + sizes: List[SsaUse] + strides: List[SsaUse] + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'memref.subview {operand.ssa_use} [ {offsets.ssa_use_list} ] [ {sizes.ssa_use_list} ] [ {strides.ssa_use_list} ] : {src_type.type} to {dst_type.type}' + + +# Inspect current module to get all classes defined above +memref = Dialect('memref', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index 8c71848..082dee3 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -16,6 +16,13 @@ class SCFConditionOp(DialectOp): _syntax_ = ['scf.condition ( {condition.ssa_id} ) {args.ssa_id_list} : {out_types.type_list_no_parens}'] +@dataclass +class SCFExecuteRegionOp(DialectOp): + out_type: mast.Type + body: mast.Region + _syntax_ = 'scf.execute_region -> {out_type.type} {body.region}' + + @dataclass class SCFForOp(DialectOp): index: mast.SsaId diff --git a/mlir/dialects/standard.py b/mlir/dialects/standard.py deleted file mode 100644 index 6eb6168..0000000 --- a/mlir/dialects/standard.py +++ /dev/null @@ -1,244 +0,0 @@ -""" Implementation of the Standard dialect. """ - -import inspect -import sys -from mlir.dialect import (Dialect, DialectOp, UnaryOperation, BinaryOperation, - is_op) -import mlir.astnodes as mast -from typing import List, Tuple, Optional, Union -from dataclasses import dataclass - - -Literal = Union[mast.StringLiteral, float, int, bool] -SsaUse = Union[mast.SsaId, Literal] - - -# Terminator Operations -@dataclass -class BrOperation(DialectOp): - block_id: mast.BlockId - args: Optional[List[Tuple[mast.SsaId, mast.Type]]] = None - _syntax_ = ['br {block.block_id}', - 'br {block.block_id} {args.block_arg_list}'] - - -@dataclass -class CondBrOperation(DialectOp): - cond: SsaUse - block_true: mast.BlockId - block_false: mast.BlockId - _syntax_ = ['cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}'] - - -# Core Operations -@dataclass -class DimOperation(DialectOp): - operand: mast.SsaId - index: mast.SsaId - type: mast.Type - _syntax_ = 'dim {operand.ssa_id} , {index.ssa_id} : {type.type}' - - -# Memory Operations -@dataclass -class AllocOperation(DialectOp): - args: mast.DimAndSymbolList - type: mast.MemRefType - _syntax_ = 'alloc {args.dim_and_symbol_use_list} : {type.memref_type}' - - -@dataclass -class AllocStaticOperation(DialectOp): - base: int - type: mast.MemRefType - _syntax_ = 'alloc_static ( {base.integer_literal} ) : {type.memref_type}' - - -@dataclass -class DeallocOperation(DialectOp): - arg: SsaUse - type: mast.MemRefType - _syntax_ = 'dealloc {arg.ssa_use} : {type.memref_type}' - - -@dataclass -class DmaStartOperation(DialectOp): - src: SsaUse - src_index: List[SsaUse] - dst: SsaUse - dst_index: List[SsaUse] - size: SsaUse - tag: SsaUse - tag_index: List[SsaUse] - src_type: mast.MemRefType - dst_type: mast.MemRefType - tag_type: mast.MemRefType - stride: Optional[SsaUse] = None - transfer_per_stride: Optional[SsaUse] = None - _syntax_ = [ - 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}', - 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}' - ] -@dataclass -class DmaWaitOperation(DialectOp): - tag: SsaUse - tag_index: List[SsaUse] - size: SsaUse - type: mast.MemRefType - _syntax_ = 'dma_wait {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {size.ssa_use} : {type.memref_type}' - - -@dataclass -class ExtractElementOperation(DialectOp): - arg: SsaUse - index: List[SsaUse] - type: mast.Type - _syntax_ = 'extract_element {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}' - - -@dataclass -class LoadOperation(DialectOp): - arg: SsaUse - index: List[SsaUse] - type: mast.MemRefType - _syntax_ = 'load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' - - -@dataclass -class SplatOperation(DialectOp): - arg: SsaUse - type: Union[mast.VectorType, mast.TensorType] - _syntax_ = 'splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type) - - -@dataclass -class StoreOperation(DialectOp): - addr: SsaUse - ref: SsaUse - index: List[SsaUse] - type: mast.MemRefType - _syntax_ = 'store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' - - -@dataclass -class TensorLoadOperation(DialectOp): - arg: SsaUse - type: mast.Type - _syntax_ = 'tensor_load {arg.ssa_use} : {type.type}' - - -@dataclass -class TensorStoreOperation(DialectOp): - arg: SsaUse - type: mast.Type - _syntax_ = 'tensor_store {src.ssa_use} , {dst.ssa_use} : {type.memref_type}' - -# Unary Operations -class AbsfOperation(UnaryOperation): _opname_ = 'absf' -class CeilfOperation(UnaryOperation): _opname_ = 'ceilf' -class CosOperation(UnaryOperation): _opname_ = 'cos' -class ExpOperation(UnaryOperation): _opname_ = 'exp' -class NegfOperation(UnaryOperation): _opname_ = 'negf' -class TanhOperation(UnaryOperation): _opname_ = 'tanh' -class CopysignOperation(UnaryOperation): _opname_ = 'copysign' -class SIToFPOperation(UnaryOperation): _opname_ = 'sitofp' - -# Arithmetic Operations -class AddiOperation(BinaryOperation): _opname_ = 'addi' -class AddfOperation(BinaryOperation): _opname_ = 'addf' -class AndOperation(BinaryOperation): _opname_ = 'and' -class DivisOperation(BinaryOperation): _opname_ = 'divis' -class DiviuOperation(BinaryOperation): _opname_ = 'diviu' -class RemisOperation(BinaryOperation): _opname_ = 'remis' -class RemiuOperation(BinaryOperation): _opname_ = 'remiu' -class DivfOperation(BinaryOperation): _opname_ = 'divf' -class MulfOperation(BinaryOperation): _opname_ = 'mulf' -class MulIOperation(BinaryOperation): _opname_ = 'muli' -class SubiOperation(BinaryOperation): _opname_ = 'subi' -class SubfOperation(BinaryOperation): _opname_ = 'subf' -class OrOperation(BinaryOperation): _opname_ = 'or' -class XorOperation(BinaryOperation): _opname_ = 'xor' - - -@dataclass -class CmpiOperation(DialectOp): - comptype: str - operand_a: mast.SsaId - operand_b: mast.SsaId - type: mast.Type - _syntax_ = 'cmpi {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' - - -@dataclass -class CmpfOperation(DialectOp): - comptype: str - operand_a: mast.SsaId - operand_b: mast.SsaId - type: mast.Type - _syntax_ = 'cmpf {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' - - -@dataclass -class ConstantOperation(DialectOp): - value: Literal - type: mast.Type - _syntax_ = 'constant {value.constant_literal} : {type.type}' - - -@dataclass -class IndexCastOperation(DialectOp): - arg: SsaUse - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'index_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' - - -@dataclass -class MemrefCastOperation(DialectOp): - arg: SsaUse - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'memref_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' - - -@dataclass -class TensorCastOperation(DialectOp): - arg: SsaUse - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'tensor_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' - - -@dataclass -class SelectOperation(DialectOp): - cond: SsaUse - arg_true: SsaUse - arg_false: SsaUse - _syntax_ = 'select {cond.ssa_use} , {arg_true.ssa_use} , {arg_false.ssa_use} : {type.type}' - - -@dataclass -class SubviewOperation(DialectOp): - operand: SsaUse - offsets: List[SsaUse] - sizes: List[SsaUse] - strides: List[SsaUse] - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'subview {operand.ssa_use} [ {offsets.ssa_use_list} ] [ {sizes.ssa_use_list} ] [ {strides.ssa_use_list} ] : {src_type.type} to {dst_type.type}' - - -@dataclass -class ViewOperation(DialectOp): - operand: SsaUse - offset: SsaUse - src_type: mast.Type - dst_type: mast.Type - sizes: Optional[List[SsaUse]] = None - _syntax_ = ['view {operand.ssa_use} [ {offset.ssa_use} ] [ {sizes.ssa_use_list} ] : {src_type.type} to {dst_type.type}', - 'view {operand.ssa_use} [ {offset.ssa_use} ] [ ] : {src_type.type} to {dst_type.type}'] - - -# Inspect current module to get all classes defined above -standard = Dialect('standard', ops=[m[1] for m in inspect.getmembers( - sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/tensor.py b/mlir/dialects/tensor.py new file mode 100644 index 0000000..758b758 --- /dev/null +++ b/mlir/dialects/tensor.py @@ -0,0 +1,36 @@ +""" Implementation of the Tensor dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple, Union + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + +@dataclass +class ExtractElementOperation(DialectOp): + arg: SsaUse + index: List[SsaUse] + type: mast.Type + _syntax_ = 'tensor.extract {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}' + + +@dataclass +class SplatOperation(DialectOp): + arg: SsaUse + type: Union[mast.VectorType, mast.TensorType] + _syntax_ = 'tensor.splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type) + +@dataclass +class TensorCastOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'tensor.cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +# Inspect current module to get all classes defined above +tensor = Dialect('tensor', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/lark/mlir.lark b/mlir/lark/mlir.lark index 9258111..409a810 100644 --- a/mlir/lark/mlir.lark +++ b/mlir/lark/mlir.lark @@ -23,7 +23,7 @@ negated_integer_literal : "-" integer_literal ?posneg_integer_literal : integer_literal | negated_integer_literal float_literal : /[-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?/ string_literal : ESCAPED_STRING -constant_literal : bool_literal | integer_literal | float_literal | string_literal +constant_literal : bool_literal | posneg_integer_literal | float_literal | string_literal // Identifier syntax bare_id : (letter| underscore) (letter|digit|underscore|id_chars)* @@ -172,7 +172,7 @@ custom_operation : bare_id "." bare_id optional_ssa_use_list trailing_type // Final operation definition // NOTE: "pymlir_dialect_ops" is defined externally by pyMLIR -operation : optional_op_result_list (pymlir_dialect_ops | custom_operation | generic_operation | module | generic_module | function) optional_trailing_loc +operation : optional_op_result_list (pymlir_dialect_ops | custom_operation | generic_operation | module | generic_module) optional_trailing_loc // ---------------------------------------------------------------------- // Blocks and regions @@ -236,7 +236,7 @@ function_result_list_parens : ("(" ")") | ("(" function_result_list_no_parens ") // Definition module : "module" optional_symbol_ref_id optional_func_mod_attrs region optional_trailing_loc -function : "func.func" symbol_ref_id "(" optional_arg_list ")" optional_fn_result_list optional_func_mod_attrs optional_fn_body optional_trailing_loc +// function : "func.func" symbol_ref_id "(" optional_arg_list ")" optional_fn_result_list optional_func_mod_attrs optional_fn_body optional_trailing_loc generic_module : string_literal "(" optional_arg_list ")" "(" region ")" optional_attr_dict trailing_type optional_trailing_loc // ---------------------------------------------------------------------- @@ -315,9 +315,9 @@ attribute_alias_def : attribute_alias "=" attribute_value // Structure of an MLIR parse-able string definition_list : definition* -function_list : function* +// function_list : function* module_list : (module | generic_module)* -definition_and_function_list : definition_list function_list +definition_and_function_list : definition_list // function_list definition_and_module_list : definition_list module_list mlir_file: definition_and_function_list* -> only_functions_and_definitions_file diff --git a/requirements.txt b/requirements.txt index af4bfeb..cd17ee2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -lark-parser==0.7.8 +lark==0.7.8 parse==1.14.0 pytest diff --git a/setup.py b/setup.py index 2c94ff0..dfe6604 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ }, include_package_data=True, install_requires=[ - 'lark-parser', 'parse' + 'lark', 'parse' ], tests_require=['pytest', 'pytest-cov'], test_suite='pytest', diff --git a/tests/test_builder.py b/tests/test_builder.py index f97eb43..45e1428 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -2,7 +2,7 @@ from mlir.builder import IRBuilder from mlir.builder import Reads, Writes, Isa from mlir.dialects.affine import AffineLoadOp -from mlir.dialects.standard import AddfOperation +from mlir.dialects.arith import AddFOperation def test_saxpy_builder(): @@ -39,14 +39,14 @@ def test_saxpy_builder(): def test_query(): block = parse_string(""" func.func @saxpy(%a : f64, %x : memref, %y : memref) { -%c0 = constant 0 : index -%n = dim %x, %c0 : memref +%c0 = arith.constant 0 : index +%n = memref.dim %x, %c0 : memref affine.for %i = 0 to %n { %xi = affine.load %x[%i+1] : memref - %axi = mulf %a, %xi : f64 + %axi = arith.mulf %a, %xi : f64 %yi = affine.load %y[%i] : memref - %axpyi = addf %yi, %axi : f64 + %axpyi = arith.addf %yi, %axi : f64 affine.store %axpyi, %y[%i] : memref } return @@ -60,11 +60,11 @@ def query(expr): for op in block.body + for_block.body if expr(op))) - assert query(Writes("%c0")).dump() == "%c0 = constant 0 : index" + assert query(Writes("%c0")).dump() == "%c0 = arith.constant 0 : index" assert (query(Reads("%y") & Isa(AffineLoadOp)).dump() == "%yi = affine.load %y [ %i ] : memref") - assert query(Reads(c0)).dump() == "%n = dim %x , %c0 : memref" + assert query(Reads(c0)).dump() == "%n = memref.dim %x , %c0 : memref" def test_build_with_queries(): @@ -92,7 +92,7 @@ def index(expr): with builder.goto_before(Reads(a0) & Reads(a1)): builder.addf(b0, b1, F64) - with builder.goto_after(Reads(b0) & Isa(AddfOperation)): + with builder.goto_after(Reads(b0) & Isa(AddFOperation)): builder.addf(c0, c1, F64) builder.func.ret() diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 4fc7f3f..565352a 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -53,11 +53,11 @@ def test_copy(): def test_dot(): assert_roundtrip_equivalence("""module { func.func @dot(%arg0: memref, %M: index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %1 = view %arg0 [ %c0 ] [ %M ] : memref to memref - %2 = view %arg0 [ %c0 ] [ %M ] : memref to memref - %3 = view %arg0 [ %c0 ] [ ] : memref to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = memref.view %arg0 [ %c0 ] [ %M ] : memref to memref + %2 = memref.view %arg0 [ %c0 ] [ %M ] : memref to memref + %3 = memref.view %arg0 [ %c0 ] [ ] : memref to memref linalg.dot ins ( %1 , %2 : memref , memref ) outs ( %3 : memref ) return } @@ -89,8 +89,8 @@ def test_generic(): func.func @example(%A: memref, %B: memref, %C: memref) { linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins ( %A, %B : memref, memref ) outs ( %C : memref ) { ^bb0 (%a: f64, %b: f64, %c: f64): - %c0 = constant 3.14 : f64 - %d = addf %a , %b : f64 + %c0 = arith.constant 3.14 : f64 + %d = arith.addf %a , %b : f64 linalg.yield %d : f64 } return @@ -103,12 +103,12 @@ def test_indexed_generic(): func.func @indexed_generic_region(%arg0: memref>, %arg1: memref>, %arg2: memref>) { linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins ( %arg0 : memref> ) outs ( %arg1, %arg2 : memref>, memref> ) { ^bb0 (%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): - %result_1 = mulf %a , %b : f32 - %ij = addi %i , %j : index - %ijk = addi %ij , %k : index - %ijk_int = index_cast %ijk : index to i32 - %ijk_float = sitofp %ijk_int : (i32) -> f32 - %result_2 = addf %c , %ijk_float : f32 + %result_1 = arith.mulf %a , %b : f32 + %ij = arith.addi %i , %j : index + %ijk = arith.addi %ij , %k : index + %ijk_int = arith.index_cast %ijk : index to i32 + %ijk_float = arith.sitofp %ijk_int : (i32) -> f32 + %result_2 = arith.addf %c , %ijk_float : f32 linalg.yield %result_1, %result_2 : f32, f32 } return @@ -119,7 +119,7 @@ def test_reduce(): assert_roundtrip_equivalence("""module { func.func @reduce(%arg0: tensor<16x32x64xf32>, %arg1: tensor<16x64xf32>) { %reduce = linalg.reduce ins ( %arg0 : tensor<16x32x64xf32> ) outs ( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) { - %0 = arith.addf %out, %in : f32 + %0 = arith.addf %out , %in : f32 linalg.yield %0 : f32 } return @@ -130,17 +130,17 @@ def test_reduce(): def test_view(): assert_roundtrip_equivalence("""module { func.func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { - %c0 = constant 0 : index - %0 = muli %arg0 , %arg0 : index - %1 = alloc (%0) : memref + %c0 = arith.constant 0 : index + %0 = arith.muli %arg0 , %arg0 : index + %1 = memref.alloc (%0) : memref %2 = linalg.range %arg0 : %arg1 : %arg 2 : !linalg.range - %3 = view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref + %3 = memref.view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref %4 = linalg.slice %3 [ %2, %2 ] : memref , !linalg.range, !linalg.range , memref %5 = linalg.slice %3 [ %2, %arg2 ] : memref , !linalg.range, index , memref> %6 = linalg.slice %3 [ %arg2, %2 ] : memref , index, !linalg.range , memref> %7 = linalg.slice %3 [ %arg2, %arg3 ] : memref , index, index , memref - %8 = view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref> - dealloc %1 : memref + %8 = memref.view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref> + memref.dealloc %1 : memref return } }""") @@ -149,11 +149,11 @@ def test_view(): def test_matmul(): assert_roundtrip_equivalence("""module { func.func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %A = view %arg0 [ %c0 ] [ %M, %K ] : memref to memref - %B = view %arg0 [ %c0 ] [ %K, %N ] : memref to memref - %C = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %A = memref.view %arg0 [ %c0 ] [ %M, %K ] : memref to memref + %B = memref.view %arg0 [ %c0 ] [ %K, %N ] : memref to memref + %C = memref.view %arg0 [ %c0 ] [ %M, %N ] : memref to memref linalg.matmul ins ( %A , %B : memref , memref ) outs ( %C : memref ) linalg.matmul ins ( %A , %B : memref , memref ) outs ( %C : memref ) -> memref return @@ -164,11 +164,11 @@ def test_matmul(): def test_matvec(): assert_roundtrip_equivalence("""module { func.func @matvec(%arg0: memref, %M: index, %N: index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %2 = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref - %3 = view %arg0 [ %c0 ] [ %M ] : memref to memref - %4 = view %arg0 [ %c0 ] [ %N ] : memref to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %2 = memref.view %arg0 [ %c0 ] [ %M, %N ] : memref to memref + %3 = memref.view %arg0 [ %c0 ] [ %M ] : memref to memref + %4 = memref.view %arg0 [ %c0 ] [ %N ] : memref to memref linalg.matvec ins ( %2 , %3 : memref , memref ) outs ( %4 : memref ) return } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 82e54a7..1154880 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -26,7 +26,7 @@ def test_function_no_args(): """ code = '''module { func.func @toy_func() -> index { - %0 = constant 0 : index + %0 = arith.constant 0 : index return %0 : index } }''' @@ -67,15 +67,15 @@ def test_affine_expr_roundtrip(): def test_loop_dialect_roundtrip(): src = """module { func.func @for(%outer: index, %A: memref, %B: memref, %C: memref, %result: memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %d0 = dim %A , %c0 : memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = memref.dim %A , %c0 : memref %b0 = affine.min affine_map<()[s0, s1] -> (1024, s0 - s1)> ()[%d0, %outer] scf.for %i0 = %c0 to %b0 step %c1 { - %B_elem = load %B [ %i0 ] : memref - %C_elem = load %C [ %i0 ] : memref - %sum_elem = addf %B_elem , %C_elem : f32 - store %sum_elem , %result [ %i0 ] : memref + %B_elem = memref.load %B [ %i0 ] : memref + %C_elem = memref.load %C [ %i0 ] : memref + %sum_elem = arith.addf %B_elem , %C_elem : f32 + memref.store %sum_elem , %result [ %i0 ] : memref } return } diff --git a/tests/test_scf.py b/tests/test_scf.py index 7f267ac..17afa13 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -11,8 +11,8 @@ def test_scf_for(): assert_roundtrip_equivalence("""module { func.func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index, %sum_0: f32) -> (f32) { %sum = scf.for %iv = %lb to %ub step %step iter_args ( %sum_iter = %sum_0 ) -> ( f32 ) { - %t = load %buffer [ %iv ] : memref<1024xf32> - %sum_next = arith.addf %sum_iter, %t : f32 + %t = memref.load %buffer [ %iv ] : memref<1024xf32> + %sum_next = arith.addf %sum_iter , %t : f32 scf.yield %sum_next : f32 } return %sum : f32 diff --git a/tests/test_syntax.py b/tests/test_syntax.py index 087c662..028ea56 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -46,7 +46,7 @@ def test_trailing_loc(parser: Optional[Parser] = None): code = ''' module { func.func @myfunc() { - %c:2 = addf %a, %b : f32 loc("test_syntax.py":36:59) + %c:2 = arith.addf %a, %b : f32 loc("test_syntax.py":36:59) } } loc("hi.mlir":30:1) ''' @@ -86,12 +86,12 @@ def test_functions(parser: Optional[Parser] = None): code = ''' module { func.func @myfunc_a() { - %c:2 = addf %a, %b : f32 + %c:2 = arith.addf %a, %b : f32 } func.func @myfunc_b() { - %d:2 = addf %a, %b : f64 + %d:2 = arith.addf %a, %b : f64 ^e: - %f:2 = addf %d, %d : f64 + %f:2 = arith.addf %d, %d : f64 } }''' parser = parser or Parser() @@ -135,17 +135,17 @@ def test_affine(parser: Optional[Parser] = None): %0 = affine.min affine_map<(d0)[s0] -> (1000, d0 + 512, s0)> (%arg0)[%arg1] } func.func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) { - %c0 = constant 1 : index - %c1 = constant 0 : index - %b = alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)> - %0 = alloc(%arg0, %arg1) : memref + %c0 = arith.constant 1 : index + %c1 = arith.constant 0 : index + %b = memref.alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)> + %0 = memref.alloc(%arg0, %arg1) : memref affine.for %arg3 = %arg1 to %arg2 step 768 { - %13 = dim %0, %c1 : memref + %13 = memref.dim %0, %c1 : memref affine.for %arg4 = 0 to %13 step 264 { - %18 = dim %0, %c0 : memref - %20 = subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref + %18 = memref.dim %0, %c0 : memref + %20 = memref.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref to memref (d0 * s1 + d1 * s2 + s0)> - %24 = dim %20, %c0 : memref (d0 * s1 + d1 * s2 + s0)> + %24 = memref.dim %20, %c0 : memref (d0 * s1 + d1 * s2 + s0)> affine.for %arg5 = 0 to %24 step 768 { "foo"() : () -> () } @@ -210,8 +210,8 @@ def test_generic_dialect_std(parser: Optional[Parser] = None): "module"() ( { "func.func"() ( { ^bb0(%arg0: i32, %arg1: i32): // no predecessors - %0 = "std.addi"(%arg1, %arg0) : (i32, i32) -> i32 - "std.return"(%0) : (i32) -> () + %0 = "arith.addi"(%arg1, %arg0) : (i32, i32) -> i32 + "return"(%0) : (i32) -> () }) {sym_name = "mlir_entry", type = (i32, i32) -> i32} : () -> () }) : () -> () ''' @@ -224,13 +224,13 @@ def test_generic_dialect_std_cond_br(parser: Optional[Parser] = None): "module"() ( { "func.func"() ( { ^bb0(%arg0: i32): // no predecessors - %c1_i32 = "std.constant"() {value = 1 : i32} : () -> i32 - %0 = "std.cmpi"(%arg0, %c1_i32) {predicate = 3 : i64} : (i32, i32) -> i1 - "std.cond_br"(%0)[^bb1, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () + %c1_i32 = "arith.constant"() {value = 1 : i32} : () -> i32 + %0 = "arith.cmpi"(%arg0, %c1_i32) {predicate = 3 : i64} : (i32, i32) -> i1 + "cf.cond_br"(%0)[^bb1, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () ^bb1: // pred: ^bb0 - "std.return"(%c1_i32) : (i32) -> () + "return"(%c1_i32) : (i32) -> () ^bb2: // pred: ^bb0 - "std.return"(%c1_i32) : (i32) -> () + "return"(%c1_i32) : (i32) -> () }) {sym_name = "mlir_entry", type = (i32) -> i32} : () -> () }) : () -> () ''' @@ -259,26 +259,26 @@ def test_generic_dialect_generic_op(parser: Optional[Parser] = None): "func.func"() ( { ^bb0(%arg0: i32, %arg1: i32): // no predecessors %0 = "generic_op_with_region"(%arg0, %arg1) ( { - %1 = "std.addi"(%arg1, %arg0) : (i32, i32) -> i32 - "std.return"(%1) : (i32) -> () + %1 = "arith.addi"(%arg1, %arg0) : (i32, i32) -> i32 + "return"(%1) : (i32) -> () }) : (i32, i32) -> i32 %2 = "generic_op_with_regions"(%0, %arg0) ( { - %3 = "std.subi"(%0, %arg0) : (i32, i32) -> i32 - "std.return"(%3) : (i32) -> () + %3 = "arith.subi"(%0, %arg0) : (i32, i32) -> i32 + "return"(%3) : (i32) -> () }, { - %4 = "std.addi"(%3, %arg0) : (i32, i32) -> i32 - "std.return"(%4) : (i32) -> () + %4 = "arith.addi"(%3, %arg0) : (i32, i32) -> i32 + "return"(%4) : (i32) -> () }) : (i32, i32) -> i32 %5 = "generic_op_with_region_and_attr"(%2, %arg0) ( { - %6 = "std.subi"(%2, %arg0) : (i32, i32) -> i32 - "std.return"(%6) : (i32) -> () + %6 = "arith.subi"(%2, %arg0) : (i32, i32) -> i32 + "return"(%6) : (i32) -> () }) {attr = "string attribute"} : (i32, i32) -> i32 %7 = "generic_op_with_region_and_successor"(%5, %arg0)[^bb1] ( { - %8 = "std.addi"(%5, %arg0) : (i32, i32) -> i32 - "std.br"(%8)[^bb1] : (i32) -> () + %8 = "arith.addi"(%5, %arg0) : (i32, i32) -> i32 + "cf.br"(%8)[^bb1] : (i32) -> () }) {attr = "string attribute"} : (i32, i32) -> i32 ^bb1(%ret: i32): - "std.return"(%ret) : (i32) -> () + "return"(%ret) : (i32) -> () }) {sym_name = "mlir_entry", type = (i32, i32) -> i32} : () -> () }) : () -> () ''' diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 7fe2b40..7e105c3 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -14,11 +14,11 @@ def parser(parser: Optional[Parser] = None) -> Parser: _code = ''' module { func.func @test0(%arg0: index, %arg1: index) { - %0 = alloc() : memref<100x100xf32> - %1 = alloc() : memref<100x100xf32, 2> - %2 = alloc() : memref<1xi32> - %c0 = constant 0 : index - %c64 = constant 64 : index + %0 = memref.alloc() : memref<100x100xf32> + %1 = memref.alloc() : memref<100x100xf32, 2> + %2 = memref.alloc() : memref<1xi32> + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index affine.for %arg2 = 0 to 10 { affine.for %arg3 = 0 to 10 { affine.dma_start %0[%arg2, %arg3], %1[%arg2, %arg3], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> @@ -30,10 +30,10 @@ def parser(parser: Optional[Parser] = None) -> Parser: func.func @test1(%arg0: index, %arg1: index) { affine.for %arg2 = 0 to 10 { affine.for %arg3 = 0 to 10 { - %c0 = constant 0 : index - %c64 = constant 64 : index - %c128 = constant 128 : index - %c256 = constant 256 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index affine.dma_start %0[%arg2, %arg3], %1[%arg2, %arg3], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> affine.dma_wait %2[%c0], %c64 : memref<1xi32> } @@ -41,7 +41,7 @@ def parser(parser: Optional[Parser] = None) -> Parser: return } func.func @test2(%arg0: index, %arg1: index) { - %0 = alloc() : memref<100x100xf32> + %0 = memref.alloc() : memref<100x100xf32> } } '''