Skip to content

Commit d2eabe7

Browse files
authored
chore(vscode): improve uri handling in lsp (#4433)
1 parent 06a51dc commit d2eabe7

File tree

8 files changed

+105
-74
lines changed

8 files changed

+105
-74
lines changed

sqlmesh/lsp/completions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from sqlmesh.lsp.custom import AllModelsResponse
44
import typing as t
55
from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget
6+
from sqlmesh.lsp.uri import URI
67

78

8-
def get_sql_completions(context: t.Optional[LSPContext], file_uri: str) -> AllModelsResponse:
9+
def get_sql_completions(context: t.Optional[LSPContext], file_uri: URI) -> AllModelsResponse:
910
"""
1011
Return a list of completions for a given file.
1112
"""
@@ -15,7 +16,7 @@ def get_sql_completions(context: t.Optional[LSPContext], file_uri: str) -> AllMo
1516
)
1617

1718

18-
def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.Set[str]:
19+
def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
1920
"""
2021
Return a list of models for a given file.
2122
@@ -41,7 +42,7 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.
4142
return all_models
4243

4344

44-
def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.Set[str]:
45+
def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
4546
"""
4647
Return a list of sql keywords for a given file.
4748
If no context is provided, return ANSI SQL keywords.

sqlmesh/lsp/context.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from dataclasses import dataclass
2-
from pathlib import Path
32
from sqlmesh.core.context import Context
43
import typing as t
54

5+
from sqlmesh.lsp.uri import URI
6+
67

78
@dataclass
89
class ModelTarget:
@@ -28,26 +29,24 @@ def __init__(self, context: Context) -> None:
2829
self.context = context
2930

3031
# Add models to the map
31-
model_map: t.Dict[str, ModelTarget] = {}
32+
model_map: t.Dict[URI, ModelTarget] = {}
3233
for model in context.models.values():
3334
if model._path is not None:
34-
path = Path(model._path).resolve()
35-
uri = f"file://{path.as_posix()}"
35+
uri = URI.from_path(model._path)
3636
if uri in model_map:
3737
model_map[uri].names.append(model.name)
3838
else:
3939
model_map[uri] = ModelTarget(names=[model.name])
4040

4141
# Add standalone audits to the map
42-
audit_map: t.Dict[str, AuditTarget] = {}
42+
audit_map: t.Dict[URI, AuditTarget] = {}
4343
for audit in context.standalone_audits.values():
4444
if audit._path is not None:
45-
path = Path(audit._path).resolve()
46-
uri = f"file://{path.as_posix()}"
45+
uri = URI.from_path(audit._path)
4746
if uri not in audit_map:
4847
audit_map[uri] = AuditTarget(name=audit.name)
4948

