Skip to content

Commit 70929e6

Browse files
authored
Merge pull request Huawei-CPLLab#74 from Balint-R/bool-number-fix
Fix Int range and Number to Bool cast
2 parents 3684309 + ff7e18f commit 70929e6

3 files changed

Lines changed: 142 additions & 70 deletions

File tree

src/pydsl/type.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def lower_class(cls) -> tuple[mlir.Type]:
275275
def val_range(cls) -> tuple[int, int]:
276276
match cls.sign:
277277
case Sign.SIGNED:
278-
return (-(1 << (cls.width - 1)), 1 << (cls.width - 1))
278+
return (-(1 << (cls.width - 1)), (1 << (cls.width - 1)) - 1)
279279
case Sign.UNSIGNED:
280-
return (0, (1 << cls.width) - 2)
280+
return (0, (1 << cls.width) - 1)
281281
case _:
282282
AssertionError("unimplemented sign")
283283

@@ -464,14 +464,11 @@ def to_CType(cls, arg_cont: ArgContainer, pyval: Any):
464464
f"{pyval} cannot be converted into an Int ctype"
465465
) from e
466466

467-
if (1 << cls.width) <= pyval:
468-
raise TypeError(
469-
f"{pyval} cannot fit into an Int of size {cls.width}"
470-
)
471-
472-
if cls.sign is Sign.UNSIGNED and pyval < 0:
467+
if not cls.in_range(pyval):
468+
lo, hi = cls.val_range()
473469
raise ValueError(
474-
f"expected positive pyval for unsigned Int, got {pyval}"
470+
f"{pyval} cannot fit into {cls.__qualname__}, must be in "
471+
f"the range [{lo}, {hi}]"
475472
)
476473

477474
arg_cont.add_arg(pyval)
@@ -1060,72 +1057,71 @@ def Int(self, target_type: type[AnyInt]) -> AnyInt:
10601057
def Float(self, target_type: type[F]) -> F:
10611058
return target_type(self.value)
10621059

1063-
def Index(self) -> "Index":
1060+
def Index(self) -> Index:
10641061
return Index(self.value)
10651062

1063+
def Bool(self) -> Bool:
1064+
return Bool(self.value)
1065+
10661066

10671067
# These are for unary operators in Number class
1068-
UnNumberOp = namedtuple(
1069-
"UnNumberOp", "dunder_name, internal_op, default_ret_type"
1070-
)
1068+
UnNumberOp = namedtuple("UnNumberOp", "dunder_name, internal_op")
10711069
un_number_op = {
1072-
UnNumberOp("op_neg", operator.neg, Number),
1073-
UnNumberOp("op_not", operator.not_, Number),
1074-
UnNumberOp("op_pos", operator.pos, Number),
1075-
UnNumberOp("op_abs", operator.abs, Number),
1076-
UnNumberOp("op_truth", operator.truth, Number),
1077-
UnNumberOp("op_floor", math.floor, Number),
1078-
UnNumberOp("op_ceil", math.ceil, Number),
1079-
UnNumberOp("op_round", round, Number),
1080-
UnNumberOp("op_invert", operator.invert, Number),
1081-
# Support CompileTimeTestable
1082-
UnNumberOp("Bool", operator.truth, Number),
1070+
UnNumberOp("op_neg", operator.neg),
1071+
UnNumberOp("op_not", operator.not_),
1072+
UnNumberOp("op_pos", operator.pos),
1073+
UnNumberOp("op_abs", operator.abs),
1074+
UnNumberOp("op_truth", operator.truth),
1075+
UnNumberOp("op_floor", math.floor),
1076+
UnNumberOp("op_ceil", math.ceil),
1077+
UnNumberOp("op_round", round),
1078+
UnNumberOp("op_invert", operator.invert),
10831079
}
10841080

10851081
for tup in un_number_op:
10861082

1087-
def method_gen(op):
1083+
def method_gen(tup):
10881084
"""
10891085
This function exists simply to allow a unique generic_unary_op to be
10901086
generated whose variables are bound to the arguments of this function
10911087
rather than the variable of the for loop.
10921088
"""
1093-
_, internal_op, default_ret_type = tup
1089+
# TODO: why is the above useful? What's wrong with binding to for loop
1090+
# variables?
1091+
_, internal_op = tup
10941092

10951093
# perform the unary operation on the underlying value
1096-
def generic_unary_op(
1097-
self: "Number", *args, **kwargs
1098-
) -> numbers.Number:
1099-
return default_ret_type(internal_op(self.value))
1094+
def generic_unary_op(self: Number) -> Number:
1095+
return Number(internal_op(self.value))
11001096

11011097
return generic_unary_op
11021098

1103-
ldunder_name, internal_op, rdunder_name = tup
1099+
ldunder_name, internal_op = tup
11041100
setattr(Number, ldunder_name, method_gen(tup))
11051101

