Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
f86b2c8
wip
edopao Jan 30, 2026
782f5ff
edit
edopao Feb 2, 2026
abdaa39
Merge branch 'main' into dace_mixed_ir
edopao Apr 8, 2026
2d24ba9
enable warning
edopao Apr 9, 2026
e3ac9de
edit
edopao Apr 13, 2026
d33bd83
fix sdfg lowering
edopao Apr 15, 2026
3d770bd
fix bindings test for unstructured grid
edopao Apr 15, 2026
7a70de9
remove noqa comment
edopao Apr 16, 2026
a367553
add test coverage
edopao Apr 16, 2026
33d60b7
add test marker for sliced out argument
edopao Apr 16, 2026
75e062a
remove initial from np.sum()
edopao Apr 16, 2026
a22ff35
add xfail for embedded backend
edopao May 5, 2026
0d17ef6
edit
edopao May 5, 2026
d270266
Merge branch 'main' into dace_mixed_ir
edopao May 5, 2026
34738e9
Merge branch 'compile_time_domain_enable_warning' into dace_mixed_ir
edopao May 5, 2026
a6d00a9
add backends to test matrix
edopao May 5, 2026
55f26f2
Switched to new GPU Codegen.
philip-paul-mueller May 6, 2026
574ed87
Was this the error.
philip-paul-mueller May 6, 2026
7a96e6b
This should be the thing.
philip-paul-mueller May 6, 2026
fd463b3
edit
edopao May 6, 2026
3b075ac
edit
edopao May 6, 2026
ae6db4e
Fix stride of local dimension in lowering of if-expressions
edopao May 6, 2026
dbbfd52
edit
edopao May 6, 2026
f3a02ba
Merge branch 'dace_fix_stride_if_region' into dace_mixed_ir
edopao May 6, 2026
2cd7728
edit
edopao May 6, 2026
48a8f1c
Let's try this fix.
philip-paul-mueller May 7, 2026
6d52e24
Let's try this fix.
philip-paul-mueller May 7, 2026
ecd5780
edit
edopao May 7, 2026
246cc24
Merge remote-tracking branch 'origin/main' into dace_new_gpu_codegen
philip-paul-mueller May 7, 2026
a8fa800
edit
edopao May 7, 2026
80c45e6
edit
edopao May 7, 2026
ba91d43
edit
edopao May 7, 2026
f560107
apply review comments
edopao May 7, 2026
5abea62
Merge branch 'dace_fix_stride_if_region' into dace_mixed_ir
edopao May 7, 2026
c174e5a
fix
edopao May 8, 2026
057ac27
Merge branch 'dace_fix_stride_if_region' into dace_mixed_ir
edopao May 8, 2026
17f68ad
Merge remote-tracking branch 'philip/dace_new_gpu_codegen' into dace_…
edopao May 8, 2026
07a5877
remove workaround
edopao May 8, 2026
bb33fd4
Merge branch 'main' into dace_mixed_ir
edopao May 11, 2026
705f512
re-enable test
edopao May 12, 2026
923ab93
preserve order of connectivities
edopao May 12, 2026
71f0d3a
Merge branch 'main' into dace_mixed_ir
edopao May 12, 2026
af56259
remove workaround
edopao May 12, 2026
39fca69
Revert "re-enable test"
edopao May 12, 2026
1621acf
fix stride gpu transformation
edopao May 13, 2026
1217498
edit
edopao May 13, 2026
f305c1c
edit
edopao May 13, 2026
fb7c6f7
edit
edopao May 13, 2026
e124ebf
edit
edopao May 13, 2026
c328157
edit
edopao May 13, 2026
e60bd28
edit
edopao May 13, 2026
609aa97
edit
edopao May 15, 2026
7a0a792
fix validation flag
edopao May 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dace-cartesian = [
'dace>=1.0.2' # refined in [tool.uv.sources]
]
dace-next = [
'dace==43!2026.04.27' # uses custom index at 'https://github.com/GridTools/pypi'
'dace==2.3.5' # uses custom index at 'https://github.com/GridTools/pypi'
]
dev = [
{include-group = 'build'},
Expand Down Expand Up @@ -486,7 +486,7 @@ url = 'https://gridtools.github.io/pypi/'
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/GridTools/dace", branch = "romanc/stree-v2", group = "dace-cartesian"},
{index = "gridtools", group = "dace-next"}
{git = "https://github.com/philip-paul-mueller/dace", branch = "phimuell__new-gpu-codegen-dev", group = "dace-next"}
]