50-
self.map: t.Dict[str, t.Union[ModelTarget, AuditTarget]] = {
49+
self.map: t.Dict[URI, t.Union[ModelTarget, AuditTarget]] = {
5150
**model_map,
5251
**audit_map,
5352
}

sqlmesh/lsp/main.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sqlmesh.lsp.reference import (
2424
get_references,
2525
)
26+
from sqlmesh.lsp.uri import URI
2627
from web.server.api.endpoints.lineage import model_lineage
2728
from web.server.api.endpoints.models import get_models
2829

@@ -42,7 +43,7 @@ def __init__(
4243
self.server = LanguageServer(server_name, version)
4344
self.context_class = context_class
4445
self.lsp_context: t.Optional[LSPContext] = None
45-
self.lint_cache: t.Dict[str, t.List[AnnotatedRuleViolation]] = {}
46+
self.lint_cache: t.Dict[URI, t.List[AnnotatedRuleViolation]] = {}
4647

4748
# Register LSP features (e.g., formatting, hover, etc.)
4849
self._register_features()
@@ -82,10 +83,11 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
8283
@self.server.feature(ALL_MODELS_FEATURE)
8384
def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
8485
try:
85-
context = self._context_get_or_load(params.textDocument.uri)
86-
return get_sql_completions(context, params.textDocument.uri)
86+
uri = URI(params.textDocument.uri)
87+
context = self._context_get_or_load(uri)
88+
return get_sql_completions(context, uri)
8789
except Exception as e:
88-
return get_sql_completions(None, params.textDocument.uri)
90+
return get_sql_completions(None, uri)
8991

9092
@self.server.feature(API_FEATURE)
9193
def api(
@@ -106,67 +108,62 @@ def api(
106108

107109
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
108110
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
109-
context = self._context_get_or_load(params.text_document.uri)
110-
if self.lint_cache.get(params.text_document.uri) is not None:
111+
uri = URI(params.text_document.uri)
112+
context = self._context_get_or_load(uri)
113+
if self.lint_cache.get(uri) is not None:
111114
ls.publish_diagnostics(
112115
params.text_document.uri,
113-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
114-
self.lint_cache[params.text_document.uri]
115-
),
116+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
116117
)
117118
return
118-
models = context.map[params.text_document.uri]
119+
models = context.map[uri]
119120
if models is None:
120121
return
121122
if not isinstance(models, ModelTarget):
122123
return
123-
self.lint_cache[params.text_document.uri] = context.context.lint_models(
124+
self.lint_cache[uri] = context.context.lint_models(
124125
models.names,
125126
raise_on_error=False,
126127
)
127128
ls.publish_diagnostics(
128129
params.text_document.uri,
129-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
130-
self.lint_cache[params.text_document.uri]
131-
),
130+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
132131
)
133132

134133
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
135134
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
136-
context = self._context_get_or_load(params.text_document.uri)
137-
models = context.map[params.text_document.uri]
135+
uri = URI(params.text_document.uri)
136+
context = self._context_get_or_load(uri)
137+
models = context.map[uri]
138138
if models is None:
139139
return
140140
if not isinstance(models, ModelTarget):
141141
return
142-
self.lint_cache[params.text_document.uri] = context.context.lint_models(
142+
self.lint_cache[uri] = context.context.lint_models(
143143
models.names,
144144
raise_on_error=False,
145145
)
146146
ls.publish_diagnostics(
147147
params.text_document.uri,
148-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
149-
self.lint_cache[params.text_document.uri]
150-
),
148+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
151149
)
152150

153151
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
154152
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
155-
context = self._context_get_or_load(params.text_document.uri)
156-
models = context.map[params.text_document.uri]
153+
uri = URI(params.text_document.uri)
154+
context = self._context_get_or_load(uri)
155+
models = context.map[uri]
157156
if models is None:
158157
return
159158
if not isinstance(models, ModelTarget):
160159
return
161-
self.lint_cache[params.text_document.uri] = context.context.lint_models(
160+
self.lint_cache[uri] = context.context.lint_models(
162161
models.names,
163162
raise_on_error=False,
164163
)
165164
ls.publish_diagnostics(
166165
params.text_document.uri,
167-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
168-
self.lint_cache[params.text_document.uri]
169-
),
166+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
170167
)
171168

172169
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
@@ -175,14 +172,15 @@ def formatting(
175172
) -> t.List[types.TextEdit]:
176173
"""Format the document using SQLMesh `format_model_expressions`."""
177174
try:
178-
self._ensure_context_for_document(params.text_document.uri)
175+
uri = URI(params.text_document.uri)
176+
self._ensure_context_for_document(uri)
179177
document = ls.workspace.get_text_document(params.text_document.uri)
180178
if self.lsp_context is None:
181179
raise RuntimeError(f"No context found for document: {document.path}")
182180

183181
# Perform formatting using the loaded context
184-
self.lsp_context.context.format(paths=(Path(document.path),))
185-
with open(document.path, "r+", encoding="utf-8") as file:
182+
self.lsp_context.context.format(paths=(str(uri.to_path()),))
183+
with open(uri.to_path(), "r+", encoding="utf-8") as file:
186184
new_text = file.read()
187185

188186
# Return a single edit that replaces the entire file.
@@ -206,14 +204,13 @@ def formatting(
206204
def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hover]:
207205
"""Provide hover information for an object."""
208206
try:
209-
self._ensure_context_for_document(params.text_document.uri)
207+
uri = URI(params.text_document.uri)
208+
self._ensure_context_for_document(uri)
210209
document = ls.workspace.get_text_document(params.text_document.uri)
211210
if self.lsp_context is None:
212211
raise RuntimeError(f"No context found for document: {document.path}")
213212

214-
references = get_references(
215-
self.lsp_context, params.text_document.uri, params.position
216-
)
213+
references = get_references(self.lsp_context, uri, params.position)
217214
if not references:
218215
return None
219216
reference = references[0]
@@ -236,14 +233,13 @@ def goto_definition(
236233
) -> t.List[types.LocationLink]:
237234
"""Jump to an object's definition."""
238235
try:
239-
self._ensure_context_for_document(params.text_document.uri)
236+
uri = URI(params.text_document.uri)
237+
self._ensure_context_for_document(uri)
240238
document = ls.workspace.get_text_document(params.text_document.uri)
241239
if self.lsp_context is None:
242240
raise RuntimeError(f"No context found for document: {document.path}")
243241

244-
references = get_references(
245-
self.lsp_context, params.text_document.uri, params.position
246-
)
242+
references = get_references(self.lsp_context, uri, params.position)
247243
return [
248244
types.LocationLink(
249245
target_uri=reference.uri,
@@ -263,7 +259,7 @@ def goto_definition(
263259
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
264260
return []
265261

266-
def _context_get_or_load(self, document_uri: str) -> LSPContext:
262+
def _context_get_or_load(self, document_uri: URI) -> LSPContext:
267263
if self.lsp_context is None:
268264
self._ensure_context_for_document(document_uri)
269265
if self.lsp_context is None:
@@ -272,7 +268,7 @@ def _context_get_or_load(self, document_uri: str) -> LSPContext:
272268

273269
def _ensure_context_for_document(
274270
self,
275-
document_uri: str,
271+
document_uri: URI,
276272
) -> None:
277273
"""
278274
Ensure that a context exists for the given document if applicable by searching
@@ -285,7 +281,7 @@ def _ensure_context_for_document(
285281
return
286282

287283
# No context yet: try to find config and load it
288-
path = Path(self._uri_to_path(document_uri)).resolve()
284+
path = document_uri.to_path()
289285
if path.suffix not in (".sql", ".py"):
290286
return
291287

@@ -321,7 +317,8 @@ def _diagnostic_to_lsp_diagnostic(
321317

322318
# Get rule definition location for diagnostics link
323319
rule_location = diagnostic.rule.get_definition_location()
324-
rule_uri = f"file://{rule_location.file_path}#L{rule_location.start_line}"
320+
rule_uri_wihout_extension = URI.from_path(rule_location.file_path)
321+
rule_uri = f"{rule_uri_wihout_extension.value}#L{rule_location.start_line}"
325322

326323
# Use URI format to create a link for "related information"
327324
return types.Diagnostic(
@@ -350,11 +347,9 @@ def _diagnostics_to_lsp_diagnostics(
350347
return lsp_diagnostics
351348

352349
@staticmethod
353-
def _uri_to_path(uri: str) -> str:
350+
def _uri_to_path(uri: str) -> Path:
354351
"""Convert a URI to a path."""
355-
if uri.startswith("file://"):
356-
return Path(uri[7:]).resolve().as_posix()
357-
return Path(uri).resolve().as_posix()
352+
return URI(uri).to_path()
358353

359354
def start(self) -> None:
360355
"""Start the server with I/O transport."""

sqlmesh/lsp/reference.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlmesh.core.model.definition import SqlModel
66
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
77
from sqlglot import exp
8-
8+
from sqlmesh.lsp.uri import URI
99
from sqlmesh.utils.pydantic import PydanticModel
1010

1111

@@ -51,7 +51,7 @@ def contains_position(r: Reference) -> bool:
5151

5252

5353
def get_references(
54-
lint_context: LSPContext, document_uri: str, position: Position
54+
lint_context: LSPContext, document_uri: URI, position: Position
5555
) -> t.List[Reference]:
5656
"""
5757
Get references at a specific position in a document.
@@ -72,7 +72,7 @@ def get_references(
7272

7373

7474
def get_model_definitions_for_a_path(
75-
lint_context: LSPContext, document_uri: str
75+
lint_context: LSPContext, document_uri: URI
7676
) -> t.List[Reference]:
7777
"""
7878
Get the model references for a given path.
@@ -89,9 +89,8 @@ def get_model_definitions_for_a_path(
8989
- Match to models that the model refers to
9090
"""
9191
# Ensure the path is a sql file
92-
if not document_uri.endswith(".sql"):
92+
if document_uri.to_path().suffix != ".sql":
9393
return []
94-
9594
# Get the file info from the context map
9695
if document_uri not in lint_context.map:
9796
return []
@@ -163,7 +162,7 @@ def get_model_definitions_for_a_path(
163162
# Check whether the path exists
164163
if not referenced_model_path.is_file():
165164
continue
166-
referenced_model_uri = f"file://{referenced_model_path}"
165+
referenced_model_uri = URI.from_path(referenced_model_path)
167166

168167
# Extract metadata for positioning
169168
table_meta = TokenPositionDetails.from_meta(table.this.meta)
@@ -180,7 +179,7 @@ def get_model_definitions_for_a_path(
180179

181180
references.append(
182181
Reference(
183-
uri=referenced_model_uri,
182+
uri=referenced_model_uri.value,
184183
range=Range(start=start_pos, end=end_pos),
185184
description=referenced_model.description,
186185
)

sqlmesh/lsp/uri.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pathlib import Path
2+
from pygls.uris import from_fs_path, to_fs_path
3+
import typing as t
4+
5+
6+
class URI:
7+
"""
8+
A URI is a unique identifier for a file used in the LSP.
9+
"""
10+
11+
def __init__(self, uri: str):
12+
self.value: str = uri
13+
14+
def __hash__(self) -> int:
15+
return hash(self.value)
16+
17+
def __eq__(self, other: object) -> bool:
18+
if not isinstance(other, URI):
19+
return False
20+
return self.value == other.value
21+
22+
def __repr__(self) -> str:
23+
return f"URI({self.value})"
24+
25+
def to_path(self) -> Path:
26+
p = to_fs_path(self.value)
27+
return Path(p)
28+
29+
@staticmethod
30+
def from_path(path: t.Union[str, Path]) -> "URI":
31+
if isinstance(path, Path):
32+
path = path.as_posix()
33+
return URI(from_fs_path(path))

tests/lsp/test_completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_get_sql_completions_with_context_and_file_uri():
3737
lsp_context = LSPContext(context)
3838

3939
file_uri = next(
40-
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
40+
key for key in lsp_context.map.keys() if str(key.to_path()).endswith("active_customers.sql")
4141
)
4242
completions = get_sql_completions(lsp_context, file_uri)
4343
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)

0 commit comments

Comments
 (0)