From 3f40f4592f06080f88e13ee2dce7d65ab6d69e0b Mon Sep 17 00:00:00 2001 From: 3ulalia <179992797+3ulalia@users.noreply.github.com> Date: Tue, 8 Apr 2025 18:35:58 -0400 Subject: [PATCH 1/8] Change LoadOperation and StoreOperation to use MemRef dialect Currently, LoadOperation and StoreOperation expect the instructions to just be "load" and "store", respectively. This is not possible in current MLIR - indeed, I don't know if it was ever possible. Regardless, since the operations involve MemRef types anyway, this isn't a significant change, and just makes it compliant with modern MLIR. Specifically, this change adds the MemRef dialect, whose currently supported operations are `load` and `store`. (more to come!) --- mlir/dialects/__init__.py | 3 ++- mlir/dialects/memref.py | 31 +++++++++++++++++++++++++++++++ mlir/dialects/standard.py | 17 ----------------- 3 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 mlir/dialects/memref.py diff --git a/mlir/dialects/__init__.py b/mlir/dialects/__init__.py index eefd918..9056c81 100644 --- a/mlir/dialects/__init__.py +++ b/mlir/dialects/__init__.py @@ -3,6 +3,7 @@ from .scf import scf as scf_dialect from .linalg import linalg from .func import func as func_dialect +from .memref import memref as memref_dialect -STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect] +STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect, memref_dialect] diff --git a/mlir/dialects/memref.py b/mlir/dialects/memref.py new file mode 100644 index 0000000..1131133 --- /dev/null +++ b/mlir/dialects/memref.py @@ -0,0 +1,31 @@ +""" Implementation of the Memref dialect. """ + +import inspect +import sys +from typing import List, Tuple, Optional, Union +from dataclasses import dataclass + +import mlir.astnodes as mast +from mlir.dialect import Dialect, DialectOp, is_op + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + +@dataclass +class LoadOperation(DialectOp): + arg: SsaUse + index: List[SsaUse] + type: mast.MemRefType + _syntax_ = 'memref.load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + +@dataclass +class StoreOperation(DialectOp): + addr: SsaUse + ref: SsaUse + index: List[SsaUse] + type: mast.MemRefType + _syntax_ = 'memref.store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + +# Inspect current module to get all classes defined above +memref = Dialect('memref', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/standard.py b/mlir/dialects/standard.py index 6eb6168..8050310 100644 --- a/mlir/dialects/standard.py +++ b/mlir/dialects/standard.py @@ -96,14 +96,6 @@ class ExtractElementOperation(DialectOp): _syntax_ = 'extract_element {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}' -@dataclass -class LoadOperation(DialectOp): - arg: SsaUse - index: List[SsaUse] - type: mast.MemRefType - _syntax_ = 'load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' - - @dataclass class SplatOperation(DialectOp): arg: SsaUse @@ -111,15 +103,6 @@ class SplatOperation(DialectOp): _syntax_ = 'splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type) -@dataclass -class StoreOperation(DialectOp): - addr: SsaUse - ref: SsaUse - index: List[SsaUse] - type: mast.MemRefType - _syntax_ = 'store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' - - @dataclass class TensorLoadOperation(DialectOp): arg: SsaUse From e44e4275f4981594e2a294818edad8f39a2456b2 Mon Sep 17 00:00:00 2001 From: qubitter <11893617+qubitter@users.noreply.github.com> Date: Tue, 20 May 2025 12:42:43 -0400 Subject: [PATCH 2/8] make type hints slightly more reliable --- mlir/astnodes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/astnodes.py b/mlir/astnodes.py index c5b3cc8..065bb20 100644 --- a/mlir/astnodes.py +++ b/mlir/astnodes.py @@ -485,11 +485,15 @@ def dump(self, indent: int = 0) -> str: return self.value.dump(indent) + ( (':' + dump_or_value(self.count, indent)) if self.count else '') +class Op(Node): + pass + + @dataclass class Operation(Node): result_list: List[OpResult] - op: "Op" + op: Node location: Optional["Location"] = None def dump(self, indent: int = 0) -> str: @@ -503,10 +507,6 @@ def dump(self, indent: int = 0) -> str: return result -class Op(Node): - pass - - @dataclass class GenericOperation(Op): name: str From d341cd472ffe1b959e5ae0d0a98b5f507153dcd6 Mon Sep 17 00:00:00 2001 From: qubitter <11893617+qubitter@users.noreply.github.com> Date: Tue, 20 May 2025 12:43:53 -0400 Subject: [PATCH 3/8] Modernize dialect structure to cohere with modern MLIR This commit removes the outdated model of a "standard dialect" and replaces it with individual, specialized dialects. This includes CF, Math, and Tensor, as well as increasing the size of Memref and SCF. This will bring pymlir in line with modern MLIR standards. --- mlir/__init__.py | 1 + mlir/dialects/__init__.py | 7 +- mlir/dialects/arith.py | 120 ++++++++++++++++++++ mlir/dialects/cf.py | 32 ++++++ mlir/dialects/math.py | 20 ++++ mlir/dialects/memref.py | 153 +++++++++++++++++++++++++ mlir/dialects/scf.py | 7 ++ mlir/dialects/standard.py | 227 -------------------------------------- mlir/dialects/tensor.py | 36 ++++++ 9 files changed, 374 insertions(+), 229 deletions(-) create mode 100644 mlir/dialects/arith.py create mode 100644 mlir/dialects/cf.py create mode 100644 mlir/dialects/math.py delete mode 100644 mlir/dialects/standard.py create mode 100644 mlir/dialects/tensor.py diff --git a/mlir/__init__.py b/mlir/__init__.py index 0df1233..5fe79ab 100644 --- a/mlir/__init__.py +++ b/mlir/__init__.py @@ -1,3 +1,4 @@ from .parser import parse_file, parse_path, parse_string, Parser from . import astnodes +from . import dialects from .visitors import NodeVisitor, NodeTransformer diff --git a/mlir/dialects/__init__.py b/mlir/dialects/__init__.py index 9056c81..9305eb4 100644 --- a/mlir/dialects/__init__.py +++ b/mlir/dialects/__init__.py @@ -1,9 +1,12 @@ from .affine import affine as affine_dialect -from .standard import standard as std_dialect +from .cf import cf as cf_dialect +from .math import math as math_dialect +from .tensor import tensor as tensor_dialect +from .arith import arith as arith_dialect from .scf import scf as scf_dialect from .linalg import linalg from .func import func as func_dialect from .memref import memref as memref_dialect -STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect, memref_dialect] +STANDARD_DIALECTS = [affine_dialect, cf_dialect, math_dialect, tensor_dialect, arith_dialect, scf_dialect, linalg, func_dialect, memref_dialect] diff --git a/mlir/dialects/arith.py b/mlir/dialects/arith.py new file mode 100644 index 0000000..5a2b5db --- /dev/null +++ b/mlir/dialects/arith.py @@ -0,0 +1,120 @@ +""" Implementation of the arith (Arithmetic) dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation, BinaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple, Union + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + + +# Unary Operations + +class BitcastOperation(UnaryOperation): _opname_ = 'arith.bitcast' +class ExtFOperation(UnaryOperation): _opname_ = 'arith.extf' +class ExtSIOperation(UnaryOperation): _opname_ = 'arith.extsi' +class ExtUIOperation(UnaryOperation): _opname_ = 'arith.extui' +class FPToSIOperation(UnaryOperation): _opname_ = 'arith.fptosi' +class FPToUIOperation(UnaryOperation): _opname_ = 'arith.fptoui' +class NegFOperation(UnaryOperation): _opname_ = 'arith.negf' +class SIToFPOperation(UnaryOperation): _opname_ = 'arith.sitofp' +class UIToFPOperation(UnaryOperation): _opname_ = 'arith.uitofp' + +# Arithmetic Operations +class AddFOperation(BinaryOperation): _opname_ = 'arith.addf' +class AddIOperation(BinaryOperation): _opname_ = 'arith.addi' +class AndIOperation(BinaryOperation): _opname_ = 'arith.andi' +class CeilDivSIOperation(BinaryOperation): _opname_ = 'arith.ceildivsi' +class CeilDivUIOperation(BinaryOperation): _opname_ = 'arith.ceildivui' +class DivFOperation(BinaryOperation): _opname_ = 'arith.divf' +class DivSIOperation(BinaryOperation): _opname_ = 'arith.divsi' +class DivUIOperation(BinaryOperation): _opname_ = 'arith.divui' +class FloorDivSIOperation(BinaryOperation): _opname_ = 'arith.floordivsi' +class MaximumFOperation(BinaryOperation): _opname_ = 'arith.maximumf' +class MaxNumFOperation(BinaryOperation): _opname_ = 'arith.maxnumf' +class MaxSIOperation(BinaryOperation): _opname_ = 'arith.maxsi' +class MaxUIOperation(BinaryOperation): _opname_ = 'arith.maxui' +class MinimumFOperation(BinaryOperation): _opname_ = 'arith.minimumf' +class MinNumFOperation(BinaryOperation): _opname_ = 'arith.minnumf' +class MinSIOperation(BinaryOperation): _opname_ = 'arith.minsi' +class MinUIOperation(BinaryOperation): _opname_ = 'arith.minui' +class MulFOperation(BinaryOperation): _opname_ = 'arith.mulf' +class MulIOperation(BinaryOperation): _opname_ = 'arith.muli' +class MulSIExtendedOp(BinaryOperation): _opname_ = 'arith.mulsi_extended' +class MulUIExtendedOp(BinaryOperation): _opname_ = 'arith.mului_extended' +class OrIOperation(BinaryOperation): _opname_ = 'arith.ori' +class RemFOperation(BinaryOperation): _opname_ = 'arith.remf' +class RemSIOperation(BinaryOperation): _opname_ = 'arith.remsi' +class RemUIOperation(BinaryOperation): _opname_ = 'arith.remui' +class ShLIOperation(BinaryOperation): _opname_ = 'arith.shli' +class ShRSIOperation(BinaryOperation): _opname_ = 'arith.shrsi' +class ShRUIOperation(BinaryOperation): _opname_ = 'arith.shrui' +class SubIOperation(BinaryOperation): _opname_ = 'arith.subi' +class SubFOperation(BinaryOperation): _opname_ = 'arith.subf' +class TruncFOperation(BinaryOperation): _opname_ = 'arith.truncf' +class TruncIOperation(BinaryOperation): _opname_ = 'arith.trunci' +class XorIOperation(BinaryOperation): _opname_ = 'arith.xori' + + +@dataclass +class AddUIExtendedOperation(DialectOp): + lhs_operand: mast.SsaId + rhs_operand: mast.SsaId + sum_type: mast.Type + ovf_type = mast.Type + _syntax_ = 'arith.addui_extended {lhs_operand.ssa_id} , {rhs_operand.ssa_id} : {sum_type.type} , {ovf_type.type}' + + +@dataclass +class CmpiOperation(DialectOp): + comptype: str + operand_a: mast.SsaId + operand_b: mast.SsaId + type: mast.Type + _syntax_ = 'arith.cmpi {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' + + +@dataclass +class CmpfOperation(DialectOp): + comptype: str + operand_a: mast.SsaId + operand_b: mast.SsaId + type: mast.Type + _syntax_ = 'arith.cmpf {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' + + +@dataclass +class ConstantOperation(DialectOp): + value: Literal + type: mast.Type + _syntax_ = 'arith.constant {value.constant_literal} : {type.type}' + + +@dataclass +class IndexCastOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'arith.index_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +@dataclass +class IndexCastUIOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'arith.index_castui {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +@dataclass +class SelectOperation(DialectOp): + cond: SsaUse + arg_true: SsaUse + arg_false: SsaUse + _syntax_ = 'arith.select {cond.ssa_use} , {arg_true.ssa_use} , {arg_false.ssa_use} : {type.type}' + + +# Inspect current module to get all classes defined above +arith = Dialect('arith', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/cf.py b/mlir/dialects/cf.py new file mode 100644 index 0000000..8ab338f --- /dev/null +++ b/mlir/dialects/cf.py @@ -0,0 +1,32 @@ +""" Implementation of the CF (Control Flow) dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple, Union + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + + +@dataclass +class BrOperation(DialectOp): + block_id: mast.BlockId + args: Optional[List[Tuple[mast.SsaId, mast.Type]]] = None + _syntax_ = ['cf.br {block.block_id}', + 'cf.br {block.block_id} {args.block_arg_list}'] + + +@dataclass +class CondBrOperation(DialectOp): + cond: SsaUse + block_true: mast.BlockId + block_false: mast.BlockId + _syntax_ = ['cf.cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}'] + + +# Inspect current module to get all classes defined above +cf = Dialect('cf', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/math.py b/mlir/dialects/math.py new file mode 100644 index 0000000..b574cca --- /dev/null +++ b/mlir/dialects/math.py @@ -0,0 +1,20 @@ +""" Implementation of the math (Mathematics) dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple + + +# Unary Operations +class AbsfOperation(UnaryOperation): _opname_ = 'math.absf' +class CosOperation(UnaryOperation): _opname_ = 'math.cos' +class ExpOperation(UnaryOperation): _opname_ = 'math.exp' +class TanhOperation(UnaryOperation): _opname_ = 'math.tanh' +class CopysignOperation(UnaryOperation): _opname_ = 'math.copysign' + +# Inspect current module to get all classes defined above +math = Dialect('math', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/memref.py b/mlir/dialects/memref.py index 1131133..e52f70f 100644 --- a/mlir/dialects/memref.py +++ b/mlir/dialects/memref.py @@ -11,6 +11,27 @@ Literal = Union[mast.StringLiteral, float, int, bool] SsaUse = Union[mast.SsaId, Literal] + +# class AssumeAlignmentOperation(DialectOp): pass +# class AtomicRMWOperation(DialectOp): pass + +@dataclass +class AtomicYieldOperation(DialectOp): + result: SsaUse + result_type: mast.Type + _syntax_ = 'memref.atomic_yield {result.ssa_use} : {result_type.type}' + +@dataclass +class CopyOperation(DialectOp): + source: SsaUse + target: SsaUse + source_type: mast.MemRefType + target_type: mast.MemRefType + _syntax_ = 'memref.copy {source.ssa_use} , {target.ssa_use} : {source.memref_type} to {target.memref_type}' + + +# class GenericAtomicRMWOperation(DialectOp): pass + @dataclass class LoadOperation(DialectOp): arg: SsaUse @@ -18,6 +39,113 @@ class LoadOperation(DialectOp): type: mast.MemRefType _syntax_ = 'memref.load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + +@dataclass +class AllocOperation(DialectOp): + args: mast.DimAndSymbolList + type: mast.MemRefType + _syntax_ = 'memref.alloc {args.dim_and_symbol_use_list} : {type.memref_type}' + +@dataclass +class AllocaOperation(DialectOp): + args: mast.DimAndSymbolList + type: mast.MemRefType + _syntax_ = 'memref.alloca {args.dim_and_symbol_use_list} : {type.memref_type}' + +# class AllocaScopeOperation(DialectOp): pass +# class AllocaScopeReturnOperation(DialectOp): pass + +@dataclass +class CastOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'memref_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +# class CollapseShapeOperation(DialectOp): pass + +@dataclass +class DeallocOperation(DialectOp): + arg: SsaUse + type: mast.MemRefType + _syntax_ = 'memref.dealloc {arg.ssa_use} : {type.memref_type}' + +@dataclass +class DimOperation(DialectOp): + operand: mast.SsaId + index: mast.SsaId + type: mast.Type + _syntax_ = 'memref.dim {operand.ssa_id} , {index.ssa_id} : {type.type}' + +@dataclass +class DmaStartOperation(DialectOp): + src: SsaUse + src_index: List[SsaUse] + dst: SsaUse + dst_index: List[SsaUse] + size: SsaUse + tag: SsaUse + tag_index: List[SsaUse] + src_type: mast.MemRefType + dst_type: mast.MemRefType + tag_type: mast.MemRefType + stride: Optional[SsaUse] = None + transfer_per_stride: Optional[SsaUse] = None + _syntax_ = [ + 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}', + 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}' + ] + + +@dataclass +class DmaWaitOperation(DialectOp): + tag: SsaUse + tag_index: List[SsaUse] + size: SsaUse + type: mast.MemRefType + _syntax_ = 'dma_wait {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {size.ssa_use} : {type.memref_type}' + +# class ExpandShapeOperation(DialectOp): pass + +@dataclass +class ExtractAlignedPointerAsIndexOperation(DialectOp): + source: SsaUse + source_type: mast.Type + dest_type: mast.Type + _syntax_ = 'memref.extract_aligned_pointer_as_index {source.ssa_use} : {source_type.type} -> {dest_type.type}' + +# class ExtractStridedMetadataOperation(DialectOp): pass + +@dataclass +class GetGlobalOperation(DialectOp): + name: SsaUse + result_type: mast.MemRefType + _syntax_ = 'memref.get_global {name.ssa_use} : {result_type.type}' + + +# class GlobalOperation(DialectOp): pass + +@dataclass +class MemorySpaceCastOperation(DialectOp): + source: SsaUse + source_type: mast.MemRefType + dest_type: mast.MemRefType + _syntax_ = 'memref.memory_space_cast {source.ssa_use} : {source_type.memref_type} to {dest_type.memref_type}' + + +# class PrefetchOperation(DialectOp): pass + +@dataclass +class RankOperation(DialectOp): + operand: SsaUse + op_type: mast.MemRefType + _syntax_ = 'memref.rank {operand.ssa_use} : {op_type.memref_type}' + + +# class ReallocOperation(DialectOp): pass +# class ReinterpretCastOperation(DialectOp): pass +# class ReshapeOperation(DialectOp): pass + @dataclass class StoreOperation(DialectOp): addr: SsaUse @@ -26,6 +154,31 @@ class StoreOperation(DialectOp): type: mast.MemRefType _syntax_ = 'memref.store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + +# class TransposeOperation(DialectOp): pass + +@dataclass +class ViewOperation(DialectOp): + operand: SsaUse + offset: SsaUse + src_type: mast.Type + dst_type: mast.Type + sizes: Optional[List[SsaUse]] = None + _syntax_ = ['memref.view {operand.ssa_use} [ {offset.ssa_use} ] [ {sizes.ssa_use_list} ] : {src_type.type} to {dst_type.type}', + 'memref.view {operand.ssa_use} [ {offset.ssa_use} ] [ ] : {src_type.type} to {dst_type.type}'] + + +@dataclass +class SubviewOperation(DialectOp): + operand: SsaUse + offsets: List[SsaUse] + sizes: List[SsaUse] + strides: List[SsaUse] + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'memref.subview {operand.ssa_use} [ {offsets.ssa_use_list} ] [ {sizes.ssa_use_list} ] [ {strides.ssa_use_list} ] : {src_type.type} to {dst_type.type}' + + # Inspect current module to get all classes defined above memref = Dialect('memref', ops=[m[1] for m in inspect.getmembers( sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index 8c71848..082dee3 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -16,6 +16,13 @@ class SCFConditionOp(DialectOp): _syntax_ = ['scf.condition ( {condition.ssa_id} ) {args.ssa_id_list} : {out_types.type_list_no_parens}'] +@dataclass +class SCFExecuteRegionOp(DialectOp): + out_type: mast.Type + body: mast.Region + _syntax_ = 'scf.execute_region -> {out_type.type} {body.region}' + + @dataclass class SCFForOp(DialectOp): index: mast.SsaId diff --git a/mlir/dialects/standard.py b/mlir/dialects/standard.py deleted file mode 100644 index 8050310..0000000 --- a/mlir/dialects/standard.py +++ /dev/null @@ -1,227 +0,0 @@ -""" Implementation of the Standard dialect. """ - -import inspect -import sys -from mlir.dialect import (Dialect, DialectOp, UnaryOperation, BinaryOperation, - is_op) -import mlir.astnodes as mast -from typing import List, Tuple, Optional, Union -from dataclasses import dataclass - - -Literal = Union[mast.StringLiteral, float, int, bool] -SsaUse = Union[mast.SsaId, Literal] - - -# Terminator Operations -@dataclass -class BrOperation(DialectOp): - block_id: mast.BlockId - args: Optional[List[Tuple[mast.SsaId, mast.Type]]] = None - _syntax_ = ['br {block.block_id}', - 'br {block.block_id} {args.block_arg_list}'] - - -@dataclass -class CondBrOperation(DialectOp): - cond: SsaUse - block_true: mast.BlockId - block_false: mast.BlockId - _syntax_ = ['cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}'] - - -# Core Operations -@dataclass -class DimOperation(DialectOp): - operand: mast.SsaId - index: mast.SsaId - type: mast.Type - _syntax_ = 'dim {operand.ssa_id} , {index.ssa_id} : {type.type}' - - -# Memory Operations -@dataclass -class AllocOperation(DialectOp): - args: mast.DimAndSymbolList - type: mast.MemRefType - _syntax_ = 'alloc {args.dim_and_symbol_use_list} : {type.memref_type}' - - -@dataclass -class AllocStaticOperation(DialectOp): - base: int - type: mast.MemRefType - _syntax_ = 'alloc_static ( {base.integer_literal} ) : {type.memref_type}' - - -@dataclass -class DeallocOperation(DialectOp): - arg: SsaUse - type: mast.MemRefType - _syntax_ = 'dealloc {arg.ssa_use} : {type.memref_type}' - - -@dataclass -class DmaStartOperation(DialectOp): - src: SsaUse - src_index: List[SsaUse] - dst: SsaUse - dst_index: List[SsaUse] - size: SsaUse - tag: SsaUse - tag_index: List[SsaUse] - src_type: mast.MemRefType - dst_type: mast.MemRefType - tag_type: mast.MemRefType - stride: Optional[SsaUse] = None - transfer_per_stride: Optional[SsaUse] = None - _syntax_ = [ - 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}', - 'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}' - ] -@dataclass -class DmaWaitOperation(DialectOp): - tag: SsaUse - tag_index: List[SsaUse] - size: SsaUse - type: mast.MemRefType - _syntax_ = 'dma_wait {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {size.ssa_use} : {type.memref_type}' - - -@dataclass -class ExtractElementOperation(DialectOp): - arg: SsaUse - index: List[SsaUse] - type: mast.Type - _syntax_ = 'extract_element {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}' - - -@dataclass -class SplatOperation(DialectOp): - arg: SsaUse - type: Union[mast.VectorType, mast.TensorType] - _syntax_ = 'splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type) - - -@dataclass -class TensorLoadOperation(DialectOp): - arg: SsaUse - type: mast.Type - _syntax_ = 'tensor_load {arg.ssa_use} : {type.type}' - - -@dataclass -class TensorStoreOperation(DialectOp): - arg: SsaUse - type: mast.Type - _syntax_ = 'tensor_store {src.ssa_use} , {dst.ssa_use} : {type.memref_type}' - -# Unary Operations -class AbsfOperation(UnaryOperation): _opname_ = 'absf' -class CeilfOperation(UnaryOperation): _opname_ = 'ceilf' -class CosOperation(UnaryOperation): _opname_ = 'cos' -class ExpOperation(UnaryOperation): _opname_ = 'exp' -class NegfOperation(UnaryOperation): _opname_ = 'negf' -class TanhOperation(UnaryOperation): _opname_ = 'tanh' -class CopysignOperation(UnaryOperation): _opname_ = 'copysign' -class SIToFPOperation(UnaryOperation): _opname_ = 'sitofp' - -# Arithmetic Operations -class AddiOperation(BinaryOperation): _opname_ = 'addi' -class AddfOperation(BinaryOperation): _opname_ = 'addf' -class AndOperation(BinaryOperation): _opname_ = 'and' -class DivisOperation(BinaryOperation): _opname_ = 'divis' -class DiviuOperation(BinaryOperation): _opname_ = 'diviu' -class RemisOperation(BinaryOperation): _opname_ = 'remis' -class RemiuOperation(BinaryOperation): _opname_ = 'remiu' -class DivfOperation(BinaryOperation): _opname_ = 'divf' -class MulfOperation(BinaryOperation): _opname_ = 'mulf' -class MulIOperation(BinaryOperation): _opname_ = 'muli' -class SubiOperation(BinaryOperation): _opname_ = 'subi' -class SubfOperation(BinaryOperation): _opname_ = 'subf' -class OrOperation(BinaryOperation): _opname_ = 'or' -class XorOperation(BinaryOperation): _opname_ = 'xor' - - -@dataclass -class CmpiOperation(DialectOp): - comptype: str - operand_a: mast.SsaId - operand_b: mast.SsaId - type: mast.Type - _syntax_ = 'cmpi {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' - - -@dataclass -class CmpfOperation(DialectOp): - comptype: str - operand_a: mast.SsaId - operand_b: mast.SsaId - type: mast.Type - _syntax_ = 'cmpf {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}' - - -@dataclass -class ConstantOperation(DialectOp): - value: Literal - type: mast.Type - _syntax_ = 'constant {value.constant_literal} : {type.type}' - - -@dataclass -class IndexCastOperation(DialectOp): - arg: SsaUse - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'index_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' - - -@dataclass -class MemrefCastOperation(DialectOp): - arg: SsaUse - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'memref_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' - - -@dataclass -class TensorCastOperation(DialectOp): - arg: SsaUse - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'tensor_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' - - -@dataclass -class SelectOperation(DialectOp): - cond: SsaUse - arg_true: SsaUse - arg_false: SsaUse - _syntax_ = 'select {cond.ssa_use} , {arg_true.ssa_use} , {arg_false.ssa_use} : {type.type}' - - -@dataclass -class SubviewOperation(DialectOp): - operand: SsaUse - offsets: List[SsaUse] - sizes: List[SsaUse] - strides: List[SsaUse] - src_type: mast.Type - dst_type: mast.Type - _syntax_ = 'subview {operand.ssa_use} [ {offsets.ssa_use_list} ] [ {sizes.ssa_use_list} ] [ {strides.ssa_use_list} ] : {src_type.type} to {dst_type.type}' - - -@dataclass -class ViewOperation(DialectOp): - operand: SsaUse - offset: SsaUse - src_type: mast.Type - dst_type: mast.Type - sizes: Optional[List[SsaUse]] = None - _syntax_ = ['view {operand.ssa_use} [ {offset.ssa_use} ] [ {sizes.ssa_use_list} ] : {src_type.type} to {dst_type.type}', - 'view {operand.ssa_use} [ {offset.ssa_use} ] [ ] : {src_type.type} to {dst_type.type}'] - - -# Inspect current module to get all classes defined above -standard = Dialect('standard', ops=[m[1] for m in inspect.getmembers( - sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/tensor.py b/mlir/dialects/tensor.py new file mode 100644 index 0000000..758b758 --- /dev/null +++ b/mlir/dialects/tensor.py @@ -0,0 +1,36 @@ +""" Implementation of the Tensor dialect. """ + +import inspect +import sys +from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation +import mlir.astnodes as mast +from dataclasses import dataclass +from typing import Optional, List, Tuple, Union + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + +@dataclass +class ExtractElementOperation(DialectOp): + arg: SsaUse + index: List[SsaUse] + type: mast.Type + _syntax_ = 'tensor.extract {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}' + + +@dataclass +class SplatOperation(DialectOp): + arg: SsaUse + type: Union[mast.VectorType, mast.TensorType] + _syntax_ = 'tensor.splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type) + +@dataclass +class TensorCastOperation(DialectOp): + arg: SsaUse + src_type: mast.Type + dst_type: mast.Type + _syntax_ = 'tensor.cast {arg.ssa_use} : {src_type.type} to {dst_type.type}' + +# Inspect current module to get all classes defined above +tensor = Dialect('tensor', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) From 82b9f1a3e80a19c8323cb1e67dba78f1a30fc726 Mon Sep 17 00:00:00 2001 From: qubitter <11893617+qubitter@users.noreply.github.com> Date: Tue, 20 May 2025 12:46:36 -0400 Subject: [PATCH 4/8] update tests for modernization --- mlir/builder/builder.py | 13 ++++----- tests/test_builder.py | 16 +++++------ tests/test_linalg.py | 60 ++++++++++++++++++++--------------------- tests/test_roundtrip.py | 16 +++++------ tests/test_scf.py | 4 +-- tests/test_syntax.py | 60 ++++++++++++++++++++--------------------- tests/test_visitors.py | 20 +++++++------- 7 files changed, 95 insertions(+), 94 deletions(-) diff --git a/mlir/builder/builder.py b/mlir/builder/builder.py index d9cdb4e..ed52506 100644 --- a/mlir/builder/builder.py +++ b/mlir/builder/builder.py @@ -1,7 +1,8 @@ """ MLIR IR Builder.""" import mlir.astnodes as mast -import mlir.dialects.standard as std +import mlir.dialects.arith as arith +import mlir.dialects.memref as memref import mlir.dialects.affine as affine import mlir.dialects.func as func from typing import Optional, Tuple, Union, List, Any @@ -436,28 +437,28 @@ def goto_after(self, query: MatchExpressionBase, def addf(self, op_a: mast.SsaId, op_b: mast.SsaId, type: mast.Type, name: Optional[str] = None): - op = std.AddfOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) + op = arith.AddFOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) return self._insert_op_in_block([name], op) def mulf(self, op_a: mast.SsaId, op_b: mast.SsaId, type: mast.Type, name: Optional[str] = None): - op = std.MulfOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) + op = arith.MulFOperation(match=0, operand_a=op_a, operand_b=op_b, type=type) return self._insert_op_in_block([name], op) def dim(self, memref_or_tensor: mast.SsaId, index: mast.SsaId, memref_type: Union[mast.MemRefType, mast.TensorType], name: Optional[str] = None): - op = std.DimOperation(match=0, operand=memref_or_tensor, index=index, + op = memref.DimOperation(match=0, operand=memref_or_tensor, index=index, type=memref_type) return self._insert_op_in_block([name], op) def index_constant(self, value: int, name: Optional[str] = None): - op = std.ConstantOperation(match=0, value=value, type=mast.IndexType()) + op = arith.ConstantOperation(match=0, value=value, type=mast.IndexType()) return self._insert_op_in_block([name], op) def float_constant(self, value: float, type: mast.FloatType, name: Optional[str] = None): - op = std.ConstantOperation(match=0, value=value, type=type) + op = arith.ConstantOperation(match=0, value=value, type=type) return self._insert_op_in_block([name], op) # }}} diff --git a/tests/test_builder.py b/tests/test_builder.py index f97eb43..45e1428 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -2,7 +2,7 @@ from mlir.builder import IRBuilder from mlir.builder import Reads, Writes, Isa from mlir.dialects.affine import AffineLoadOp -from mlir.dialects.standard import AddfOperation +from mlir.dialects.arith import AddFOperation def test_saxpy_builder(): @@ -39,14 +39,14 @@ def test_saxpy_builder(): def test_query(): block = parse_string(""" func.func @saxpy(%a : f64, %x : memref, %y : memref) { -%c0 = constant 0 : index -%n = dim %x, %c0 : memref +%c0 = arith.constant 0 : index +%n = memref.dim %x, %c0 : memref affine.for %i = 0 to %n { %xi = affine.load %x[%i+1] : memref - %axi = mulf %a, %xi : f64 + %axi = arith.mulf %a, %xi : f64 %yi = affine.load %y[%i] : memref - %axpyi = addf %yi, %axi : f64 + %axpyi = arith.addf %yi, %axi : f64 affine.store %axpyi, %y[%i] : memref } return @@ -60,11 +60,11 @@ def query(expr): for op in block.body + for_block.body if expr(op))) - assert query(Writes("%c0")).dump() == "%c0 = constant 0 : index" + assert query(Writes("%c0")).dump() == "%c0 = arith.constant 0 : index" assert (query(Reads("%y") & Isa(AffineLoadOp)).dump() == "%yi = affine.load %y [ %i ] : memref") - assert query(Reads(c0)).dump() == "%n = dim %x , %c0 : memref" + assert query(Reads(c0)).dump() == "%n = memref.dim %x , %c0 : memref" def test_build_with_queries(): @@ -92,7 +92,7 @@ def index(expr): with builder.goto_before(Reads(a0) & Reads(a1)): builder.addf(b0, b1, F64) - with builder.goto_after(Reads(b0) & Isa(AddfOperation)): + with builder.goto_after(Reads(b0) & Isa(AddFOperation)): builder.addf(c0, c1, F64) builder.func.ret() diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 4fc7f3f..565352a 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -53,11 +53,11 @@ def test_copy(): def test_dot(): assert_roundtrip_equivalence("""module { func.func @dot(%arg0: memref, %M: index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %1 = view %arg0 [ %c0 ] [ %M ] : memref to memref - %2 = view %arg0 [ %c0 ] [ %M ] : memref to memref - %3 = view %arg0 [ %c0 ] [ ] : memref to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = memref.view %arg0 [ %c0 ] [ %M ] : memref to memref + %2 = memref.view %arg0 [ %c0 ] [ %M ] : memref to memref + %3 = memref.view %arg0 [ %c0 ] [ ] : memref to memref linalg.dot ins ( %1 , %2 : memref , memref ) outs ( %3 : memref ) return } @@ -89,8 +89,8 @@ def test_generic(): func.func @example(%A: memref, %B: memref, %C: memref) { linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins ( %A, %B : memref, memref ) outs ( %C : memref ) { ^bb0 (%a: f64, %b: f64, %c: f64): - %c0 = constant 3.14 : f64 - %d = addf %a , %b : f64 + %c0 = arith.constant 3.14 : f64 + %d = arith.addf %a , %b : f64 linalg.yield %d : f64 } return @@ -103,12 +103,12 @@ def test_indexed_generic(): func.func @indexed_generic_region(%arg0: memref>, %arg1: memref>, %arg2: memref>) { linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins ( %arg0 : memref> ) outs ( %arg1, %arg2 : memref>, memref> ) { ^bb0 (%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): - %result_1 = mulf %a , %b : f32 - %ij = addi %i , %j : index - %ijk = addi %ij , %k : index - %ijk_int = index_cast %ijk : index to i32 - %ijk_float = sitofp %ijk_int : (i32) -> f32 - %result_2 = addf %c , %ijk_float : f32 + %result_1 = arith.mulf %a , %b : f32 + %ij = arith.addi %i , %j : index + %ijk = arith.addi %ij , %k : index + %ijk_int = arith.index_cast %ijk : index to i32 + %ijk_float = arith.sitofp %ijk_int : (i32) -> f32 + %result_2 = arith.addf %c , %ijk_float : f32 linalg.yield %result_1, %result_2 : f32, f32 } return @@ -119,7 +119,7 @@ def test_reduce(): assert_roundtrip_equivalence("""module { func.func @reduce(%arg0: tensor<16x32x64xf32>, %arg1: tensor<16x64xf32>) { %reduce = linalg.reduce ins ( %arg0 : tensor<16x32x64xf32> ) outs ( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) { - %0 = arith.addf %out, %in : f32 + %0 = arith.addf %out , %in : f32 linalg.yield %0 : f32 } return @@ -130,17 +130,17 @@ def test_reduce(): def test_view(): assert_roundtrip_equivalence("""module { func.func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { - %c0 = constant 0 : index - %0 = muli %arg0 , %arg0 : index - %1 = alloc (%0) : memref + %c0 = arith.constant 0 : index + %0 = arith.muli %arg0 , %arg0 : index + %1 = memref.alloc (%0) : memref %2 = linalg.range %arg0 : %arg1 : %arg 2 : !linalg.range - %3 = view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref + %3 = memref.view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref %4 = linalg.slice %3 [ %2, %2 ] : memref , !linalg.range, !linalg.range , memref %5 = linalg.slice %3 [ %2, %arg2 ] : memref , !linalg.range, index , memref> %6 = linalg.slice %3 [ %arg2, %2 ] : memref , index, !linalg.range , memref> %7 = linalg.slice %3 [ %arg2, %arg3 ] : memref , index, index , memref - %8 = view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref> - dealloc %1 : memref + %8 = memref.view %1 [ %c0 ] [ %arg0, %arg0 ] : memref to memref> + memref.dealloc %1 : memref return } }""") @@ -149,11 +149,11 @@ def test_view(): def test_matmul(): assert_roundtrip_equivalence("""module { func.func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %A = view %arg0 [ %c0 ] [ %M, %K ] : memref to memref - %B = view %arg0 [ %c0 ] [ %K, %N ] : memref to memref - %C = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %A = memref.view %arg0 [ %c0 ] [ %M, %K ] : memref to memref + %B = memref.view %arg0 [ %c0 ] [ %K, %N ] : memref to memref + %C = memref.view %arg0 [ %c0 ] [ %M, %N ] : memref to memref linalg.matmul ins ( %A , %B : memref , memref ) outs ( %C : memref ) linalg.matmul ins ( %A , %B : memref , memref ) outs ( %C : memref ) -> memref return @@ -164,11 +164,11 @@ def test_matmul(): def test_matvec(): assert_roundtrip_equivalence("""module { func.func @matvec(%arg0: memref, %M: index, %N: index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %2 = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref - %3 = view %arg0 [ %c0 ] [ %M ] : memref to memref - %4 = view %arg0 [ %c0 ] [ %N ] : memref to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %2 = memref.view %arg0 [ %c0 ] [ %M, %N ] : memref to memref + %3 = memref.view %arg0 [ %c0 ] [ %M ] : memref to memref + %4 = memref.view %arg0 [ %c0 ] [ %N ] : memref to memref linalg.matvec ins ( %2 , %3 : memref , memref ) outs ( %4 : memref ) return } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 82e54a7..1154880 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -26,7 +26,7 @@ def test_function_no_args(): """ code = '''module { func.func @toy_func() -> index { - %0 = constant 0 : index + %0 = arith.constant 0 : index return %0 : index } }''' @@ -67,15 +67,15 @@ def test_affine_expr_roundtrip(): def test_loop_dialect_roundtrip(): src = """module { func.func @for(%outer: index, %A: memref, %B: memref, %C: memref, %result: memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %d0 = dim %A , %c0 : memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = memref.dim %A , %c0 : memref %b0 = affine.min affine_map<()[s0, s1] -> (1024, s0 - s1)> ()[%d0, %outer] scf.for %i0 = %c0 to %b0 step %c1 { - %B_elem = load %B [ %i0 ] : memref - %C_elem = load %C [ %i0 ] : memref - %sum_elem = addf %B_elem , %C_elem : f32 - store %sum_elem , %result [ %i0 ] : memref + %B_elem = memref.load %B [ %i0 ] : memref + %C_elem = memref.load %C [ %i0 ] : memref + %sum_elem = arith.addf %B_elem , %C_elem : f32 + memref.store %sum_elem , %result [ %i0 ] : memref } return } diff --git a/tests/test_scf.py b/tests/test_scf.py index 7f267ac..17afa13 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -11,8 +11,8 @@ def test_scf_for(): assert_roundtrip_equivalence("""module { func.func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index, %sum_0: f32) -> (f32) { %sum = scf.for %iv = %lb to %ub step %step iter_args ( %sum_iter = %sum_0 ) -> ( f32 ) { - %t = load %buffer [ %iv ] : memref<1024xf32> - %sum_next = arith.addf %sum_iter, %t : f32 + %t = memref.load %buffer [ %iv ] : memref<1024xf32> + %sum_next = arith.addf %sum_iter , %t : f32 scf.yield %sum_next : f32 } return %sum : f32 diff --git a/tests/test_syntax.py b/tests/test_syntax.py index 087c662..028ea56 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -46,7 +46,7 @@ def test_trailing_loc(parser: Optional[Parser] = None): code = ''' module { func.func @myfunc() { - %c:2 = addf %a, %b : f32 loc("test_syntax.py":36:59) + %c:2 = arith.addf %a, %b : f32 loc("test_syntax.py":36:59) } } loc("hi.mlir":30:1) ''' @@ -86,12 +86,12 @@ def test_functions(parser: Optional[Parser] = None): code = ''' module { func.func @myfunc_a() { - %c:2 = addf %a, %b : f32 + %c:2 = arith.addf %a, %b : f32 } func.func @myfunc_b() { - %d:2 = addf %a, %b : f64 + %d:2 = arith.addf %a, %b : f64 ^e: - %f:2 = addf %d, %d : f64 + %f:2 = arith.addf %d, %d : f64 } }''' parser = parser or Parser() @@ -135,17 +135,17 @@ def test_affine(parser: Optional[Parser] = None): %0 = affine.min affine_map<(d0)[s0] -> (1000, d0 + 512, s0)> (%arg0)[%arg1] } func.func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) { - %c0 = constant 1 : index - %c1 = constant 0 : index - %b = alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)> - %0 = alloc(%arg0, %arg1) : memref + %c0 = arith.constant 1 : index + %c1 = arith.constant 0 : index + %b = memref.alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)> + %0 = memref.alloc(%arg0, %arg1) : memref affine.for %arg3 = %arg1 to %arg2 step 768 { - %13 = dim %0, %c1 : memref + %13 = memref.dim %0, %c1 : memref affine.for %arg4 = 0 to %13 step 264 { - %18 = dim %0, %c0 : memref - %20 = subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref + %18 = memref.dim %0, %c0 : memref + %20 = memref.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref to memref (d0 * s1 + d1 * s2 + s0)> - %24 = dim %20, %c0 : memref (d0 * s1 + d1 * s2 + s0)> + %24 = memref.dim %20, %c0 : memref (d0 * s1 + d1 * s2 + s0)> affine.for %arg5 = 0 to %24 step 768 { "foo"() : () -> () } @@ -210,8 +210,8 @@ def test_generic_dialect_std(parser: Optional[Parser] = None): "module"() ( { "func.func"() ( { ^bb0(%arg0: i32, %arg1: i32): // no predecessors - %0 = "std.addi"(%arg1, %arg0) : (i32, i32) -> i32 - "std.return"(%0) : (i32) -> () + %0 = "arith.addi"(%arg1, %arg0) : (i32, i32) -> i32 + "return"(%0) : (i32) -> () }) {sym_name = "mlir_entry", type = (i32, i32) -> i32} : () -> () }) : () -> () ''' @@ -224,13 +224,13 @@ def test_generic_dialect_std_cond_br(parser: Optional[Parser] = None): "module"() ( { "func.func"() ( { ^bb0(%arg0: i32): // no predecessors - %c1_i32 = "std.constant"() {value = 1 : i32} : () -> i32 - %0 = "std.cmpi"(%arg0, %c1_i32) {predicate = 3 : i64} : (i32, i32) -> i1 - "std.cond_br"(%0)[^bb1, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () + %c1_i32 = "arith.constant"() {value = 1 : i32} : () -> i32 + %0 = "arith.cmpi"(%arg0, %c1_i32) {predicate = 3 : i64} : (i32, i32) -> i1 + "cf.cond_br"(%0)[^bb1, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () ^bb1: // pred: ^bb0 - "std.return"(%c1_i32) : (i32) -> () + "return"(%c1_i32) : (i32) -> () ^bb2: // pred: ^bb0 - "std.return"(%c1_i32) : (i32) -> () + "return"(%c1_i32) : (i32) -> () }) {sym_name = "mlir_entry", type = (i32) -> i32} : () -> () }) : () -> () ''' @@ -259,26 +259,26 @@ def test_generic_dialect_generic_op(parser: Optional[Parser] = None): "func.func"() ( { ^bb0(%arg0: i32, %arg1: i32): // no predecessors %0 = "generic_op_with_region"(%arg0, %arg1) ( { - %1 = "std.addi"(%arg1, %arg0) : (i32, i32) -> i32 - "std.return"(%1) : (i32) -> () + %1 = "arith.addi"(%arg1, %arg0) : (i32, i32) -> i32 + "return"(%1) : (i32) -> () }) : (i32, i32) -> i32 %2 = "generic_op_with_regions"(%0, %arg0) ( { - %3 = "std.subi"(%0, %arg0) : (i32, i32) -> i32 - "std.return"(%3) : (i32) -> () + %3 = "arith.subi"(%0, %arg0) : (i32, i32) -> i32 + "return"(%3) : (i32) -> () }, { - %4 = "std.addi"(%3, %arg0) : (i32, i32) -> i32 - "std.return"(%4) : (i32) -> () + %4 = "arith.addi"(%3, %arg0) : (i32, i32) -> i32 + "return"(%4) : (i32) -> () }) : (i32, i32) -> i32 %5 = "generic_op_with_region_and_attr"(%2, %arg0) ( { - %6 = "std.subi"(%2, %arg0) : (i32, i32) -> i32 - "std.return"(%6) : (i32) -> () + %6 = "arith.subi"(%2, %arg0) : (i32, i32) -> i32 + "return"(%6) : (i32) -> () }) {attr = "string attribute"} : (i32, i32) -> i32 %7 = "generic_op_with_region_and_successor"(%5, %arg0)[^bb1] ( { - %8 = "std.addi"(%5, %arg0) : (i32, i32) -> i32 - "std.br"(%8)[^bb1] : (i32) -> () + %8 = "arith.addi"(%5, %arg0) : (i32, i32) -> i32 + "cf.br"(%8)[^bb1] : (i32) -> () }) {attr = "string attribute"} : (i32, i32) -> i32 ^bb1(%ret: i32): - "std.return"(%ret) : (i32) -> () + "return"(%ret) : (i32) -> () }) {sym_name = "mlir_entry", type = (i32, i32) -> i32} : () -> () }) : () -> () ''' diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 7fe2b40..7e105c3 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -14,11 +14,11 @@ def parser(parser: Optional[Parser] = None) -> Parser: _code = ''' module { func.func @test0(%arg0: index, %arg1: index) { - %0 = alloc() : memref<100x100xf32> - %1 = alloc() : memref<100x100xf32, 2> - %2 = alloc() : memref<1xi32> - %c0 = constant 0 : index - %c64 = constant 64 : index + %0 = memref.alloc() : memref<100x100xf32> + %1 = memref.alloc() : memref<100x100xf32, 2> + %2 = memref.alloc() : memref<1xi32> + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index affine.for %arg2 = 0 to 10 { affine.for %arg3 = 0 to 10 { affine.dma_start %0[%arg2, %arg3], %1[%arg2, %arg3], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> @@ -30,10 +30,10 @@ def parser(parser: Optional[Parser] = None) -> Parser: func.func @test1(%arg0: index, %arg1: index) { affine.for %arg2 = 0 to 10 { affine.for %arg3 = 0 to 10 { - %c0 = constant 0 : index - %c64 = constant 64 : index - %c128 = constant 128 : index - %c256 = constant 256 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index affine.dma_start %0[%arg2, %arg3], %1[%arg2, %arg3], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32> affine.dma_wait %2[%c0], %c64 : memref<1xi32> } @@ -41,7 +41,7 @@ def parser(parser: Optional[Parser] = None) -> Parser: return } func.func @test2(%arg0: index, %arg1: index) { - %0 = alloc() : memref<100x100xf32> + %0 = memref.alloc() : memref<100x100xf32> } } ''' From cdf79b57784209765da1d07517726c6202621881 Mon Sep 17 00:00:00 2001 From: qubitter <11893617+qubitter@users.noreply.github.com> Date: Tue, 20 May 2025 12:46:57 -0400 Subject: [PATCH 5/8] actually implement func.func as an operation, so that typing can work --- mlir/dialects/func.py | 14 ++++++++++++++ mlir/lark/mlir.lark | 8 ++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mlir/dialects/func.py b/mlir/dialects/func.py index c27d86a..876e273 100644 --- a/mlir/dialects/func.py +++ b/mlir/dialects/func.py @@ -34,6 +34,20 @@ class ConstantOperation(DialectOp): _syntax_ = ['func.constant {value.symbol_ref_id} : {type.type}'] # Note: The 'func.func' operation is defined as 'function' in mlir.lark. +# not anymore lmfaooooo +@dataclass +class FuncOperation(DialectOp): + name: mast.SymbolRefId + args: Optional[List[mast.NamedArgument]] + result_list: Optional[List[mast.OpResult]] | mast.OpResult + func_mod_attrs: Optional[mast.AttributeDict] + body: Optional[mast.Region] + trail: Optional[mast.Location] = None + + _syntax_ = [ + 'func.func {name.symbol_ref_id} ( {args.optional_arg_list} ) {result_list.optional_fn_result_list} {func_mod_attrs.optional_func_mod_attrs} {body.optional_fn_body}', + 'func.func {name.symbol_ref_id} ( {args.optional_arg_list} ) {result_list.optional_fn_result_list} {func_mod_attrs.optional_func_mod_attrs} {body.optional_fn_body} (loc ({trail.optional_location}))'] + @dataclass class ReturnOperation(DialectOp): diff --git a/mlir/lark/mlir.lark b/mlir/lark/mlir.lark index 9258111..d70e22d 100644 --- a/mlir/lark/mlir.lark +++ b/mlir/lark/mlir.lark @@ -172,7 +172,7 @@ custom_operation : bare_id "." bare_id optional_ssa_use_list trailing_type // Final operation definition // NOTE: "pymlir_dialect_ops" is defined externally by pyMLIR -operation : optional_op_result_list (pymlir_dialect_ops | custom_operation | generic_operation | module | generic_module | function) optional_trailing_loc +operation : optional_op_result_list (pymlir_dialect_ops | custom_operation | generic_operation | module | generic_module) optional_trailing_loc // ---------------------------------------------------------------------- // Blocks and regions @@ -236,7 +236,7 @@ function_result_list_parens : ("(" ")") | ("(" function_result_list_no_parens ") // Definition module : "module" optional_symbol_ref_id optional_func_mod_attrs region optional_trailing_loc -function : "func.func" symbol_ref_id "(" optional_arg_list ")" optional_fn_result_list optional_func_mod_attrs optional_fn_body optional_trailing_loc +// function : "func.func" symbol_ref_id "(" optional_arg_list ")" optional_fn_result_list optional_func_mod_attrs optional_fn_body optional_trailing_loc generic_module : string_literal "(" optional_arg_list ")" "(" region ")" optional_attr_dict trailing_type optional_trailing_loc // ---------------------------------------------------------------------- @@ -315,9 +315,9 @@ attribute_alias_def : attribute_alias "=" attribute_value // Structure of an MLIR parse-able string definition_list : definition* -function_list : function* +// function_list : function* module_list : (module | generic_module)* -definition_and_function_list : definition_list function_list +definition_and_function_list : definition_list // function_list definition_and_module_list : definition_list module_list mlir_file: definition_and_function_list* -> only_functions_and_definitions_file From 17544dd6c7a30cd428be2e9b1841f19b3997befe Mon Sep 17 00:00:00 2001 From: qubitter <11893617+qubitter@users.noreply.github.com> Date: Fri, 4 Jul 2025 15:21:39 -0400 Subject: [PATCH 6/8] update setup.py, requirements.txt to use new name for lark --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index af4bfeb..cd17ee2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -lark-parser==0.7.8 +lark==0.7.8 parse==1.14.0 pytest diff --git a/setup.py b/setup.py index 2c94ff0..dfe6604 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ }, include_package_data=True, install_requires=[ - 'lark-parser', 'parse' + 'lark', 'parse' ], tests_require=['pytest', 'pytest-cov'], test_suite='pytest', From 8b642e1e4da6d405ac8a3238f4551cacd4ba873d Mon Sep 17 00:00:00 2001 From: Qucheng Jiang Date: Wed, 24 Sep 2025 10:26:41 +0800 Subject: [PATCH 7/8] spcl/pymlir#45 Update constant_literal to include posneg_integer_literal spcl/pymlir#45 --- mlir/lark/mlir.lark | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lark/mlir.lark b/mlir/lark/mlir.lark index d70e22d..409a810 100644 --- a/mlir/lark/mlir.lark +++ b/mlir/lark/mlir.lark @@ -23,7 +23,7 @@ negated_integer_literal : "-" integer_literal ?posneg_integer_literal : integer_literal | negated_integer_literal float_literal : /[-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?/ string_literal : ESCAPED_STRING -constant_literal : bool_literal | integer_literal | float_literal | string_literal +constant_literal : bool_literal | posneg_integer_literal | float_literal | string_literal // Identifier syntax bare_id : (letter| underscore) (letter|digit|underscore|id_chars)* From dfa967de361b0e11178b11980da43b740eda0666 Mon Sep 17 00:00:00 2001 From: Qucheng Jiang Date: Wed, 24 Sep 2025 10:44:33 +0800 Subject: [PATCH 8/8] Update syntax for ConstantOperation in arith.py --- mlir/dialects/arith.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/dialects/arith.py b/mlir/dialects/arith.py index 5a2b5db..92ed0f3 100644 --- a/mlir/dialects/arith.py +++ b/mlir/dialects/arith.py @@ -90,7 +90,8 @@ class CmpfOperation(DialectOp): class ConstantOperation(DialectOp): value: Literal type: mast.Type - _syntax_ = 'arith.constant {value.constant_literal} : {type.type}' + _syntax_ = ['arith.constant {value.constant_literal} : {type.type}', 'arith.constant {value.constant_literal}'] + @dataclass