Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -661,29 +661,31 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:

def _visit_if_branch_arg(
self,
if_sdfg: dace.SDFG,
if_branch_state: dace.SDFGState,
sdfg: dace.SDFG,
state: dace.SDFGState,
param_name: str,
arg: IteratorExpr | DataExpr,
deref_on_input_memlet: bool,
if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr],
input_memlets: dict[str, MemletExpr | ValueExpr],
) -> IteratorExpr | ValueExpr:
"""
Helper method to be called by `_visit_if_branch()` to visit the input arguments.

Args:
if_sdfg: The nested SDFG where the if expression is lowered.
if_branch_state: The state inside the nested SDFG where the if branch is lowered.
sdfg: The SDFG where the if expression is lowered.
state: The state inside the given SDFG where the if branch is lowered.
param_name: The parameter name of the input argument.
arg: The input argument expression.
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet.
if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function.
input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function.
"""
use_full_shape = False
if isinstance(arg, (MemletExpr, ValueExpr)):
field_dims = []
arg_desc = arg.dc_node.desc(self.sdfg)
arg_expr = arg
elif isinstance(arg, IteratorExpr):
field_dims = [dim for dim, _ in arg.field_domain]
arg_desc = arg.field.desc(self.sdfg)
if deref_on_input_memlet:
# If the iterator is just dereferenced inside the branch state,
Expand All @@ -707,37 +709,50 @@ def _visit_if_branch_arg(
inner_desc = arg_desc.clone()
inner_desc.transient = False
elif isinstance(arg.gt_dtype, ts.ScalarType):
inner_desc = dace.data.Scalar(arg_desc.dtype)
if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1:
# TODO(edopao): we cannot use a scalar because of an issue in gpu codegen,
# which leads to compilation error: cannot convert 'const double' to 'const double*'
inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,))
else:
inner_desc = dace.data.Scalar(arg_desc.dtype)
else:
# for list of values, we retrieve the local size from the corresponding offset
assert arg.gt_dtype.offset_type is not None
offset_provider_type = self.subgraph_builder.get_offset_provider_type(
arg.gt_dtype.offset_type.value
local_dim = arg.gt_dtype.offset_type
assert local_dim is not None
assert isinstance(
self.subgraph_builder.get_offset_provider_type(local_dim.value),
gtx_common.NeighborConnectivityType,
)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
# find position of the local dimension in the field layout
assert isinstance(arg_desc, dace.data.Array)
assert all(dim.kind != gtx_common.DimensionKind.LOCAL for dim in field_dims)
extended_dims = gtx_common.order_dimensions([*field_dims, local_dim])
local_dim_pos = extended_dims.index(local_dim)
inner_desc = dace.data.Array(
dtype=arg_desc.dtype, shape=[offset_provider_type.max_neighbors]
dtype=arg_desc.dtype,
shape=(arg_desc.shape[local_dim_pos],),
strides=(arg_desc.strides[local_dim_pos],),
)

if param_name in if_sdfg.arrays:
if param_name in sdfg.arrays:
# the data desciptor was added by the visitor of the other branch expression
assert if_sdfg.data(param_name) == inner_desc
assert sdfg.data(param_name) == inner_desc
else:
if_sdfg.add_datadesc(param_name, inner_desc)
if_sdfg_input_memlets[param_name] = arg_expr
sdfg.add_datadesc(param_name, inner_desc)
input_memlets[param_name] = arg_expr
Comment on lines -722 to +742
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: If this function is called in a uniform way I would expect that the data descriptor is either already added to the sdfg or not in all cases. So this if locks a bit fishy to me but maybe I miss something.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The origin of this if is that I am adding globals as I encounter them while visiting the if branches. I am visiting one branch at a time, so it could be that some globals have already been added in the previous branch.


inner_node = if_branch_state.add_access(param_name)
inner_node = state.add_access(param_name)
if isinstance(arg, IteratorExpr) and use_full_shape:
return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices)
else:
return ValueExpr(inner_node, arg.gt_dtype)

def _visit_if_branch(
self,
if_sdfg: dace.SDFG,
if_branch_state: dace.SDFGState,
sdfg: dace.SDFG,
state: dace.SDFGState,
expr: gtir.Expr,
if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr],
input_memlets: dict[str, MemletExpr | ValueExpr],
direct_deref_iterators: Iterable[str],
) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]:
"""
Expand All @@ -746,58 +761,62 @@ def _visit_if_branch(
This function is called by `_visit_if()` for each if-branch.

Args:
if_sdfg: The nested SDFG where the if expression is lowered.
if_branch_state: The state inside the nested SDFG where the if branch is lowered.
sdfg: The SDFG where the if expression is lowered.
state: The state inside the given SDFG where the if branch is lowered.
expr: The if branch expression to lower.
if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function.
input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function.
direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift.

Returns:
A tuple containing:
- the list of input edges for the parent dataflow
- the tree representation of output data, in the form of a tuple of data edges.
"""
assert if_branch_state in if_sdfg.states()
assert state in sdfg.states()