# -- versioningit --
Expand Down
10 changes: 9 additions & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def apply_common_transforms(
unroll_reduce=False,
common_subexpression_elimination=True,
force_inline_lambda_args=False,
transform_concat_where_to_as_fieldop=True,
#: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for
#: more details.
symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None,
Expand Down Expand Up @@ -186,7 +187,8 @@ def apply_common_transforms(
ir = prune_empty_concat_where.prune_empty_concat_where(ir)
ir = remove_broadcast.RemoveBroadcast.apply(ir)

ir = concat_where.transform_to_as_fieldop(ir)
if transform_concat_where_to_as_fieldop:
ir = concat_where.transform_to_as_fieldop(ir)

for _ in range(10):
inlined = ir
Expand Down Expand Up @@ -258,6 +260,12 @@ def apply_common_transforms(
ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args
)

ir = infer_domain.infer_program(
ir,
offset_provider=offset_provider,
symbolic_domain_sizes=symbolic_domain_sizes,
)

assert isinstance(ir, itir.Program)
return ir

Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/program_processors/runners/dace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
make_dace_backend,
run_dace_cpu,
run_dace_cpu_cached,
run_dace_cpu_gt,
run_dace_cpu_gt_noopt,
run_dace_cpu_noopt,
run_dace_gpu,
run_dace_gpu_cached,
Expand All @@ -24,6 +26,8 @@
"make_dace_backend",
"run_dace_cpu",
"run_dace_cpu_cached",
"run_dace_cpu_gt",
"run_dace_cpu_gt_noopt",
"run_dace_cpu_noopt",
"run_dace_gpu",
"run_dace_gpu_cached",
Expand Down
184 changes: 122 additions & 62 deletions src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def connect(
map_exit: Optional[dace_nodes.MapExit],
dest: dace_nodes.AccessNode,
dest_subset: dace_subsets.Range,
allow_removal_of_last_node: bool,
) -> bool:
"""Create a connection to the `dest` node, writing the given `dest_subset`.

Expand All @@ -280,24 +281,27 @@ def connect(
dest_desc = self.result.dc_node.desc(self.state)
write_edge = self.state.in_edges(self.result.dc_node)[0]

# Check the kind of node which writes the result
if isinstance(write_edge.src, dace_nodes.Tasklet):
# The temporary data written by a tasklet can be safely deleted.
assert map_exit is not None
remove_last_node = True
elif isinstance(write_edge.src, dace_nodes.NestedSDFG):
if isinstance(dest_desc, dace.data.Scalar):
# We keep scalar temporary storage, as a general rule, since it
# does not affect performance of the generated code. This scalar
# is only required in some cases, e.g. for nested SDFGs implementing
# reduction, which use a WCR memlet for the reduction operation.
if allow_removal_of_last_node:
if self.state.out_degree(self.result.dc_node) != 0:
remove_last_node = False
else:
# We remove the transient array on the output connection of a nested
# SDFG and write directly to the destination node.
# The caller is responsible to propagate the strides of the destination
# array to the array inside the nested SDFG.
elif isinstance(write_edge.src, dace_nodes.Tasklet):
# The temporary data written by a tasklet can be safely deleted.
remove_last_node = True
elif isinstance(write_edge.src, dace_nodes.NestedSDFG):
if isinstance(dest_desc, dace.data.Scalar):
# We keep scalar temporary storage, as a general rule, since it
# does not affect performance of the generated code. This scalar
# is only required in some cases, e.g. for nested SDFGs implementing
# reduction, which use a WCR memlet for the reduction operation.
remove_last_node = False
else:
# We remove the transient array on the output connection of a nested
# SDFG and write directly to the destination node.
# The caller is responsible to propagate the strides of the destination
# array to the array inside the nested SDFG.
remove_last_node = True
else:
remove_last_node = False
else:
remove_last_node = False

Expand Down Expand Up @@ -552,6 +556,51 @@ def _construct_tasklet_result(
),
)

def _visit_can_deref(self, node: gtir.FunCall) -> DataExpr:
assert isinstance(node.type, ts.ScalarType) and node.type.kind == ts.ScalarKind.BOOL
assert len(node.args) == 1
if not cpm.is_applied_shift(node.args[0]):
raise NotImplementedError(
f"Only `can_deref` of unstructured `shift` expressions is supported, got {node.args[0]}."
)
it = self._visit_shift(node.args[0])
index_values = {k: v for k, v in it.indices.items() if not isinstance(v, SymbolExpr)}
if len(index_values) == 0:
raise ValueError(f"Unexpected `can_deref` argument: {it}.")
can_deref_node, connector_mapping = self._add_tasklet(
name="can_deref",
inputs={f"index_{dim.value}" for dim in index_values},
outputs={"valid"},
code="valid = "
+ " and ".join(
f"index_{dim.value} != {gtx_common._DEFAULT_SKIP_VALUE}"
for dim in index_values.keys()
),
)
for dim, index_expr in index_values.items():
index_connector = f"index_{dim.value}"
if isinstance(index_expr, MemletExpr):
self._add_input_data_edge(
index_expr.dc_node,
index_expr.subset,
can_deref_node,
connector_mapping[index_connector],
)

else:
self._add_edge(
index_expr.dc_node,
None,
can_deref_node,
connector_mapping[index_connector],
dace.Memlet(data=index_expr.dc_node.data, subset="0"),
)
return self._construct_tasklet_result(
dc_dtype=dace.bool_,
src_node=can_deref_node,
src_connector=connector_mapping["valid"],
)

def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
"""
Visit a `deref` node, which represents dereferencing of an iterator.
Expand Down Expand Up @@ -709,12 +758,7 @@ def _visit_if_branch_arg(
inner_desc = arg_desc.clone()
inner_desc.transient = False
elif isinstance(arg.gt_dtype, ts.ScalarType):
if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1:
# TODO(edopao): we cannot use a scalar because of an issue in gpu codegen,
# which leads to compilation error: cannot convert 'const double' to 'const double*'
inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,))
else:
inner_desc = dace.data.Scalar(arg_desc.dtype)
inner_desc = dace.data.Scalar(arg_desc.dtype)
else:
# for list of values, we retrieve the local size from the corresponding offset
local_dim = arg.gt_dtype.offset_type
Expand Down Expand Up @@ -837,16 +881,9 @@ def _visit_if_branch_result(
# If the result is currently written to a transient node, inside the nested SDFG,
# we need to allocate a non-transient data node.
result_desc = edge.result.dc_node.desc(sdfg)
if isinstance(sym.type, ts.ScalarType) and isinstance(result_desc, dace.data.Array):
# TODO(edopao): a scalar should not be represented as an array, but
# currently this can happen because of an issue workaround, see the
# todo comment above in `_visit_if_branch_arg()`.
assert len(result_desc.shape) == 1 and result_desc.shape[0] == 1
_, output_desc = sdfg.add_scalar(output_data, result_desc.dtype)
else:
output_desc = result_desc.clone()
output_desc.transient = False
output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True)
output_desc = result_desc.clone()
output_desc.transient = False
output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True)
output_node = state.add_access(output_data)
state.add_nedge(
edge.result.dc_node,
Expand Down Expand Up @@ -1181,7 +1218,7 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr:
assert index_arg.dc_dtype in dace.dtypes.INTEGER_TYPES
src_subset = (
dace_subsets.Range(src_subset[:local_dim_index])
+ dace_subsets.Range.from_string(index_arg.value)
+ dace_subsets.Range.from_indices([index_arg.value])
+ dace_subsets.Range(src_subset[local_dim_index + 1 :])
)
if isinstance(src_arg, MemletExpr):
Expand Down Expand Up @@ -1586,7 +1623,6 @@ def _make_cartesian_shift(
self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr
) -> IteratorExpr:
"""Implements cartesian shift along one dimension."""
assert any(dim == offset_dim for dim, _ in it.field_domain)
new_index: SymbolExpr | ValueExpr
index_expr = it.indices[offset_dim]
if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr):
Expand Down Expand Up @@ -1658,9 +1694,9 @@ def _make_cartesian_shift(

def _make_dynamic_neighbor_offset(
self,
offset_expr: MemletExpr | ValueExpr,
offset_expr: MemletExpr | ValueExpr | SymbolExpr,
offset_table_node: dace_nodes.AccessNode,
origin_index: SymbolExpr,
origin_index: MemletExpr | ValueExpr | SymbolExpr,
) -> ValueExpr:
"""
Implements access to neighbor connectivity table by means of a tasklet node.
Expand All @@ -1669,33 +1705,56 @@ def _make_dynamic_neighbor_offset(
or computed by another tasklet (`DataExpr`).
"""
new_index_connector = "neighbor_index"
tasklet_node, connector_mapping = self._add_tasklet(
"dynamic_neighbor_offset",
{"table", "offset"},
{new_index_connector},
f"{new_index_connector} = table[{origin_index.value}, offset]",
)
if isinstance(offset_expr, SymbolExpr) and isinstance(origin_index, SymbolExpr):
tasklet_node, connector_mapping = self._add_tasklet(
"dynamic_neighbor_offset",
{"table"},
{new_index_connector},
f"{new_index_connector} = table[{origin_index.value}, {offset_expr.value}]",
)
elif isinstance(origin_index, SymbolExpr):
tasklet_node, connector_mapping = self._add_tasklet(
"dynamic_neighbor_offset",
{"table", "offset"},
{new_index_connector},
f"{new_index_connector} = table[{origin_index.value}, offset]",
)
elif isinstance(offset_expr, SymbolExpr):
tasklet_node, connector_mapping = self._add_tasklet(
"dynamic_neighbor_offset",
{"table", "origin"},
{new_index_connector},
f"{new_index_connector} = table[origin, {offset_expr.value}]",
)
else:
tasklet_node, connector_mapping = self._add_tasklet(
"dynamic_neighbor_offset",
{"table", "offset", "origin"},
{new_index_connector},
f"{new_index_connector} = table[origin, offset]",
)
self._add_input_data_edge(
offset_table_node,
dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)),
tasklet_node,
connector_mapping["table"],
)
if isinstance(offset_expr, MemletExpr):
self._add_input_data_edge(
offset_expr.dc_node,
offset_expr.subset,
tasklet_node,
connector_mapping["offset"],
)
else:
self._add_edge(
offset_expr.dc_node,
None,
tasklet_node,
connector_mapping["offset"],
dace.Memlet(data=offset_expr.dc_node.data, subset="0"),
)
for conn, input_expr in [("offset", offset_expr), ("origin", origin_index)]:
if isinstance(input_expr, MemletExpr):
self._add_input_data_edge(
input_expr.dc_node,
input_expr.subset,
tasklet_node,
connector_mapping[conn],
)
elif isinstance(input_expr, ValueExpr):
self._add_edge(
input_expr.dc_node,
None,
tasklet_node,
connector_mapping[conn],
dace.Memlet(data=input_expr.dc_node.data, subset="0"),
)

