Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 87 additions & 10 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sqlmesh.dbt.test import TestConfig
from sqlmesh.dbt.util import DBT_VERSION
from sqlmesh.utils import AttributeDict
from sqlmesh.utils.dag import find_path_with_dfs
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.pydantic import field_validator

Expand Down Expand Up @@ -270,9 +271,10 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:

def fix_circular_test_refs(self, context: DbtContext) -> None:
"""
Checks for direct circular references between two models and moves the test to the downstream
model if found. This addresses the most common circular reference - relationship tests in both
directions. In the future, we may want to increase coverage by checking for indirect circular references.
Checks for circular references between models and moves tests to break cycles.
This handles both direct circular references (A -> B -> A) and indirect circular
references (A -> B -> C -> A). Tests are moved to the model that appears latest
in the dependency chain to ensure the cycle is broken.

Args:
context: The dbt context this model resides within.
Expand All @@ -284,16 +286,91 @@ def fix_circular_test_refs(self, context: DbtContext) -> None:
for ref in test.dependencies.refs:
if ref == self.name or ref in self.dependencies.refs:
continue
model = context.refs[ref]
if (
self.name in model.dependencies.refs
or self.name in model.tests_ref_source_dependencies.refs
):

# Check if moving this test would create or maintain a cycle
cycle_path = self._find_circular_path(ref, context, set())
if cycle_path:
# Find the model in the cycle that should receive the test
# We want to move to the model that appears latest in the dependency chain
target_model_name = self._select_target_model_for_test(cycle_path, context)
target_model = context.refs[target_model_name]

logger.info(
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
f"Moving test '{test.name}' from model '{self.name}' to '{target_model_name}' "
f"to avoid circular reference through path: {' -> '.join(cycle_path)}"
)
model.tests.append(test)
target_model.tests.append(test)
self.tests.remove(test)
break

def _find_circular_path(
self, ref: str, context: DbtContext, visited: t.Set[str]
) -> t.Optional[t.List[str]]:
"""
Find if there's a circular dependency path from ref back to this model.

Args:
ref: The model name to start searching from
context: The dbt context
visited: Set of model names already visited in this path

Returns:
List of model names forming the circular path, or None if no cycle exists
"""
# Build a graph of all models and their dependencies from the context
graph: t.Dict[str, t.Set[str]] = {}

def build_graph_from_node(node_name: str, current_visited: t.Set[str]) -> None:
if node_name in current_visited or node_name in graph:
return
Comment thread
eakmanrq marked this conversation as resolved.
Outdated
current_visited.add(node_name)

model = context.refs[node_name]
# Include both direct model dependencies and test dependencies
all_refs = model.dependencies.refs | model.tests_ref_source_dependencies.refs
graph[node_name] = all_refs.copy()

# Recursively build graph for dependencies
for dep in all_refs:
build_graph_from_node(dep, current_visited)

# Build the graph starting from the ref, including visited nodes to avoid infinite recursion
build_graph_from_node(ref, visited.copy())

# Add self.name to the graph if it's not already there
if self.name not in graph:
graph[self.name] = set()

# Use the shared DFS function to find path from ref to self.name
return find_path_with_dfs(graph, start_node=ref, target_node=self.name)

def _select_target_model_for_test(self, cycle_path: t.List[str], context: DbtContext) -> str:
"""
Select which model in the cycle should receive the test.
We select the model that has the most downstream dependencies in the cycle

Args:
cycle_path: List of model names in the circular dependency path
context: The dbt context

Returns:
Name of the model that should receive the test
"""
# Count how many other models in the cycle each model depends on
dependency_counts = {}

for model_name in cycle_path:
model = context.refs[model_name]
all_refs = model.dependencies.refs | model.tests_ref_source_dependencies.refs
count = len([ref for ref in all_refs if ref in cycle_path])
dependency_counts[model_name] = count

# Return the model with the fewest dependencies within the cycle
# (i.e., the most downstream model in the cycle)
if dependency_counts:
return min(dependency_counts, key=dependency_counts.get) # type: ignore
Comment thread
eakmanrq marked this conversation as resolved.
Outdated
# Fallback to the last model in the path
return cycle_path[-1]

@property
def sqlmesh_config_fields(self) -> t.Set[str]:
Expand Down
114 changes: 103 additions & 11 deletions sqlmesh/utils/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,82 @@
T = t.TypeVar("T", bound=t.Hashable)


def find_path_with_dfs(
graph: t.Dict[T, t.Set[T]],
start_node: t.Optional[T] = None,
target_node: t.Optional[T] = None,
) -> t.Optional[t.List[T]]:
"""
Find a path in a graph using depth-first search.