11061102
# These are for binary operators in Number
11071103
BinNumberOp = namedtuple(
1108-
"BinNumberOp", "ldunder_name, internal_op, rdunder_name, default_ret_type"
1104+
"BinNumberOp", "ldunder_name, internal_op, rdunder_name"
11091105
)
11101106
bin_number_op = {
1111-
BinNumberOp("op_add", operator.add, "op_radd", Number),
1112-
BinNumberOp("op_sub", operator.sub, "op_rsub", Number),
1113-
BinNumberOp("op_mul", operator.mul, "op_rmul", Number),
1114-
BinNumberOp("op_truediv", operator.truediv, "op_rtruediv", Number),
1115-
BinNumberOp("op_pow", operator.pow, "op_rpow", Number),
1116-
BinNumberOp("op_divmod", divmod, "op_rdivmod", Number),
1117-
BinNumberOp("op_floordiv", operator.floordiv, "op_rfloordiv", Number),
1118-
BinNumberOp("op_mod", operator.mod, "op_rmod", Number),
1119-
BinNumberOp("op_lshift", operator.lshift, "op_rlshift", Number),
1120-
BinNumberOp("op_rshift", operator.rshift, "op_rrshift", Number),
1121-
BinNumberOp("op_and", operator.and_, "op_rand", Number),
1122-
BinNumberOp("op_xor", operator.xor, "op_rxor", Number),
1123-
BinNumberOp("op_or", operator.or_, "op_ror", Number),
1124-
BinNumberOp("op_lt", operator.lt, "op_gt", Bool),
1125-
BinNumberOp("op_le", operator.le, "op_ge", Bool),
1126-
BinNumberOp("op_eq", operator.le, "op_eq", Bool),
1127-
BinNumberOp("op_ge", operator.ge, "op_le", Bool),
1128-
BinNumberOp("op_gt", operator.gt, "op_lt", Bool),
1107+
BinNumberOp("op_add", operator.add, "op_radd"),
1108+
BinNumberOp("op_sub", operator.sub, "op_rsub"),
1109+
BinNumberOp("op_mul", operator.mul, "op_rmul"),
1110+
BinNumberOp("op_truediv", operator.truediv, "op_rtruediv"),
1111+
BinNumberOp("op_pow", operator.pow, "op_rpow"),
1112+
BinNumberOp("op_divmod", divmod, "op_rdivmod"),
1113+
BinNumberOp("op_floordiv", operator.floordiv, "op_rfloordiv"),
1114+
BinNumberOp("op_mod", operator.mod, "op_rmod"),
1115+
BinNumberOp("op_lshift", operator.lshift, "op_rlshift"),
1116+
BinNumberOp("op_rshift", operator.rshift, "op_rrshift"),
1117+
BinNumberOp("op_and", operator.and_, "op_rand"),
1118+
BinNumberOp("op_xor", operator.xor, "op_rxor"),
1119+
BinNumberOp("op_or", operator.or_, "op_ror"),
1120+
BinNumberOp("op_lt", operator.lt, "op_gt"),
1121+
BinNumberOp("op_le", operator.le, "op_ge"),
1122+
BinNumberOp("op_eq", operator.le, "op_eq"),
1123+
BinNumberOp("op_ge", operator.ge, "op_le"),
1124+
BinNumberOp("op_gt", operator.gt, "op_lt"),
11291125
}
11301126

11311127

