Skip to content

Commit bc8d3d7

Browse files
Feat(lsp): Add support for go to and find all References for CTEs (#4652)
1 parent 40bd72f commit bc8d3d7

File tree

4 files changed

+425
-13
lines changed

4 files changed

+425
-13
lines changed

sqlmesh/lsp/main.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@
3636
RenderModelRequest,
3737
RenderModelResponse,
3838
)
39-
from sqlmesh.lsp.reference import (
40-
get_references,
41-
)
39+
from sqlmesh.lsp.reference import get_references, get_cte_references
4240
from sqlmesh.lsp.uri import URI
4341
from web.server.api.endpoints.lineage import column_lineage, model_lineage
4442
from web.server.api.endpoints.models import get_models
@@ -378,6 +376,28 @@ def goto_definition(
378376
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
379377
return []
380378

379+
@self.server.feature(types.TEXT_DOCUMENT_REFERENCES)
380+
def find_references(
381+
ls: LanguageServer, params: types.ReferenceParams
382+
) -> t.Optional[t.List[types.Location]]:
383+
"""Find all references of a symbol (currently supporting CTEs)"""
384+
try:
385+
uri = URI(params.text_document.uri)
386+
self._ensure_context_for_document(uri)
387+
document = ls.workspace.get_text_document(params.text_document.uri)
388+
if self.lsp_context is None:
389+
raise RuntimeError(f"No context found for document: {document.path}")
390+
391+
cte_references = get_cte_references(self.lsp_context, uri, params.position)
392+
393+
# Convert references to Location objects
394+
locations = [types.Location(uri=ref.uri, range=ref.range) for ref in cte_references]
395+
396+
return locations if locations else None
397+
except Exception as e:
398+
ls.show_message(f"Error getting locations: {e}", types.MessageType.Error)
399+
return None
400+
381401
@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
382402
def diagnostic(
383403
ls: LanguageServer, params: types.DocumentDiagnosticParams

sqlmesh/lsp/reference.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,7 @@ def by_position(position: Position) -> t.Callable[[Reference], bool]:
4747
"""
4848

4949
def contains_position(r: Reference) -> bool:
50-
return (
51-
r.range.start.line < position.line
52-
or (
53-
r.range.start.line == position.line
54-
and r.range.start.character <= position.character
55-
)
56-
) and (
57-
r.range.end.line > position.line
58-
or (r.range.end.line == position.line and r.range.end.character >= position.character)
59-
)
50+
return _position_within_range(position, r.range)
6051

6152
return contains_position
6253

@@ -478,3 +469,78 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
478469
),
479470
markdown_description=func.__doc__ if func.__doc__ else None,
480471
)
472+
473+
474+
def get_cte_references(
475+
lint_context: LSPContext, document_uri: URI, position: Position
476+
) -> t.List[Reference]:
477+
"""
478+
Get all references to a CTE at a specific position in a document.
479+
480+
This function finds both the definition and all usages of a CTE within the same file.
481+
482+
Args:
483+
lint_context: The LSP context
484+
document_uri: The URI of the document
485+
position: The position to check for CTE references
486+
487+
Returns:
488+
A list of references to the CTE (including its definition and all usages)
489+
"""
490+
references = get_model_definitions_for_a_path(lint_context, document_uri)
491+
492+
# Filter for CTE references (those with target_range set and same URI)
493+
# TODO: Consider extending Reference class to explicitly indicate reference type instead
494+
cte_references = [
495+
ref for ref in references if ref.target_range is not None and ref.uri == document_uri.value
496+
]
497+
498+
if not cte_references:
499+
return []
500+
501+
target_cte_definition_range = None
502+
for ref in cte_references:
503+
# Check if cursor is on a CTE usage
504+
if _position_within_range(position, ref.range):
505+
target_cte_definition_range = ref.target_range
506+
break
507+
# Check if cursor is on the CTE definition
508+
elif ref.target_range and _position_within_range(position, ref.target_range):
509+
target_cte_definition_range = ref.target_range
510+
break
511+
512+
if target_cte_definition_range is None:
513+
return []
514+
515+
# Add the CTE definition
516+
matching_references = [
517+
Reference(
518+
uri=document_uri.value,
519+
range=target_cte_definition_range,
520+
markdown_description="CTE definition",
521+
)
522+
]
523+
524+
# Add all usages
525+
for ref in cte_references:
526+
if ref.target_range == target_cte_definition_range:
527+
matching_references.append(
528+
Reference(
529+
uri=document_uri.value,
530+
range=ref.range,
531+
markdown_description="CTE usage",
532+
)
533+
)
534+
535+
return matching_references
536+
537+
538+
def _position_within_range(position: Position, range: Range) -> bool:
539+
"""Check if a position is within a given range."""
540+
return (
541+
range.start.line < position.line
542+
or (range.start.line == position.line and range.start.character <= position.character)
543+
) and (
544+
range.end.line > position.line
545+
or (range.end.line == position.line and range.end.character >= position.character)
546+
)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from lsprotocol.types import Position
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_cte_references
5+
from sqlmesh.lsp.uri import URI
6+
from tests.lsp.test_reference_cte import find_ranges_from_regex
7+
8+
9+
def test_cte_find_all_references():
10+
context = Context(paths=["examples/sushi"])
11+
lsp_context = LSPContext(context)
12+
13+
sushi_customers_path = next(
14+
path
15+
for path, info in lsp_context.map.items()
16+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
17+
)
18+
19+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
20+
read_file = file.readlines()
21+
22+
# Test finding all references of "current_marketing"
23+
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
24+
assert len(ranges) == 2
25+
26+
# Click on the CTE definition
27+
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
28+
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
29+
30+
# Should find both the definition and the usage
31+
assert len(references) == 2
32+
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
33+
34+
reference_ranges = [ref.range for ref in references]
35+
for expected_range in ranges:
36+
assert any(
37+
ref_range.start.line == expected_range.start.line
38+
and ref_range.start.character == expected_range.start.character
39+
for ref_range in reference_ranges
40+
), (
41+
f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}"
42+
)
43+
44+
# Click on the CTE usage
45+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
46+
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
47+
48+
# Should find the same references
49+
assert len(references) == 2
50+
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
51+
52+
reference_ranges = [ref.range for ref in references]
53+
for expected_range in ranges:
54+
assert any(
55+
ref_range.start.line == expected_range.start.line
56+
and ref_range.start.character == expected_range.start.character
57+
for ref_range in reference_ranges
58+
), (
59+
f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}"
60+
)
61+
62+
63+
def test_cte_find_all_references_outer():
64+
context = Context(paths=["examples/sushi"])
65+
lsp_context = LSPContext(context)
66+
67+
sushi_customers_path = next(
68+
path
69+
for path, info in lsp_context.map.items()
70+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
71+
)
72+
73+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
74+
read_file = file.readlines()
75+
76+
# Test finding all references of "current_marketing_outer"
77+
ranges = find_ranges_from_regex(read_file, r"current_marketing_outer")
78+
assert len(ranges) == 2
79+
80+
# Click on the CTE definition
81+
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
82+
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
83+
84+
# Should find both the definition and the usage
85+
assert len(references) == 2
86+
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
87+
88+
# Verify that we found both occurrences
89+
reference_ranges = [ref.range for ref in references]
90+
for expected_range in ranges:
91+
assert any(
92+
ref_range.start.line == expected_range.start.line
93+
and ref_range.start.character == expected_range.start.character
94+
for ref_range in reference_ranges
95+
), (
96+
f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}"
97+
)
98+
99+
# Click on the CTE usage
100+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
101+
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
102+
103+
# Should find the same references
104+
assert len(references) == 2
105+
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
106+
107+
reference_ranges = [ref.range for ref in references]
108+
for expected_range in ranges:
109+
assert any(
110+
ref_range.start.line == expected_range.start.line
111+
and ref_range.start.character == expected_range.start.character
112+
for ref_range in reference_ranges
113+
), (
114+
f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}"
115+
)
116+
117+
118+
def test_cte_no_references_on_non_cte():
119+
# Test that clicking on non-CTE elements returns nothing, once this is supported adapt this test accordingly
120+
context = Context(paths=["examples/sushi"])
121+
lsp_context = LSPContext(context)
122+
123+
sushi_customers_path = next(
124+
path
125+
for path, info in lsp_context.map.items()
126+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
127+
)
128+
129+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
130+
read_file = file.readlines()
131+
132+
# Click on a regular table reference
133+
ranges = find_ranges_from_regex(read_file, r"sushi\.orders")
134+
assert len(ranges) >= 1
135+
136+
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
137+
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
138+
139+
# Should find no references since this is not a CTE
140+
assert len(references) == 0

0 commit comments

Comments
 (0)