Skip to content

Commit 4626414

Browse files
committed
Update tag decorator to support attributes with string value
1 parent 7cdfe7e commit 4626414

3 files changed

Lines changed: 34 additions & 6 deletions

File tree

src/pydsl/transform.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from mlir.dialects import transform
88
from mlir.dialects.transform import loop, structured
9-
from mlir.ir import IndexType, IntegerAttr, OpView, UnitAttr
9+
from mlir.ir import IndexType, IntegerAttr, OpView, UnitAttr, Attribute
1010

1111
from pydsl.macro import CallMacro, Compiled, Evaluated, Uncompiled
1212
from pydsl.protocols import SubtreeOut, ToMLIRBase, lower_single
@@ -98,17 +98,29 @@ def decorator(mlir):
9898

9999

100100
@CallMacro.generate()
101-
def tag(visitor: ToMLIRBase, attr_name: Evaluated[str]) -> Evaluated[Callable]:
101+
def tag(
102+
visitor: ToMLIRBase,
103+
attr_name: Evaluated[str],
104+
attr_value: Evaluated[str | None] = None,
105+
) -> Evaluated[Callable]:
102106
"""
103-
Tags the `mlir` MLIR operation with a MLIR unit attribute with name
107+
Tags the `mlir` MLIR operation with a MLIR attribute with name
104108
`attr_name`.
105109
106110
Arguments:
107111
- `mlir`: AST. The AST node whose equivalent MLIR Operator is to be tagged
108-
with the unit attribute
109-
- `attr_name`: str. The name of the unit attribute
112+
with the attribute
113+
- `attr_name`: str. The name of the attribute
114+
- `attr_value`: str. An optional argument that will convert the attribute
115+
from a unit attribute to a key/value attribute pair.
110116
"""
111-
return attr_setter(attr_name, UnitAttr.get())
117+
if type(attr_name) is not str:
118+
raise TypeError("Attribute name is not a string")
119+
120+
if attr_value is None:
121+
return attr_setter(attr_name, UnitAttr.get())
122+
else:
123+
return attr_setter(attr_name, Attribute.parse(attr_value))
112124

113125

114126
@CallMacro.generate()

src/pydsl/type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import ast
45
import collections.abc as cabc
56
import ctypes
@@ -782,6 +783,7 @@ def op_ge(self, rhs: SupportsInt) -> "Bool":
782783
@classmethod
783784
def CType(cls) -> tuple[type]:
784785
ctypes_map = {
786+
16: np.float16,
785787
32: ctypes.c_float,
786788
64: ctypes.c_double,
787789
80: ctypes.c_longdouble,

tests/e2e/test_transform.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,22 @@ def f(m: MemRef[UInt32, 4, 4]):
128128
)
129129

130130

131+
def test_tag_with_string_value():
132+
@compile(globals())
133+
@tag("mytag", "0")
134+
def f():
135+
a: F32 = 0.0
136+
b: F32 = 1.0
137+
138+
mlir = f.emit_mlir()
139+
140+
# No runtime check: testing function of tag only
141+
assert r"mytag = 0" in mlir
142+
143+
131144
if __name__ == "__main__":
132145
run(test_multiple_recursively_tag)
133146
run(test_multiple_recursively_int_attr)
134147
run(test_cse_then_coalesce)
135148
run(test_outline_loop)
149+
run(test_tag_with_string_value)

0 commit comments

Comments
 (0)