Skip to content

Commit d53e787

Browse files
benfdkingclaude
andauthored
fix(vscode): filter go to definitions by position (#4345)
Co-authored-by: Claude <noreply@anthropic.com>
1 parent a5e67d0 commit d53e787

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

sqlmesh/lsp/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from sqlmesh.lsp.completions import get_sql_completions
1515
from sqlmesh.lsp.context import LSPContext, ModelTarget
1616
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
17-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
17+
from sqlmesh.lsp.reference import (
18+
get_model_definitions_for_a_path,
19+
filter_references_by_position,
20+
)
1821

1922

2023
class SQLMeshLanguageServer:
@@ -189,9 +192,7 @@ def goto_definition(
189192
references = get_model_definitions_for_a_path(
190193
self.lsp_context, params.text_document.uri
191194
)
192-
if not references:
193-
return []
194-
195+
filtered_references = filter_references_by_position(references, params.position)
195196
return [
196197
types.LocationLink(
197198
target_uri=reference.uri,
@@ -205,7 +206,7 @@ def goto_definition(
205206
),
206207
origin_selection_range=reference.range,
207208
)
208-
for reference in references
209+
for reference in filtered_references
209210
]
210211

211212
except Exception as e:

sqlmesh/lsp/reference.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,39 @@ class Reference(PydanticModel):
1414
uri: str
1515

1616

17+
def filter_references_by_position(
18+
references: t.List[Reference], position: Position
19+
) -> t.List[Reference]:
20+
"""
21+
Filter references to only include those that contain the given position.
22+
23+
Args:
24+
references: List of Reference objects
25+
position: The cursor position to check
26+
27+
Returns:
28+
List of Reference objects that contain the position
29+
"""
30+
filtered_references = []
31+
32+
for reference in references:
33+
# Check if position is within the reference range
34+
range_start = reference.range.start
35+
range_end = reference.range.end
36+
37+
# Position is within range if it's after or at start and before or at end
38+
if (
39+
range_start.line < position.line
40+
or (range_start.line == position.line and range_start.character <= position.character)
41+
) and (
42+
range_end.line > position.line
43+
or (range_end.line == position.line and range_end.character >= position.character)
44+
):
45+
filtered_references.append(reference)
46+
47+
return filtered_references
48+
49+
1750
def get_model_definitions_for_a_path(
1851
lint_context: LSPContext, document_uri: str
1952
) -> t.List[Reference]:

tests/lsp/test_reference.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
2+
from lsprotocol.types import Position
23
from sqlmesh.core.context import Context
34
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
4-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
5+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, filter_references_by_position
56

67

78
@pytest.mark.fast
@@ -108,3 +109,70 @@ def get_string_from_range(file_lines, range_obj) -> str:
108109
result += file_lines[line_num]
109110
result += file_lines[end_line][:end_character] # Last line up to end_character
110111
return result
112+
113+
114+
@pytest.mark.fast
115+
def test_filter_references_by_position() -> None:
116+
"""Test that we can filter references correctly based on cursor position."""
117+
context = Context(paths=["examples/sushi"])
118+
lsp_context = LSPContext(context)
119+
120+
# Use a file with multiple references (waiter_revenue_by_day)
121+
waiter_revenue_by_day_uri = next(
122+
uri
123+
for uri, info in lsp_context.map.items()
124+
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
125+
)
126+
127+
# Get all references in the file
128+
all_references = get_model_definitions_for_a_path(lsp_context, waiter_revenue_by_day_uri)
129+
assert len(all_references) == 3
130+
131+
# Get file contents to locate positions for testing
132+
path = waiter_revenue_by_day_uri.removeprefix("file://")
133+
with open(path, "r") as file:
134+
read_file = file.readlines()
135+
136+
# Test positions for each reference
137+
for i, reference in enumerate(all_references):
138+
# Position inside the reference - should return exactly one reference
139+
middle_line = (reference.range.start.line + reference.range.end.line) // 2
140+
middle_char = (reference.range.start.character + reference.range.end.character) // 2
141+
position_inside = Position(line=middle_line, character=middle_char)
142+
filtered = filter_references_by_position(all_references, position_inside)
143+
assert len(filtered) == 1
144+
assert filtered[0].uri == reference.uri
145+
assert filtered[0].range == reference.range
146+
147+
# For testing outside position, use a position before the current reference
148+
# or after the last reference for the last one
149+
if i == 0:
150+
outside_line = reference.range.start.line
151+
outside_char = max(0, reference.range.start.character - 5)
152+
else:
153+
prev_ref = all_references[i - 1]
154+
outside_line = prev_ref.range.end.line
155+
outside_char = prev_ref.range.end.character + 5
156+
157+
position_outside = Position(line=outside_line, character=outside_char)
158+
filtered_outside = filter_references_by_position(all_references, position_outside)
159+
assert reference not in filtered_outside, (
160+
f"Reference {i} should not match position outside its range"
161+
)
162+
163+
# Test case: cursor at beginning of file - no references should match
164+
position_start = Position(line=0, character=0)
165+
filtered_start = filter_references_by_position(all_references, position_start)
166+
assert len(filtered_start) == 0 or all(
167+
ref.range.start.line == 0 and ref.range.start.character <= 0 for ref in filtered_start
168+
)
169+
170+
# Test case: cursor at end of file - no references should match (unless there's a reference at the end)
171+
last_line = len(read_file) - 1
172+
last_char = len(read_file[last_line]) - 1
173+
position_end = Position(line=last_line, character=last_char)
174+
filtered_end = filter_references_by_position(all_references, position_end)
175+
assert len(filtered_end) == 0 or all(
176+
ref.range.end.line >= last_line and ref.range.end.character >= last_char
177+
for ref in filtered_end
178+
)

0 commit comments

Comments
 (0)