dc_dtype = offset_table_node.desc(self.sdfg).dtype
return self._construct_tasklet_result(
Expand All @@ -1710,17 +1769,14 @@ def _make_unstructured_shift(
offset_expr: DataExpr,
) -> IteratorExpr:
"""Implements shift in unstructured domain by means of a neighbor table."""
# make sure that the field can be dereferenced with the given connectivity type
assert any(dim == conn_type.codomain for dim, _ in it.field_domain)
# make sure that the iterator can access the connectivity table
assert conn_type.source_dim in it.indices
conn_source_index = it.indices[conn_type.source_dim]
assert isinstance(conn_source_index, SymbolExpr)

shifted_indices = {
dim: idx for dim, idx in it.indices.items() if dim != conn_type.source_dim
}
if isinstance(offset_expr, SymbolExpr):
if isinstance(offset_expr, SymbolExpr) and isinstance(conn_source_index, SymbolExpr):
# use memlet to retrieve the neighbor index
shifted_indices[conn_type.codomain] = MemletExpr(
dc_node=conn_node,
Expand Down Expand Up @@ -1869,7 +1925,10 @@ def _visit_tuple_get(
return tuple_fields[index]

def visit_FunCall(self, node: gtir.FunCall) -> MaybeNestedInTuple[IteratorExpr | DataExpr]:
if cpm.is_call_to(node, "deref"):
if cpm.is_call_to(node, "can_deref"):
return self._visit_can_deref(node)

elif cpm.is_call_to(node, "deref"):
return self._visit_deref(node)

elif cpm.is_call_to(node, "if_"):
Expand Down Expand Up @@ -2003,6 +2062,7 @@ def translate_lambda_to_dataflow(
flat_arg_nodes = (
x.field if isinstance(x, IteratorExpr) else x.dc_node # type: ignore[attr-defined]
for x in gtx_utils.flatten_nested_tuple(tuple(args))
if not isinstance(x, SymbolExpr)
)
state.remove_nodes_from([node for node in flat_arg_nodes if state.degree(node) == 0])

Expand Down
Loading
Loading