Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions sqlmesh/utils/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,53 @@ def upstream(self, node: T) -> t.Set[T]:

return self._upstream[node]

def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]:
"""Find the exact cycle path using DFS when a cycle is detected.
Args:
nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies
Returns:
List of nodes forming the cycle path, or None if no cycle found
"""
if not nodes_in_cycle:
return None

# Use DFS to find a cycle path
visited: t.Set[T] = set()
path: t.List[T] = []

def dfs(node: T) -> t.Optional[t.List[T]]:
if node in path:
# Found a cycle - extract the cycle path
cycle_start = path.index(node)
return path[cycle_start:] + [node]

if node in visited:
return None

visited.add(node)
path.append(node)

# Only follow edges to nodes that are still in the unprocessed set
for neighbor in nodes_in_cycle.get(node, set()):
if neighbor in nodes_in_cycle:
cycle = dfs(neighbor)
if cycle:
return cycle

path.pop()
return None

# Try starting DFS from each unvisited node
for start_node in nodes_in_cycle:
if start_node not in visited:
cycle = dfs(start_node)
if cycle:
return cycle[:-1] # Remove the duplicate node at the end

return None

@property
def roots(self) -> t.Set[T]:
"""Returns all nodes in the graph without any upstream dependencies."""
Expand All @@ -125,23 +172,31 @@ def sorted(self) -> t.List[T]:
next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}

if not next_nodes:
# Sort cycle candidates to make the order deterministic
cycle_candidates_msg = (
"\nPossible candidates to check for circular references: "
+ ", ".join(str(node) for node in sorted(cycle_candidates))
)
# A cycle was detected - find the exact cycle path
cycle_path = self._find_cycle_path(unprocessed_nodes)

if last_processed_nodes:
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
str(node) for node in last_processed_nodes
last_processed_msg = ""
if cycle_path:
node_output = " ->\n".join(
str(node) for node in (cycle_path + [cycle_path[0]])
)
cycle_msg = f"\nCycle:\n{node_output}"
else:
last_processed_msg = ""
# Fallback message in case a cycle can't be found
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there scenarios where a cycle won't be found? I'm wondering if we can remove the else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No known scenarios at this time but it seems like a safe fallback to have in place for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to not have a fallback to get user feedback on what scenario hits it more quickly?

cycle_candidates_msg = (
"\nPossible candidates to check for circular references: "
+ ", ".join(str(node) for node in sorted(cycle_candidates))
)
cycle_msg = cycle_candidates_msg
if last_processed_nodes:
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
str(node) for node in last_processed_nodes
)

raise SQLMeshError(
"Detected a cycle in the DAG. "
"Please make sure there are no circular references between nodes."
f"{last_processed_msg}{cycle_candidates_msg}"
f"{last_processed_msg}{cycle_msg}"
)

for node in next_nodes:
Expand Down
11 changes: 5 additions & 6 deletions tests/utils/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def test_sorted_with_cycles():

expected_error_message = (
"Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n"
"Last nodes added to the DAG: c\n"
"Possible candidates to check for circular references: d, e"
"Cycle:\nd ->\ne ->\nd"
)

assert expected_error_message == str(ex.value)
Expand All @@ -70,7 +69,7 @@ def test_sorted_with_cycles():

expected_error_message = (
"Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n"
"Possible candidates to check for circular references: a, b, c"
"Cycle:\na ->\nb ->\nc ->\na"
)

assert expected_error_message == str(ex.value)
Expand All @@ -81,11 +80,11 @@ def test_sorted_with_cycles():
dag.sorted

expected_error_message = (
"Last nodes added to the DAG: c\n"
+ "Possible candidates to check for circular references: b, d"
"Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n"
+ "Cycle:\nb ->\nd ->\nb"
)

assert expected_error_message in str(ex.value)
assert expected_error_message == str(ex.value)


def test_reversed_graph():
Expand Down