Skip to content

Commit 418d450

Browse files
committed
add support for bufferization dialect's to_trensor and materialize_in_destination.
add support for arith dialect's vadd and mul. Update memref.py and frontend.py.
1 parent 75c68d7 commit 418d450

5 files changed

Lines changed: 680 additions & 62 deletions

File tree

src/pydsl/arith.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pydsl.macro import CallMacro, Compiled
33
from pydsl.protocols import lower_single, SubtreeOut, ToMLIRBase
44
from pydsl.type import Int, Float, Sign
5+
from pydsl.vector import Vector
56

67
import mlir.dialects.arith as arith
78

@@ -59,3 +60,69 @@ def min(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut:
5960
return rett(arith.MinimumFOp(av, bv))
6061
else:
6162
raise TypeError(f"cannot take min of {rett.__qualname__}")
63+
64+
65+
@CallMacro.generate()
66+
def trunc(
67+
visitor: ToMLIRBase,
68+
a: Compiled,
69+
truncated_type: Compiled,
70+
*,
71+
round_mode: Compiled = None,
72+
) -> SubtreeOut:
73+
a_type = type(a)
74+
out_type = truncated_type
75+
if isinstance(a, Vector):
76+
out_type = Vector.get(a.shape, truncated_type)
77+
a_type = a.element_type
78+
79+
if truncated_type.width >= a_type.width:
80+
raise TypeError("truncated type must be smaller than called type.")
81+
82+
if issubclass(a_type, Int):
83+
out = arith.TruncIOp(lower_single(out_type), lower_single(a))
84+
elif issubclass(a_type, Float):
85+
out = arith.TruncFOp(lower_single(out_type), lower_single(a))
86+
else:
87+
raise TypeError(f"cannot take trunc of {a_type.__qualname__}")
88+
if round_mode is not None:
89+
out.attributes["round_mode"] = lower_single(round_mode)
90+
return (out_type)(out)
91+
92+
93+
@CallMacro.generate()
94+
def vadd(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut:
95+
rett = type(a)
96+
97+
if not isinstance(a, Vector):
98+
raise TypeError(f"NOT a vector addition operation")
99+
if type(a) != type(b):
100+
raise TypeError(f"VADD type {type(a)} does not match {type(b)}")
101+
102+
a_type = a.element_type
103+
if issubclass(a_type, Int):
104+
op = arith.addi(lower_single(a), lower_single(b))
105+
elif issubclass(a_type, Float):
106+
op = arith.addf(lower_single(a), lower_single(b))
107+
else:
108+
raise TypeError(f"unsupported vector addition type: {a_type}")
109+
return rett(op)
110+
111+
112+
@CallMacro.generate()
113+
def vmul(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut:
114+
rett = type(a)
115+
116+
if not isinstance(a, Vector):
117+
raise TypeError(f"NOT a vector multiplication operation")
118+
if type(a) != type(b):
119+
raise TypeError(f"VMUL type {type(a)} does not match {type(b)}")
120+
121+
a_type = a.element_type
122+
if issubclass(a_type, Int):
123+
op = arith.muli(lower_single(a), lower_single(b))
124+
elif issubclass(a_type, Float):
125+
op = arith.mulf(lower_single(a), lower_single(b))
126+
else:
127+
raise TypeError(f"unsupported vector multiplication type: {a_type}")
128+
return rett(op)

src/pydsl/bufferization.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from mlir.dialects import bufferization
2+
from pydsl.macro import CallMacro, Compiled
3+
from pydsl.tensor import Tensor
4+
from pydsl.memref import MemRef
5+
from pydsl.protocols import ToMLIRBase, lower_single, SubtreeOut
6+
7+
TensorFactory = Tensor.class_factory
8+
9+
10+
def verify_all_memref(*args):
11+
"""
12+
Checks that all arguments are MemRef.
13+
Raises a TypeError otherwise.
14+
"""
15+
16+
# Collect argument type names for error messages
17+
arg_type_names = []
18+
for arg in args:
19+
arg_type_names.append(type(arg).__qualname__)
20+
arg_type_str = ", ".join(arg_type_names)
21+
22+
# Check that every argument is a MemRef
23+
for arg in args:
24+
if not isinstance(arg, MemRef):
25+
raise TypeError(
26+
"bufferization operation expects arguments of type MemRef, "
27+
f"got {arg_type_str}"
28+
)
29+
30+
31+
def verify_all_tensor(*args):
32+
"""
33+
Checks that all arguments are Tensor.
34+
Raises a TypeError otherwise.
35+
"""
36+
37+
# Collect argument type names for error messages
38+
arg_type_names = []
39+
for arg in args:
40+
arg_type_names.append(type(arg).__qualname__)
41+
arg_type_str = ", ".join(arg_type_names)
42+
43+
# Check that every argument is a Tensor
44+
for arg in args:
45+
if not isinstance(arg, Tensor):
46+
raise TypeError(
47+
"bufferization operation expects arguments of type Tensor, "
48+
f"got {arg_type_str}"
49+
)
50+
51+
52+
@CallMacro.generate()
53+
def to_tensor(visitor: "ToMLIRBase", x: Compiled) -> SubtreeOut:
54+
verify_all_memref(x)
55+
56+
rep = bufferization.to_tensor(
57+
lower_single(x), restrict=True, writable=True
58+
)
59+
static_shape = rep.type.shape
60+
t_type = TensorFactory(tuple(static_shape), rep.type.element_type)
61+
62+
return t_type(rep)
63+
64+
65+
@CallMacro.generate()
66+
def materialize_in_destination(
67+
visitor: "ToMLIRBase", x: Compiled, y: Compiled
68+
):
69+
verify_all_tensor(x)
70+
verify_all_memref(y)
71+
bufferization.MaterializeInDestinationOp(
72+
None, lower_single(x), lower_single(y), writable=True
73+
)
74+
return

src/pydsl/frontend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ def get_supported_dialects(self) -> set[Dialect]:
846846
Dialect.from_name("transform"),
847847
Dialect.from_name("transform.loop"),
848848
Dialect.from_name("transform.structured"),
849+
Dialect.from_name("vector"),
849850
}
850851

851852

0 commit comments

Comments
 (0)