From 4a31a1799a43075a263698bc14576eb9b3a795c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= Date: Thu, 27 May 2021 22:14:50 -0500 Subject: [PATCH] Revert "Revert "Kind inference: do not treat sums specially"" This reverts commit 28a328d98520be2d2dfc38f533615abd54d719db. --- dagrt/data.py | 33 ++++++--------------------------- test/test_codegen_fortran.py | 2 +- 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/dagrt/data.py b/dagrt/data.py index 298a10c..92ff198 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -318,42 +318,21 @@ def map_variable(self, expr): "nothing known about '%s'" % expr.name) - def map_sum(self, expr): - kind = None - - last_exc = None - - # Sums must be homogeneous, so being able to - # infer one child is good enough. - for ch in expr.children: - try: - ch_kind = self.rec(ch) - except UnableToInferKind as e: - if self.check: - raise - else: - last_exc = e - - else: - kind = unify(kind, ch_kind) - - if kind is None: - raise last_exc - else: - return kind - - def map_product_like(self, children): + def map_expr_with_children(self, children): kind = None for ch in children: kind = unify(kind, self.rec(ch)) return kind + def map_sum(self, expr): + return self.map_expr_with_children(expr.children) + def map_product(self, expr): - return self.map_product_like(expr.children) + return self.map_expr_with_children(expr.children) def map_quotient(self, expr): - return self.map_product_like((expr.numerator, expr.denominator)) + return self.map_expr_with_children((expr.numerator, expr.denominator)) def map_power(self, expr): if self.check and not isinstance(self.rec(expr.exponent), Scalar): diff --git a/test/test_codegen_fortran.py b/test/test_codegen_fortran.py index 420efd3..bcdb6ee 100755 --- a/test/test_codegen_fortran.py +++ b/test/test_codegen_fortran.py @@ -118,7 +118,7 @@ def test_arrays_and_linalg(): def test_self_dep_in_loop(): with CodeBuilder(name="primary") as cb: cb("y", "y") - cb("y", "f(0, 2*i*f(0, y if i > 2 else 2*y))", + cb("y", "5 + f(0, 2*i*f(0, y if i > 2 else 2*y))", loops=(("i", 0, 5),)) cb("y", "y")