This function can be used for two main purposes:
1. Find any cycle in a cyclic subgraph (when target_node is None)
2. Find a specific path from start_node to target_node

Args:
graph: Dictionary mapping nodes to their dependencies/neighbors
start_node: Optional specific node to start the search from
target_node: Optional target node to search for. If None, finds any cycle

Returns:
List of nodes forming the path, or None if no path/cycle found
"""
if not graph:
return None

visited: t.Set[T] = set()
rec_stack: t.Set[T] = set()
path: t.List[T] = []

def dfs(node: T) -> t.Optional[t.List[T]]:
if target_node is None:
# Cycle detection mode: look for any node in recursion stack
if node in rec_stack:
cycle_start = path.index(node)
return path[cycle_start:] + [node]
else:
# Target search mode: look for specific target
if node == target_node:
return [node]

if node in visited:
return None

visited.add(node)
rec_stack.add(node)
path.append(node)

# Follow edges to neighbors
for neighbor in graph.get(node, set()):
if neighbor in graph: # Only follow edges to nodes in our subgraph
Comment thread
eakmanrq marked this conversation as resolved.
Outdated
result = dfs(neighbor)
if result:
if target_node is None:
# Cycle detection: return the cycle as-is
return result
# Target search: prepend current node to path
return [node] + result

rec_stack.remove(node)
path.pop()
return None

# Determine which nodes to try as starting points
start_nodes = [start_node] if start_node is not None else list(graph.keys())

for node in start_nodes:
if node not in visited and node in graph:
result = dfs(node)
if result:
if target_node is None:
# Cycle detection: remove duplicate node at end
return result[:-1] if len(result) > 1 and result[0] == result[-1] else result
# Target search: return path as-is
return result

return None


class DAG(t.Generic[T]):
def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
self._dag: t.Dict[T, t.Set[T]] = {}
Expand Down Expand Up @@ -99,6 +175,17 @@ def upstream(self, node: T) -> t.Set[T]:

return self._upstream[node]

def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]:
"""Find the exact cycle path using DFS when a cycle is detected.

Args:
nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies

