diff --git a/dagrt/data.py b/dagrt/data.py index 298a10c..92ff198 100644 --- a/dagrt/data.py +++ b/dagrt/data.py @@ -318,42 +318,21 @@ def map_variable(self, expr): "nothing known about '%s'" % expr.name) - def map_sum(self, expr): - kind = None - - last_exc = None - - # Sums must be homogeneous, so being able to - # infer one child is good enough. - for ch in expr.children: - try: - ch_kind = self.rec(ch) - except UnableToInferKind as e: - if self.check: - raise - else: - last_exc = e - - else: - kind = unify(kind, ch_kind) - - if kind is None: - raise last_exc - else: - return kind - - def map_product_like(self, children): + def map_expr_with_children(self, children): kind = None for ch in children: kind = unify(kind, self.rec(ch)) return kind + def map_sum(self, expr): + return self.map_expr_with_children(expr.children) + def map_product(self, expr): - return self.map_product_like(expr.children) + return self.map_expr_with_children(expr.children) def map_quotient(self, expr): - return self.map_product_like((expr.numerator, expr.denominator)) + return self.map_expr_with_children((expr.numerator, expr.denominator)) def map_power(self, expr): if self.check and not isinstance(self.rec(expr.exponent), Scalar): diff --git a/test/test_codegen_fortran.py b/test/test_codegen_fortran.py index 420efd3..bcdb6ee 100755 --- a/test/test_codegen_fortran.py +++ b/test/test_codegen_fortran.py @@ -118,7 +118,7 @@ def test_arrays_and_linalg(): def test_self_dep_in_loop(): with CodeBuilder(name="primary") as cb: cb("y", "y") - cb("y", "f(0, 2*i*f(0, y if i > 2 else 2*y))", + cb("y", "5 + f(0, 2*i*f(0, y if i > 2 else 2*y))", loops=(("i", 0, 5),)) cb("y", "y")