Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, *args, **kwargs):
for c in obj.children))
# Set dtype if not set already.
if not hasattr(obj, 'dtype'):
obj.dtype = obj.inherit_dtype_from_children(obj.children)
obj.dtype = obj.inherit_dtype_from_children(*obj.children)

return obj

Expand Down Expand Up @@ -157,7 +157,7 @@ def __rmod__(self, other):
return as_gem_uint(other).__mod__(self)

@staticmethod
def inherit_dtype_from_children(children):
def inherit_dtype_from_children(*children):
if any(c.dtype is None for c in children):
# Set dtype = None will let _assign_dtype()
# assign the default dtype for this node later.
Expand Down Expand Up @@ -340,7 +340,7 @@ def __new__(cls, *args):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b)))
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children(a, b))

self = super(Sum, cls).__new__(cls)
self.children = a, b
Expand Down Expand Up @@ -369,7 +369,7 @@ def __new__(cls, *args):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b)))
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children(a, b))

self = super(Product, cls).__new__(cls)
self.children = a, b
Expand All @@ -393,7 +393,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b)))
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children(a, b))

self = super(Division, cls).__new__(cls)
self.children = a, b
Expand All @@ -406,7 +406,7 @@ class FloorDiv(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
dtype = Node.inherit_dtype_from_children((a, b))
dtype = Node.inherit_dtype_from_children(a, b)
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})")
# Constant folding
Expand All @@ -429,7 +429,7 @@ class Remainder(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
dtype = Node.inherit_dtype_from_children((a, b))
dtype = Node.inherit_dtype_from_children(a, b)
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})")
# Constant folding
Expand All @@ -452,7 +452,7 @@ class Power(Scalar):
def __new__(cls, base, exponent):
assert not base.shape
assert not exponent.shape
dtype = Node.inherit_dtype_from_children((base, exponent))
dtype = Node.inherit_dtype_from_children(base, exponent)

# Constant folding
if isinstance(base, Zero):
Expand Down Expand Up @@ -568,7 +568,7 @@ def __new__(cls, condition, then, else_):
self = super(Conditional, cls).__new__(cls)
self.children = condition, then, else_
self.shape = then.shape
self.dtype = Node.inherit_dtype_from_children((then, else_))
self.dtype = Node.inherit_dtype_from_children(then, else_)
return self


Expand Down Expand Up @@ -888,7 +888,7 @@ class ListTensor(Node):
def __new__(cls, array):
array = asarray(array)
assert numpy.prod(array.shape)
dtype = Node.inherit_dtype_from_children(tuple(array.flat))
dtype = Node.inherit_dtype_from_children(*array.flat)

# Handle children with shape
child_shape = array.flat[0].shape
Expand Down Expand Up @@ -950,7 +950,7 @@ class Concatenate(Node):
__slots__ = ('children',)

def __new__(cls, *children):
dtype = Node.inherit_dtype_from_children(children)
dtype = Node.inherit_dtype_from_children(*children)
if all(isinstance(child, Zero) for child in children):
size = int(sum(numpy.prod(child.shape, dtype=int) for child in children))
return Zero((size,), dtype=dtype)
Expand Down
Loading