@@ -1145,20 +1141,20 @@ def method_gen(tup):
11451141
generated whose variables are bound to the arguments of this function
11461142
rather than the variable of the for loop.
11471143
"""
1148-
_, internal_op, rdunder_name, default_ret_type = tup
1144+
_, internal_op, rdunder_name = tup
11491145

1150-
def generic_bin_op(self: "Number", rhs: NumberLike) -> NumberLike:
1146+
def generic_bin_op(self: Number, rhs: NumberLike) -> NumberLike:
11511147
# if RHS is also a Number
11521148
if isinstance(rhs, Number):
11531149
# perform the binary operation on the underlying values
1154-
return default_ret_type(internal_op(self.value, rhs.value))
1150+
return Number(internal_op(self.value, rhs.value))
11551151

1156-
# otherwise use RHS' implementation instead
1152+
# otherwise use RHS's implementation instead
11571153
return getattr(rhs, rdunder_name)(self)
11581154

11591155
return generic_bin_op
11601156

1161-
ldunder_name, internal_op, rdunder_name, default_ret_type = tup
1157+
ldunder_name, internal_op, rdunder_name = tup
11621158
setattr(Number, ldunder_name, method_gen(tup))
11631159

11641160

tests/e2e/test_arith.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import itertools
3+
from contextlib import nullcontext
34

45
from helper import (
56
f32_edges,
@@ -10,12 +11,16 @@
1011
)
1112

1213
from pydsl.frontend import compile
14+
from pydsl.macro import CallMacro, Compiled
1315
from pydsl.math import abs as p_abs
16+
from pydsl.protocols import ToMLIRBase
1417
from pydsl.type import (
18+
Bool,
1519
F32,
1620
F64,
17-
Bool,
1821
Index,
22+
Int,
23+
Number,
1924
SInt8,
2025
SInt64,
2126
UInt8,
@@ -24,24 +29,60 @@
2429
)
2530

2631

27-
def test_illegal_unfit_Int_input():
28-
with failed_from(TypeError):
32+
# TODO: rewrite with templates
33+
def test_val_range_Int8():
34+
def check(typ: type[Int], val: int, good: bool) -> None:
35+
with nullcontext() if good else compilation_failed_from(ValueError):
2936

30-
@compile(globals())
31-
def f(_: UInt8):
32-
pass
37+
@compile()
38+
def f() -> typ:
39+
return typ(val)
3340

34-
f(1 << 8)
41+
assert f() == val
3542

43+
check(UInt8, 0, True)
44+
check(UInt8, 123, True)
45+
check(UInt8, 255, True)
46+
check(UInt8, -1, False)
47+
check(UInt8, 256, False)
3648

37-
def test_illegal_Int_sign_input():
38-
with failed_from(ValueError):
49+
check(SInt8, -128, True)
50+
check(SInt8, 12, True)
51+
check(SInt8, 127, True)
52+
check(SInt8, -129, False)
53+
check(SInt8, 128, False)
54+
55+
56+
def test_ctype_range_UInt8():
57+
@compile()
58+
def f(x: UInt8) -> UInt8:
59+
return x
60+
61+
def check(val: int, good: bool) -> None:
62+
with nullcontext() if good else failed_from(ValueError):
63+
assert f(val) == val
64+
65+
check(0, True)
66+
check(98, True)
67+
check(255, True)
68+
check(-1, False)
69+
check(256, False)
3970

40-
@compile(globals())
41-
def f(_: UInt8):
42-
pass
4371

44-
f(-1)
72+
def test_ctype_range_SInt8():
73+
@compile()
74+
def f(x: SInt8) -> SInt8:
75+
return x
76+
77+
def check(val: int, good: bool) -> None:
78+
with nullcontext() if good else failed_from(ValueError):
79+
assert f(val) == val
80+
81+
check(-128, True)
82+
check(-56, True)
83+
check(127, True)
84+
check(-129, False)
85+
check(128, False)
4586

4687

4788
def test_cast_UInt8_to_Floats():
@@ -354,9 +395,29 @@ def imp_un() -> Tuple[SInt64, SInt64, SInt64, SInt64]:
354395
assert imp_un() == (-5, +5, abs(5), ~5)
355396

356397

398+
def test_Number_bool():
399+
@CallMacro.generate()
400+
def assert_number(visitor: ToMLIRBase, x: Compiled):
401+
assert isinstance(x, Number)
402+
403+
@CallMacro.generate()
404+
def assert_bool(visitor: ToMLIRBase, x: Compiled):
405+
assert isinstance(x, Bool)
406+
407+
@compile()
408+
def f():
409+
assert_number(True or False and True)
410+
assert_bool(Bool(False or True and False))
411+
assert_bool(Bool(True) and False)
412+
assert_bool(Bool(False) or 1)
413+
414+
f()
415+
416+
357417
if __name__ == "__main__":
358-
run(test_illegal_unfit_Int_input)
359-
run(test_illegal_Int_sign_input)
418+
run(test_val_range_Int8)
419+
run(test_ctype_range_UInt8)
420+
run(test_ctype_range_SInt8)
360421
run(test_cast_UInt8_to_Floats)
361422
run(test_cast_UInt64_to_Floats)
362423
run(test_cast_SInt8_to_Floats)
@@ -387,3 +448,4 @@ def imp_un() -> Tuple[SInt64, SInt64, SInt64, SInt64]:
387448
run(test_cast_Index_to_Floats)
388449
run(test_SInt_unary)
389450
run(test_Number_unary)
451+
run(test_Number_bool)

tests/e2e/test_scf.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,22 @@ def f(m: MemRefSingle, b: Bool):
6565
assert n[0] == 5
6666

6767

68+
def test_const_if():
69+
@compile()
70+
def f(m: MemRefSingle):
71+
if True:
72+
m[0] = 123
73+
else:
74+
m[0] = 456
75+
76+
n1 = np.asarray([0], dtype=np.uint32)
77+
f(n1)
78+
assert n1[0] == 123
79+
80+
6881
if __name__ == "__main__":
6982
run(test_range_basic)
7083
run(test_range_implicit_type)
7184
run(test_if)
7285
run(test_if_else)
86+
run(test_const_if)

0 commit comments

Comments
 (0)