lambda_args = []
lambda_params = []
lambda_args: list[MaybeNestedInTuple[IteratorExpr | DataExpr]] = []
lambda_params: list[gtir.Sym] = []
for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()):
arg = self.symbol_map[pname]
if isinstance(arg, tuple):
inner_arg: MaybeNestedInTuple[IteratorExpr | DataExpr]
if isinstance(arg, SymbolExpr):
psymbol = im.sym(pname, gtx_dace_args.as_itir_type(arg.dc_dtype))
inner_arg = arg
elif isinstance(arg, tuple):
ptype = get_tuple_type(arg) # type: ignore[arg-type]
psymbol = im.sym(pname, ptype)
psymbol_tree = gtir_to_sdfg_utils.make_symbol_tree(pname, ptype)
deref_on_input_memlet = pname in direct_deref_iterators
inner_arg = gtx_utils.tree_map(
lambda tsym, targ, deref_on_input_memlet=deref_on_input_memlet: (
self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
sdfg,
state,
str(tsym.id),
targ,
deref_on_input_memlet,
if_sdfg_input_memlets,
input_memlets,
)
)
)(psymbol_tree, arg)
else:
psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr]
psymbol = im.sym(pname, arg.gt_dtype)
deref_on_input_memlet = pname in direct_deref_iterators
inner_arg = self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
sdfg,
state,
pname,
arg,
deref_on_input_memlet,
if_sdfg_input_memlets,
input_memlets,
)
lambda_args.append(inner_arg)
lambda_params.append(psymbol)

# visit each branch of the if-statement as if it was a Lambda node
lambda_node = gtir.Lambda(params=lambda_params, expr=expr)
input_edges, output_tree = translate_lambda_to_dataflow(
if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args
sdfg, state, self.subgraph_builder, lambda_node, lambda_args
)

return input_edges, output_tree
Expand All @@ -818,9 +837,16 @@ def _visit_if_branch_result(
# If the result is currently written to a transient node, inside the nested SDFG,
# we need to allocate a non-transient data node.
result_desc = edge.result.dc_node.desc(sdfg)
output_desc = result_desc.clone()
output_desc.transient = False
output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True)
if isinstance(sym.type, ts.ScalarType) and isinstance(result_desc, dace.data.Array):
# TODO(edopao): a scalar should not be represented as an array, but
# currently this can happen because of an issue workaround, see the
# todo comment above in `_visit_if_branch_arg()`.
assert len(result_desc.shape) == 1 and result_desc.shape[0] == 1
_, output_desc = sdfg.add_scalar(output_data, result_desc.dtype)
else:
output_desc = result_desc.clone()
output_desc.transient = False
output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True)
output_node = state.add_access(output_data)
state.add_nedge(
edge.result.dc_node,
Expand Down Expand Up @@ -868,6 +894,13 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("if_stmt"))
nsdfg.debuginfo = gtir_to_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo)

# add connectivities
for aname, adesc in self.sdfg.arrays.items():
if gtx_dace_args.is_connectivity_identifier(aname):
adesc = adesc.clone()
adesc.transient = True
nsdfg.add_datadesc(aname, adesc)

# create states inside the nested SDFG for the if-branches
if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region, ensure_unique_name=True)
Expand Down Expand Up @@ -948,14 +981,21 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp

outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))}

# map the connectivities that were used inside the nested SDFG
used_connectivities = [
aname
for aname, adesc in nsdfg.arrays.items()
if gtx_dace_args.is_connectivity_identifier(aname) and not adesc.transient
]

# all free symbols are mapped to the symbols available in parent SDFG
nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols}
if isinstance(condition_value, SymbolExpr):
nsdfg_symbols_mapping["__cond"] = condition_value.value
nsdfg_node = self.state.add_nested_sdfg(
nsdfg,
inputs=set(input_memlets.keys()),
outputs=outputs,
inputs=sorted(used_connectivities | input_memlets.keys()),
outputs=sorted(outputs),
symbol_mapping=nsdfg_symbols_mapping,
)

Expand All @@ -971,6 +1011,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
self.sdfg.make_array_memlet(input_expr.dc_node.data),
)

for conn in used_connectivities:
desc = self.sdfg.data(conn)
desc.transient = False
self._add_input_data_edge(
self.state.add_access(conn),
dace_subsets.Range.from_array(desc),
nsdfg_node,
conn,
)

return (
gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result)
if isinstance(result, tuple)
Expand Down Expand Up @@ -1904,6 +1954,9 @@ def visit_Literal(self, node: gtir.Literal) -> SymbolExpr:
dc_dtype = gtx_dace_args.as_dace_type(node.type)
return SymbolExpr(node.value, dc_dtype)

def visit_OffsetLiteral(self, node: gtir.OffsetLiteral) -> SymbolExpr:
return SymbolExpr(node.value, gtir_to_sdfg_types.INDEX_DTYPE)

def visit_SymRef(self, node: gtir.SymRef) -> MaybeNestedInTuple[IteratorExpr | DataExpr]:
param = str(node.id)
if param in self.symbol_map:
Expand All @@ -1918,7 +1971,7 @@ def translate_lambda_to_dataflow(
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.DataflowBuilder,
node: gtir.Lambda,
args: Sequence[MaybeNestedInTuple[IteratorExpr | MemletExpr | ValueExpr]],
args: Sequence[MaybeNestedInTuple[IteratorExpr | DataExpr]],
) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]:
"""
Entry point to visit a `Lambda` node and lower it to a dataflow graph,
Expand Down