Skip to content
This repository was archived by the owner on Jun 14, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 6 additions & 27 deletions dagrt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,42 +318,21 @@ def map_variable(self, expr):
"nothing known about '%s'"
% expr.name)

def map_sum(self, expr):
kind = None

last_exc = None

# Sums must be homogeneous, so being able to
# infer one child is good enough.
for ch in expr.children:
try:
ch_kind = self.rec(ch)
except UnableToInferKind as e:
if self.check:
raise
else:
last_exc = e

else:
kind = unify(kind, ch_kind)

if kind is None:
raise last_exc
else:
return kind

def map_product_like(self, children):
def map_expr_with_children(self, children):
kind = None
for ch in children:
kind = unify(kind, self.rec(ch))

return kind

def map_sum(self, expr):
return self.map_expr_with_children(expr.children)

def map_product(self, expr):
return self.map_product_like(expr.children)
return self.map_expr_with_children(expr.children)

def map_quotient(self, expr):
return self.map_product_like((expr.numerator, expr.denominator))
return self.map_expr_with_children((expr.numerator, expr.denominator))

def map_power(self, expr):
if self.check and not isinstance(self.rec(expr.exponent), Scalar):
Expand Down
2 changes: 1 addition & 1 deletion test/test_codegen_fortran.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_arrays_and_linalg():
def test_self_dep_in_loop():
with CodeBuilder(name="primary") as cb:
cb("y", "<state>y")
cb("y", "<func>f(0, 2*i*<func>f(0, y if i > 2 else 2*y))",
cb("y", "5 + <func>f(0, 2*i*<func>f(0, y if i > 2 else 2*y))",
loops=(("i", 0, 5),))
cb("<state>y", "y")

Expand Down