Skip to content

Commit 3462a5d

Browse files
authored
feat(vscode): go to definition for standalone audits (#4344)
1 parent 0c323d6 commit 3462a5d

File tree

7 files changed

+168
-43
lines changed

7 files changed

+168
-43
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,6 @@ tests/_version.py
155155
# spark
156156
metastore_db/
157157
spark-warehouse/
158+
159+
# claude
160+
.claude/

sqlmesh/lsp/completions.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from sqlglot import Dialect, Tokenizer
33
from sqlmesh.lsp.custom import AllModelsResponse
44
import typing as t
5-
from sqlmesh.lsp.context import LSPContext
5+
from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget
66

77

88
def get_sql_completions(context: t.Optional[LSPContext], file_uri: str) -> AllModelsResponse:
@@ -24,11 +24,20 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.
2424
"""
2525
if context is None:
2626
return set()
27-
all_models = set(model for models in context.map.values() for model in models)
28-
if file_uri is not None:
29-
models_file_refers_to = context.map[file_uri]
30-
for model in models_file_refers_to:
31-
all_models.discard(model)
27+
28+
all_models = set()
29+
# Extract model names from ModelInfo objects
30+
for file_info in context.map.values():
31+
if isinstance(file_info, ModelTarget):
32+
all_models.update(file_info.names)
33+
34+
# Remove models from the current file
35+
if file_uri is not None and file_uri in context.map:
36+
file_info = context.map[file_uri]
37+
if isinstance(file_info, ModelTarget):
38+
for model in file_info.names:
39+
all_models.discard(model)
40+
3241
return all_models
3342

3443

@@ -43,16 +52,25 @@ def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[str]) ->
4352
If both a context and a file_uri are provided, returns the keywords
4453
for the dialect of the model that the file belongs to.
4554
"""
46-
if file_uri is not None and context is not None:
47-
models = context.map[file_uri]
48-
if models:
49-
model = models[0]
50-
model_from_context = context.context.get_model(model)
51-
if model_from_context is not None:
52-
if model_from_context.dialect:
53-
return get_keywords_from_tokenizer(model_from_context.dialect)
55+
if file_uri is not None and context is not None and file_uri in context.map:
56+
file_info = context.map[file_uri]
57+
58+
# Handle ModelInfo objects
59+
if isinstance(file_info, ModelTarget) and file_info.names:
60+
model_name = file_info.names[0]
61+
model_from_context = context.context.get_model(model_name)
62+
if model_from_context is not None and model_from_context.dialect:
63+
return get_keywords_from_tokenizer(model_from_context.dialect)
64+
65+
# Handle AuditInfo objects
66+
elif isinstance(file_info, AuditTarget) and file_info.name:
67+
audit = context.context.standalone_audits.get(file_info.name)
68+
if audit is not None and audit.dialect:
69+
return get_keywords_from_tokenizer(audit.dialect)
70+
5471
if context is not None:
5572
return get_keywords_from_tokenizer(context.context.default_dialect)
73+
5674
return get_keywords_from_tokenizer(None)
5775

5876

sqlmesh/lsp/context.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,53 @@
1-
from collections import defaultdict
1+
from dataclasses import dataclass
22
from pathlib import Path
33
from sqlmesh.core.context import Context
44
import typing as t
55

66

7+
@dataclass
8+
class ModelTarget:
9+
"""Information about models in a file."""
10+
11+
names: t.List[str]
12+
13+
14+
@dataclass
15+
class AuditTarget:
16+
"""Information about standalone audits in a file."""
17+
18+
name: str
19+
20+
721
class LSPContext:
822
"""
9-
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
23+
A context that is used for linting. It contains the context and a reverse map of file uri to
24+
model names and standalone audit names.
1025
"""
1126

1227
def __init__(self, context: Context) -> None:
1328
self.context = context
14-
map: t.Dict[str, t.List[str]] = defaultdict(list)
29+
30+
# Add models to the map
31+
model_map: t.Dict[str, ModelTarget] = {}
1532
for model in context.models.values():
1633
if model._path is not None:
1734
path = Path(model._path).resolve()
18-
map[f"file://{path.as_posix()}"].append(model.name)
35+
uri = f"file://{path.as_posix()}"
36+
if uri in model_map:
37+
model_map[uri].names.append(model.name)
38+
else:
39+
model_map[uri] = ModelTarget(names=[model.name])
40+
41+
# Add standalone audits to the map
42+
audit_map: t.Dict[str, AuditTarget] = {}
43+
for audit in context.standalone_audits.values():
44+
if audit._path is not None:
45+
path = Path(audit._path).resolve()
46+
uri = f"file://{path.as_posix()}"
47+
if uri not in audit_map:
48+
audit_map[uri] = AuditTarget(name=audit.name)
1949

20-
self.map = map
50+
self.map: t.Dict[str, t.Union[ModelTarget, AuditTarget]] = {
51+
**model_map,
52+
**audit_map,
53+
}

sqlmesh/lsp/main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
1414
from sqlmesh.lsp.completions import get_sql_completions
15-
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.context import LSPContext, ModelTarget
1616
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1717
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
1818

@@ -91,8 +91,10 @@ def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> Non
9191
models = context.map[params.text_document.uri]
9292
if models is None:
9393
return
94+
if not isinstance(models, ModelTarget):
95+
return
9496
self.lint_cache[params.text_document.uri] = context.context.lint_models(
95-
models,
97+
models.names,
9698
raise_on_error=False,
9799
)
98100
ls.publish_diagnostics(
@@ -108,8 +110,10 @@ def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) ->
108110
models = context.map[params.text_document.uri]
109111
if models is None:
110112
return
113+
if not isinstance(models, ModelTarget):
114+
return
111115
self.lint_cache[params.text_document.uri] = context.context.lint_models(
112-
models,
116+
models.names,
113117
raise_on_error=False,
114118
)
115119
ls.publish_diagnostics(
@@ -125,8 +129,10 @@ def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> Non
125129
models = context.map[params.text_document.uri]
126130
if models is None:
127131
return
132+
if not isinstance(models, ModelTarget):
133+
return
128134
self.lint_cache[params.text_document.uri] = context.context.lint_models(
129-
models,
135+
models.names,
130136
raise_on_error=False,
131137
)
132138
ls.publish_diagnostics(

sqlmesh/lsp/reference.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sqlmesh.core.dialect import normalize_model_name
55
from sqlmesh.core.model.definition import SqlModel
6-
from sqlmesh.lsp.context import LSPContext
6+
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
77
from sqlglot import exp
88

99
from sqlmesh.utils.pydantic import PydanticModel
@@ -20,7 +20,7 @@ def get_model_definitions_for_a_path(
2020
"""
2121
Get the model references for a given path.
2222
23-
Works for models and audits.
23+
Works for models and standalone audits.
2424
Works for targeting sql and python models.
2525
2626
Steps:
@@ -31,39 +31,63 @@ def get_model_definitions_for_a_path(
3131
- Try get_model before normalization
3232
- Match to models that the model refers to
3333
"""
34-
# Ensure the path is a sql model
34+
# Ensure the path is a sql file
3535
if not document_uri.endswith(".sql"):
3636
return []
3737

38-
# Get the model
39-
models = lint_context.map[document_uri]
40-
if not models:
38+
# Get the file info from the context map
39+
if document_uri not in lint_context.map:
4140
return []
42-
model = lint_context.context.get_model(model_or_snapshot=models[0], raise_if_missing=False)
43-
if model is None or not isinstance(model, SqlModel):
41+
42+
file_info = lint_context.map[document_uri]
43+
44+
# Process based on whether it's a model or standalone audit
45+
if isinstance(file_info, ModelTarget):
46+
# It's a model
47+
model = lint_context.context.get_model(
48+
model_or_snapshot=file_info.names[0], raise_if_missing=False
49+
)
50+
if model is None or not isinstance(model, SqlModel):
51+
return []
52+
53+
query = model.query
54+
dialect = model.dialect
55+
depends_on = model.depends_on
56+
file_path = model._path
57+
elif isinstance(file_info, AuditTarget):
58+
# It's a standalone audit
59+
audit = lint_context.context.standalone_audits.get(file_info.name)
60+
if audit is None:
61+
return []
62+
63+
query = audit.query
64+
dialect = audit.dialect
65+
depends_on = audit.depends_on
66+
file_path = audit._path
67+
else:
4468
return []
4569

4670
# Find all possible references
4771
references = []
48-
tables = list(model.query.find_all(exp.Table))
72+
73+
# Get SQL query and find all table references
74+
tables = list(query.find_all(exp.Table))
4975
if len(tables) == 0:
5076
return []
5177

52-
read_file = open(model._path, "r").readlines()
78+
read_file = open(file_path, "r").readlines()
5379

5480
for table in tables:
55-
depends_on = model.depends_on
56-
5781
# Normalize the table reference
5882
unaliased = table.copy()
5983
if unaliased.args.get("alias") is not None:
6084
unaliased.set("alias", None)
61-
reference_name = unaliased.sql(dialect=model.dialect)
85+
reference_name = unaliased.sql(dialect=dialect)
6286
try:
6387
normalized_reference_name = normalize_model_name(
6488
reference_name,
6589
default_catalog=lint_context.context.default_catalog,
66-
dialect=model.dialect,
90+
dialect=dialect,
6791
)
6892
if normalized_reference_name not in depends_on:
6993
continue

tests/lsp/test_context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from sqlmesh.core.context import Context
3-
from sqlmesh.lsp.context import LSPContext
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
44

55

66
@pytest.mark.fast
@@ -16,4 +16,7 @@ def test_lsp_context():
1616
active_customers_key = next(
1717
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
1818
)
19-
assert lsp_context.map[active_customers_key] == ["sushi.active_customers"]
19+
20+
# Check that the value is a ModelInfo with the expected model name
21+
assert isinstance(lsp_context.map[active_customers_key], ModelTarget)
22+
assert "sushi.active_customers" in lsp_context.map[active_customers_key].names

tests/lsp/test_reference.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from sqlmesh.core.context import Context
3-
from sqlmesh.lsp.context import LSPContext
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
44
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
55

66

@@ -9,11 +9,16 @@ def test_reference() -> None:
99
context = Context(paths=["examples/sushi"])
1010
lsp_context = LSPContext(context)
1111

12+
# Find model URIs
1213
active_customers_uri = next(
13-
uri for uri, models in lsp_context.map.items() if "sushi.active_customers" in models
14+
uri
15+
for uri, info in lsp_context.map.items()
16+
if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names
1417
)
1518
sushi_customers_uri = next(
16-
uri for uri, models in lsp_context.map.items() if "sushi.customers" in models
19+
uri
20+
for uri, info in lsp_context.map.items()
21+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
1722
)
1823

1924
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
@@ -35,7 +40,9 @@ def test_reference_with_alias() -> None:
3540
lsp_context = LSPContext(context)
3641

3742
waiter_revenue_by_day_uri = next(
38-
uri for uri, models in lsp_context.map.items() if "sushi.waiter_revenue_by_day" in models
43+
uri
44+
for uri, info in lsp_context.map.items()
45+
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
3946
)
4047

4148
references = get_model_definitions_for_a_path(lsp_context, waiter_revenue_by_day_uri)
@@ -52,6 +59,37 @@ def test_reference_with_alias() -> None:
5259
assert get_string_from_range(read_file, references[2].range) == "sushi.items"
5360

5461

62+
@pytest.mark.fast
63+
def test_standalone_audit_reference() -> None:
64+
context = Context(paths=["examples/sushi"])
65+
lsp_context = LSPContext(context)
66+
67+
# Find the standalone audit URI
68+
audit_uri = next(
69+
uri
70+
for uri, info in lsp_context.map.items()
71+
if isinstance(info, AuditTarget) and info.name == "assert_item_price_above_zero"
72+
)
73+
74+
# Find the items model URI
75+
items_uri = next(
76+
uri
77+
for uri, info in lsp_context.map.items()
78+
if isinstance(info, ModelTarget) and "sushi.items" in info.names
79+
)
80+
81+
references = get_model_definitions_for_a_path(lsp_context, audit_uri)
82+
83+
assert len(references) == 1
84+
assert references[0].uri == items_uri
85+
86+
# Check that the reference in the correct range is sushi.items
87+
path = audit_uri.removeprefix("file://")
88+
read_file = open(path, "r").readlines()
89+
referenced_text = get_string_from_range(read_file, references[0].range)
90+
assert referenced_text == "sushi.items"
91+
92+
5593
def get_string_from_range(file_lines, range_obj) -> str:
5694
start_line = range_obj.start.line
5795
end_line = range_obj.end.line

0 commit comments

Comments
 (0)