Skip to content

Commit d7b1866

Browse files
authored
feat: add on hover to models on lsp (#4351)
1 parent 00cc5ed commit d7b1866

File tree

3 files changed

+87
-33
lines changed

3 files changed

+87
-33
lines changed

sqlmesh/lsp/main.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from sqlmesh.lsp.context import LSPContext, ModelTarget
1616
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1717
from sqlmesh.lsp.reference import (
18-
get_model_definitions_for_a_path,
19-
filter_references_by_position,
18+
get_references,
2019
)
2120

2221

@@ -178,6 +177,34 @@ def formatting(
178177
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
179178
return []
180179

180+
@self.server.feature(types.TEXT_DOCUMENT_HOVER)
181+
def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hover]:
182+
"""Provide hover information for an object."""
183+
try:
184+
self._ensure_context_for_document(params.text_document.uri)
185+
document = ls.workspace.get_document(params.text_document.uri)
186+
if self.lsp_context is None:
187+
raise RuntimeError(f"No context found for document: {document.path}")
188+
189+
references = get_references(
190+
self.lsp_context, params.text_document.uri, params.position
191+
)
192+
if not references:
193+
return None
194+
reference = references[0]
195+
if not reference.description:
196+
return None
197+
return types.Hover(
198+
contents=types.MarkupContent(
199+
kind=types.MarkupKind.Markdown, value=reference.description
200+
),
201+
range=reference.range,
202+
)
203+
204+
except Exception as e:
205+
ls.show_message(f"Error getting hover information: {e}", types.MessageType.Error)
206+
return None
207+
181208
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
182209
def goto_definition(
183210
ls: LanguageServer, params: types.DefinitionParams
@@ -189,10 +216,9 @@ def goto_definition(
189216
if self.lsp_context is None:
190217
raise RuntimeError(f"No context found for document: {document.path}")
191218

192-
references = get_model_definitions_for_a_path(
193-
self.lsp_context, params.text_document.uri
219+
references = get_references(
220+
self.lsp_context, params.text_document.uri, params.position
194221
)
195-
filtered_references = filter_references_by_position(references, params.position)
196222
return [
197223
types.LocationLink(
198224
target_uri=reference.uri,
@@ -206,9 +232,8 @@ def goto_definition(
206232
),
207233
origin_selection_range=reference.range,
208234
)
209-
for reference in filtered_references
235+
for reference in references
210236
]
211-
212237
except Exception as e:
213238
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
214239
return []

sqlmesh/lsp/reference.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,64 @@
1010

1111

1212
class Reference(PydanticModel):
13+
"""
14+
A reference to a model.
15+
16+
Attributes:
17+
range: The range of the reference in the source file
18+
uri: The uri of the referenced model
19+
description: The description of the referenced model
20+
"""
21+
1322
range: Range
1423
uri: str
24+
description: t.Optional[str] = None
1525

1626

17-
def filter_references_by_position(
18-
references: t.List[Reference], position: Position
19-
) -> t.List[Reference]:
27+
def by_position(position: Position) -> t.Callable[[Reference], bool]:
2028
"""
21-
Filter references to only include those that contain the given position.
29+
Filter reference to only filter references that contain the given position.
2230
2331
Args:
24-
references: List of Reference objects
2532
position: The cursor position to check
2633
2734
Returns:
28-
List of Reference objects that contain the position
35+
A function that returns True if the reference contains the position, False otherwise
2936
"""
30-
filtered_references = []
31-
32-
for reference in references:
33-
# Check if position is within the reference range
34-
range_start = reference.range.start
35-
range_end = reference.range.end
3637

37-
# Position is within range if it's after or at start and before or at end
38-
if (
39-
range_start.line < position.line
40-
or (range_start.line == position.line and range_start.character <= position.character)
38+
def contains_position(r: Reference) -> bool:
39+
return (
40+
r.range.start.line < position.line
41+
or (
42+
r.range.start.line == position.line
43+
and r.range.start.character <= position.character
44+
)
4145
) and (
42-
range_end.line > position.line
43-
or (range_end.line == position.line and range_end.character >= position.character)
44-
):
45-
filtered_references.append(reference)
46+
r.range.end.line > position.line
47+
or (r.range.end.line == position.line and r.range.end.character >= position.character)
48+
)
49+
50+
return contains_position
4651

52+
53+
def get_references(
54+
lint_context: LSPContext, document_uri: str, position: Position
55+
) -> t.List[Reference]:
56+
"""
57+
Get references at a specific position in a document.
58+
59+
Used for hover information.
60+
61+
Args:
62+
lint_context: The LSP context
63+
document_uri: The URI of the document
64+
position: The position to check for references
65+
66+
Returns:
67+
A list of references at the given position
68+
"""
69+
references = get_model_definitions_for_a_path(lint_context, document_uri)
70+
filtered_references = list(filter(by_position(position), references))
4771
return filtered_references
4872

4973

@@ -154,7 +178,11 @@ def get_model_definitions_for_a_path(
154178
start_pos = catalog_or_db_range.start
155179

156180
references.append(
157-
Reference(uri=referenced_model_uri, range=Range(start=start_pos, end=end_pos))
181+
Reference(
182+
uri=referenced_model_uri,
183+
range=Range(start=start_pos, end=end_pos),
184+
description=referenced_model.description,
185+
)
158186
)
159187

160188
return references

tests/lsp/test_reference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from lsprotocol.types import Position
33
from sqlmesh.core.context import Context
44
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
5-
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, filter_references_by_position
5+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, by_position
66

77

88
@pytest.mark.fast
@@ -54,6 +54,7 @@ def test_reference_with_alias() -> None:
5454

5555
assert references[0].uri.endswith("orders.py")
5656
assert get_string_from_range(read_file, references[0].range) == "sushi.orders"
57+
assert references[0].description == "Table of sushi orders."
5758
assert references[1].uri.endswith("order_items.py")
5859
assert get_string_from_range(read_file, references[1].range) == "sushi.order_items"
5960
assert references[2].uri.endswith("items.py")
@@ -139,7 +140,7 @@ def test_filter_references_by_position() -> None:
139140
middle_line = (reference.range.start.line + reference.range.end.line) // 2
140141
middle_char = (reference.range.start.character + reference.range.end.character) // 2
141142
position_inside = Position(line=middle_line, character=middle_char)
142-
filtered = filter_references_by_position(all_references, position_inside)
143+
filtered = list(filter(by_position(position_inside), all_references))
143144
assert len(filtered) == 1
144145
assert filtered[0].uri == reference.uri
145146
assert filtered[0].range == reference.range
@@ -155,14 +156,14 @@ def test_filter_references_by_position() -> None:
155156
outside_char = prev_ref.range.end.character + 5
156157

157158
position_outside = Position(line=outside_line, character=outside_char)
158-
filtered_outside = filter_references_by_position(all_references, position_outside)
159+
filtered_outside = list(filter(by_position(position_outside), all_references))
159160
assert reference not in filtered_outside, (
160161
f"Reference {i} should not match position outside its range"
161162
)
162163

163164
# Test case: cursor at beginning of file - no references should match
164165
position_start = Position(line=0, character=0)
165-
filtered_start = filter_references_by_position(all_references, position_start)
166+
filtered_start = list(filter(by_position(position_start), all_references))
166167
assert len(filtered_start) == 0 or all(
167168
ref.range.start.line == 0 and ref.range.start.character <= 0 for ref in filtered_start
168169
)
@@ -171,7 +172,7 @@ def test_filter_references_by_position() -> None:
171172
last_line = len(read_file) - 1
172173
last_char = len(read_file[last_line]) - 1
173174
position_end = Position(line=last_line, character=last_char)
174-
filtered_end = filter_references_by_position(all_references, position_end)
175+
filtered_end = list(filter(by_position(position_end), all_references))
175176
assert len(filtered_end) == 0 or all(
176177
ref.range.end.line >= last_line and ref.range.end.character >= last_char
177178
for ref in filtered_end

0 commit comments

Comments
 (0)