Returns:
List of nodes forming the cycle path, or None if no cycle found
"""
return find_path_with_dfs(nodes_in_cycle)

@property
def roots(self) -> t.Set[T]:
"""Returns all nodes in the graph without any upstream dependencies."""
Expand All @@ -125,23 +212,28 @@ def sorted(self) -> t.List[T]:
next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}

if not next_nodes:
# Sort cycle candidates to make the order deterministic
cycle_candidates_msg = (
"\nPossible candidates to check for circular references: "
+ ", ".join(str(node) for node in sorted(cycle_candidates))
)
# A cycle was detected - find the exact cycle path
cycle_path = self._find_cycle_path(unprocessed_nodes)

if last_processed_nodes:
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
str(node) for node in last_processed_nodes
)
last_processed_msg = ""
if cycle_path:
cycle_msg = f"\nCycle: {' -> '.join(str(node) for node in cycle_path)} -> {cycle_path[0]}"
else:
last_processed_msg = ""
# Fallback message in case a cycle can't be found
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there scenarios where a cycle won't be found? I'm wondering if we can remove the else.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No known scenarios at this time but it seems like a safe fallback to have in place for now.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to not have a fallback to get user feedback on what scenario hits it more quickly?

cycle_candidates_msg = (
"\nPossible candidates to check for circular references: "
+ ", ".join(str(node) for node in sorted(cycle_candidates))
)
cycle_msg = cycle_candidates_msg
if last_processed_nodes:
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
str(node) for node in last_processed_nodes
)

raise SQLMeshError(
"Detected a cycle in the DAG. "
"Please make sure there are no circular references between nodes."
f"{last_processed_msg}{cycle_candidates_msg}"
f"{last_processed_msg}{cycle_msg}"
)

for node in next_nodes:
Expand Down
120 changes: 120 additions & 0 deletions tests/dbt/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,126 @@ def test_model_test_circular_references() -> None:
assert downstream_model.tests == [downstream_test, upstream_test]


def test_model_test_indirect_circular_references() -> None:
"""Test detection and resolution of indirect circular references through test dependencies"""
model_a = ModelConfig(name="model_a") # No dependencies
model_b = ModelConfig(
name="model_b", dependencies=Dependencies(refs={"model_a"})
) # B depends on A
model_c = ModelConfig(
name="model_c", dependencies=Dependencies(refs={"model_b"})
) # C depends on B

context = DbtContext(_refs={"model_a": model_a, "model_b": model_b, "model_c": model_c})

# Test on model_a that references model_c (creates indirect cycle through test dependencies)
# The cycle would be: model_a (via test) -> model_c -> model_b -> model_a
test_a_refs_c = TestConfig(
name="test_a_refs_c",
sql="",
dependencies=Dependencies(refs={"model_a", "model_c"}), # Test references both A and C
)

# Place tests that would create indirect cycles when combined with model dependencies
model_a.tests = [test_a_refs_c]
assert model_b.tests == []
assert model_c.tests == []

# Fix circular references on model_a
model_a.fix_circular_test_refs(context)
# The test should be moved from model_a to break the indirect cycle down to model c
assert model_a.tests == []
assert test_a_refs_c in model_c.tests


def test_model_test_complex_indirect_circular_references() -> None:
"""Test detection and resolution of more complex indirect circular references through test dependencies"""
# Create models with a longer linear dependency chain (no cycles in models themselves)
# A -> B -> C -> D (B depends on A, C depends on B, D depends on C)
model_a = ModelConfig(name="model_a") # No dependencies
model_b = ModelConfig(
name="model_b", dependencies=Dependencies(refs={"model_a"})
) # B depends on A
model_c = ModelConfig(
name="model_c", dependencies=Dependencies(refs={"model_b"})
) # C depends on B
model_d = ModelConfig(
name="model_d", dependencies=Dependencies(refs={"model_c"})
) # D depends on C

context = DbtContext(
_refs={"model_a": model_a, "model_b": model_b, "model_c": model_c, "model_d": model_d}
)

# Test on model_a that references model_d (creates long indirect cycle through test dependencies)
# The cycle would be: model_a (via test) -> model_d -> model_c -> model_b -> model_a
test_a_refs_d = TestConfig(
name="test_a_refs_d",
sql="",
dependencies=Dependencies(refs={"model_a", "model_d"}), # Test references both A and D
)

# Place tests that would create indirect cycles when combined with model dependencies
model_a.tests = [test_a_refs_d]
model_b.tests = []
assert model_c.tests == []
assert model_d.tests == []

# Fix circular references on model_a
model_a.fix_circular_test_refs(context)
# The test should be moved from model_a to break the long indirect cycle down to model_d
assert model_a.tests == []
assert model_d.tests == [test_a_refs_d]

# Test on model_b that references model_d (creates indirect cycle through test dependencies)
# The cycle would be: model_b (via test) -> model_d -> model_c -> model_b
test_b_refs_d = TestConfig(
name="test_b_refs_d",
sql="",
dependencies=Dependencies(refs={"model_b", "model_d"}), # Test references both B and D
)
model_a.tests = []
model_b.tests = [test_b_refs_d]
model_c.tests = []
model_d.tests = []

model_b.fix_circular_test_refs(context)
assert model_a.tests == []
assert model_b.tests == []
assert model_c.tests == []
assert model_d.tests == [test_b_refs_d]

# Do both at the same time
model_a.tests = [test_a_refs_d]
model_b.tests = [test_b_refs_d]
model_c.tests = []
model_d.tests = []

model_a.fix_circular_test_refs(context)
model_b.fix_circular_test_refs(context)
assert model_a.tests == []
assert model_b.tests == []
assert model_c.tests == []
assert model_d.tests == [test_a_refs_d, test_b_refs_d]

# Test A -> B -> C cycle and make sure test ends up with C
test_a_refs_c = TestConfig(
name="test_a_refs_c",
sql="",
dependencies=Dependencies(refs={"model_a", "model_c"}), # Test references both A and C
)
model_a.tests = [test_a_refs_c]
model_b.tests = []
model_c.tests = []
model_d.tests = []

model_a.fix_circular_test_refs(context)
assert model_a.tests == []
assert model_b.tests == []
assert model_c.tests == [test_a_refs_c]
assert model_d.tests == []


@pytest.mark.slow
def test_load_invalid_ref_audit_constraints(
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
Expand Down
Loading