Skip to content

Commit bf52c5b

Browse files
committed
misc: Fixes since merge with main
1 parent c51d394 commit bf52c5b

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

devito/ir/iet/nodes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,9 @@ class Dereference(ExprStmt, Node):
10601060
The following cases are supported:
10611061
10621062
* `pointer` is an AbstractFunction, and `pointee` is an Array.
1063-
* `pointer` is an AbstractObject, and `pointee` is an Array.
1063+
* `pointer` is an AbstractObject, and `pointee` is an Array
1064+
- if the `pointer` is a `LocalCompositeObject`, then `pointee` is a
1065+
Symbol representing the derefrerenced value.
10641066
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
10651067
`pointee` is a Symbol representing the dereferenced value.
10661068
"""
@@ -1092,16 +1094,14 @@ def expr_symbols(self):
10921094
for i in self.pointee.symbolic_shape[1:]))
10931095
ret.extend(self.pointer.free_symbols)
10941096
elif self.pointer.is_AbstractObject:
1095-
ret.extend([self.pointer, self.pointee.indexed])
1097+
if isinstance(self.pointer, LocalCompositeObject):
1098+
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
1099+
else:
1100+
ret.extend([self.pointer, self.pointee.indexed])
10961101
ret.extend(flatten(i.free_symbols
10971102
for i in self.pointee.symbolic_shape[1:]))
10981103
else:
1099-
# TODO: Might be uneccessary now
1100-
if isinstance(self.pointer, LocalCompositeObject) or \
1101-
issubclass(self.pointer._C_ctype, ctypes._Pointer):
1102-
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
1103-
else:
1104-
assert False, f"Unexpected pointer type {type(self.pointer)}"
1104+
assert False, f"Unexpected pointer type {type(self.pointer)}"
11051105

11061106
return tuple(filter_ordered(ret))
11071107

devito/ir/iet/visitors.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,18 @@ def visit_PointerCast(self, o):
515515

516516
def visit_Dereference(self, o):
517517
a0, a1 = o.functions
518-
if a0.is_AbstractFunction:
518+
# TODO: Temporary fix or fine? — ensures that all objects dereferenced from
519+
# a PETSc struct (e.g., `ctx0`) are handled correctly.
520+
# **Example**
521+
# Need this: struct dataobj *rhs_vec = ctx0->rhs_vec;
522+
# Not this: PetscScalar (* rhs)[rhs_vec->size[1]] =
523+
# (PetscScalar (*)[rhs_vec->size[1]]) ctx0;
524+
# This is the case when a1 is a LocalCompositeObject (i.e a1.is_AbstractObject)
525+
526+
if a1.is_AbstractObject:
527+
rvalue = f'{a1.name}->{a0._C_name}'
528+
lvalue = self._gen_value(a0, 0)
529+
elif a0.is_AbstractFunction:
519530
cstr = self.ccode(a0.indexed._C_typedata)
520531

521532
try:

tests/test_iet.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.ir.iet import (
1111
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration, List,
1212
KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations,
13-
make_efunc, retrieve_iteration_tree, Transformer, Callback, Definition, FindNodes
13+
make_efunc, retrieve_iteration_tree, Transformer, Callback, FindNodes
1414
)
1515
from devito.ir import SymbolRegistry
1616
from devito.passes.iet.engine import Graph
@@ -505,16 +505,16 @@ def test_codegen_quality0():
505505
assert foo1.parameters[0] is a
506506

507507

508-
def test_special_array_definition():
508+
# def test_special_array_definition():
509509

510-
class MyArray(Array):
511-
is_extern = True
512-
_data_alignment = False
510+
# class MyArray(Array):
511+
# is_extern = True
512+
# _data_alignment = False
513513

514-
dim = CustomDimension(name='d', symbolic_size=String(''))
515-
a = MyArray(name='a', dimensions=dim, scope='shared', dtype=np.uint8)
514+
# dim = CustomDimension(name='d', symbolic_size=String(''))
515+
# a = MyArray(name='a', dimensions=dim, scope='shared', dtype=np.uint8)
516516

517-
assert str(Definition(a)) == "extern unsigned char a[];"
517+
# assert str(Definition(a)) == "extern unsigned char a[];"
518518

519519

520520
def test_list_inline():

0 commit comments

Comments
 (0)