We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2a940b1 commit 516cd76Copy full SHA for 516cd76
4 files changed
imex_version.txt
@@ -1 +1 @@
1
-571f54577e2301c70033fef9a05b8a96fa841d2b
+4366dfe334c62a1321d9aac2bae183a10f5259a1
src/jit/mlir.cpp
@@ -432,7 +432,9 @@ static const char *pass_pipeline =
432
getenv("DDPT_PASSES") ? getenv("DDPT_PASSES")
433
: "func.func(ptensor-dist),"
434
"func.func(dist-coalesce),"
435
+ "func.func(dist-infer-elementwise-cores),"
436
"convert-dist-to-standard,"
437
+ "canonicalize,"
438
"convert-ptensor-to-linalg,"
439
"canonicalize,"
440
"func.func(tosa-to-linalg),"
test/test_ewb.py
@@ -47,6 +47,15 @@ def test_add3(self):
47
v = 16 * 16 * 3
48
assert float(r1) == v
49
50
+ def test_add_mul(self):
51
+ def doit(aapi):
52
+ a = aapi.zeros((16, 16), dtype=aapi.int64)
53
+ b = aapi.ones((12, 12), dtype=aapi.int64)
54
+ a[3:13, 3:13] = b[0:10, 1:11] + b[1:11, 1:11] * b[1, 1]
55
+ return a
56
+
57
+ assert runAndCompare(doit)
58
59
def test_add_shifted1(self):
60
for dtyp in mpi_idtypes:
61
aa = dt.ones((16, 16), dtype=dtyp)
@@ -103,6 +112,16 @@ def test_add_shifted6(self):
103
112
r1 = dt.sum(c)
104
113
assert int(r1) == 388
105
114
115
+ @pytest.mark.skip(reason="FIXME halo update")
116
+ def test_add_broadcast(self):
117
118
119
+ b = aapi.arange(1, 16, 1, dtype=aapi.int64)
120
+ a[3:13, 3:13] = a[0:10, 1:11] + b[0]
121
122
123
124
106
125
@pytest.mark.skip(reason="FIXME")
107
126
def test_prod_het(self):
108
127
a = dt.full([16, 16], 2, dt.float64)
test/test_spmd.py
@@ -36,10 +36,11 @@ def test_get_locals(self):
36
assert float(c) == v
37
MPI.COMM_WORLD.barrier()
38
39
- @pytest.mark.skipif(
40
- MPI.COMM_WORLD.size == 1 and os.getenv("DDPT_FORCE_DIST", "") == "",
41
- reason="FIXME extra memref.copy",
42
- )
+ # @pytest.mark.skipif(
+ # MPI.COMM_WORLD.size == 1, and os.getenv("DDPT_FORCE_DIST", "") == "",
+ # reason="FIXME extra memref.copy",
+ # )
43
+ @pytest.mark.skip(reason="FIXME imex-remove-temporaries")
44
def test_get_locals_of_view(self):
45
a = dt.ones((32, 32), dt.float64)
46
b = a[0:32:2, 0:32:2]
0 commit comments