Skip to content

Commit 40bd72f

Browse files
authored
feat(lsp): add keywords in query to autocompete (#4638)
1 parent aedb472 commit 40bd72f

File tree

4 files changed

+216
-8
lines changed

4 files changed

+216
-8
lines changed

sqlmesh/lsp/completions.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,25 @@
77

88

99
def get_sql_completions(
10-
context: t.Optional[LSPContext], file_uri: t.Optional[URI]
10+
context: t.Optional[LSPContext], file_uri: t.Optional[URI], content: t.Optional[str] = None
1111
) -> AllModelsResponse:
1212
"""
1313
Return a list of completions for a given file.
1414
"""
15+
# Get SQL keywords for the dialect
16+
sql_keywords = get_keywords(context, file_uri)
17+
18+
# Get keywords from file content if provided
19+
file_keywords = set()
20+
if content:
21+
file_keywords = extract_keywords_from_content(content, get_dialect(context, file_uri))
22+
23+
# Combine keywords - SQL keywords first, then file keywords
24+
all_keywords = list(sql_keywords) + list(file_keywords - sql_keywords)
25+
1526
return AllModelsResponse(
1627
models=list(get_models(context, file_uri)),
17-
keywords=list(get_keywords(context, file_uri)),
28+
keywords=all_keywords,
1829
)
1930

2031

@@ -97,3 +108,54 @@ def get_keywords_from_tokenizer(dialect: t.Optional[str] = None) -> t.Set[str]:
97108
parts = keyword.split(" ")
98109
expanded_keywords.update(parts)
99110
return expanded_keywords
111+
112+
113+
def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Optional[str]:
114+
"""
115+
Get the dialect for a given file.
116+
"""
117+
if file_uri is not None and context is not None and file_uri.to_path() in context.map:
118+
file_info = context.map[file_uri.to_path()]
119+
120+
# Handle ModelInfo objects
121+
if isinstance(file_info, ModelTarget) and file_info.names:
122+
model_name = file_info.names[0]
123+
model_from_context = context.context.get_model(model_name)
124+
return model_from_context.dialect
125+
126+
# Handle AuditInfo objects
127+
if isinstance(file_info, AuditTarget) and file_info.name:
128+
audit = context.context.standalone_audits.get(file_info.name)
129+
if audit is not None and audit.dialect:
130+
return audit.dialect
131+
132+
if context is not None:
133+
return context.context.default_dialect
134+
135+
return None
136+
137+
138+
def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]:
139+
"""
140+
Extract identifiers from SQL content using the tokenizer.
141+
Only extracts identifiers (variable names, table names, column names, etc.)
142+
that are not SQL keywords.
143+
"""
144+
if not content:
145+
return set()
146+
147+
tokenizer_class = Dialect.get_or_raise(dialect).tokenizer_class
148+
keywords = set()
149+
try:
150+
tokenizer = tokenizer_class()
151+
tokens = tokenizer.tokenize(content)
152+
for token in tokens:
153+
# Don't include keywords in the set
154+
if token.text.upper() not in tokenizer_class.KEYWORDS:
155+
keywords.add(token.text)
156+
157+
except Exception:
158+
# If tokenization fails, return empty set
159+
pass
160+
161+
return keywords

sqlmesh/lsp/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,18 @@ def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
176176
if audit._path is not None
177177
]
178178

179-
def get_autocomplete(self, uri: t.Optional[URI]) -> AllModelsResponse:
179+
def get_autocomplete(
180+
self, uri: t.Optional[URI], content: t.Optional[str] = None
181+
) -> AllModelsResponse:
180182
"""Get autocomplete suggestions for a file.
181183
182184
Args:
183185
uri: The URI of the file to get autocomplete suggestions for.
186+
content: The content of the file (optional).
184187
185188
Returns:
186189
AllModelsResponse containing models and keywords.
187190
"""
188191
from sqlmesh.lsp.completions import get_sql_completions
189192

190-
return get_sql_completions(self, uri)
193+
return get_sql_completions(self, uri, content)

sqlmesh/lsp/main.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,22 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
114114
@self.server.feature(ALL_MODELS_FEATURE)
115115
def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
116116
uri = URI(params.textDocument.uri)
117+
118+
# Get the document content
119+
content = None
120+
try:
121+
document = ls.workspace.get_text_document(params.textDocument.uri)
122+
content = document.source
123+
except Exception:
124+
pass
125+
117126
try:
118127
context = self._context_get_or_load(uri)
119-
return context.get_autocomplete(uri)
128+
return context.get_autocomplete(uri, content)
120129
except Exception as e:
121130
from sqlmesh.lsp.completions import get_sql_completions
122131

123-
return get_sql_completions(None, URI(params.textDocument.uri))
132+
return get_sql_completions(None, URI(params.textDocument.uri), content)
124133

125134
@self.server.feature(RENDER_MODEL_FEATURE)
126135
def render_model(ls: LanguageServer, params: RenderModelRequest) -> RenderModelResponse:
@@ -471,8 +480,16 @@ def completion(
471480
uri = URI(params.text_document.uri)
472481
context = self._context_get_or_load(uri)
473482

483+
# Get the document content
484+
content = None
485+
try:
486+
document = ls.workspace.get_text_document(params.text_document.uri)
487+
content = document.source
488+
except Exception:
489+
pass
490+
474491
# Get completions using the existing completions module
475-
completion_response = context.get_autocomplete(uri)
492+
completion_response = context.get_autocomplete(uri, content)
476493

477494
completion_items = []
478495
# Add model completions

tests/lsp/test_completions.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from sqlglot import Tokenizer
22
from sqlmesh.core.context import Context
3-
from sqlmesh.lsp.completions import get_keywords_from_tokenizer, get_sql_completions
3+
from sqlmesh.lsp.completions import (
4+
get_keywords_from_tokenizer,
5+
get_sql_completions,
6+
extract_keywords_from_content,
7+
)
48
from sqlmesh.lsp.context import LSPContext
59
from sqlmesh.lsp.uri import URI
610

@@ -36,3 +40,125 @@ def test_get_sql_completions_with_context_and_file_uri():
3640
completions = lsp_context.get_autocomplete(URI.from_path(file_uri))
3741
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
3842
assert "sushi.active_customers" not in completions.models
43+
44+
45+
def test_extract_keywords_from_content():
46+
# Test extracting keywords from SQL content
47+
content = """
48+
SELECT customer_id, order_date, total_amount
49+
FROM orders o
50+
JOIN customers c ON o.customer_id = c.id
51+
WHERE order_date > '2024-01-01'
52+
"""
53+
54+
keywords = extract_keywords_from_content(content)
55+
56+
# Check that identifiers are extracted
57+
assert "customer_id" in keywords
58+
assert "order_date" in keywords
59+
assert "total_amount" in keywords
60+
assert "orders" in keywords
61+
assert "customers" in keywords
62+
assert "o" in keywords # alias
63+
assert "c" in keywords # alias
64+
assert "id" in keywords
65+
66+
# Check that SQL keywords are NOT included
67+
assert "SELECT" not in keywords
68+
assert "FROM" not in keywords
69+
assert "JOIN" not in keywords
70+
assert "WHERE" not in keywords
71+
assert "ON" not in keywords
72+
73+
74+
def test_get_sql_completions_with_file_content():
75+
context = Context(paths=["examples/sushi"])
76+
lsp_context = LSPContext(context)
77+
78+
# SQL content with custom identifiers
79+
content = """
80+
SELECT my_custom_column, another_identifier
81+
FROM my_custom_table mct
82+
JOIN some_other_table sot ON mct.id = sot.table_id
83+
WHERE my_custom_column > 100
84+
"""
85+
86+
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
87+
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
88+
89+
# Check that SQL keywords are included
90+
assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords)
91+
92+
# Check that file-specific identifiers are included at the end
93+
keywords_list = completions.keywords
94+
assert "my_custom_column" in keywords_list
95+
assert "another_identifier" in keywords_list
96+
assert "my_custom_table" in keywords_list
97+
assert "some_other_table" in keywords_list
98+
assert "mct" in keywords_list # alias
99+
assert "sot" in keywords_list # alias
100+
assert "table_id" in keywords_list
101+
102+
# Check that file keywords come after SQL keywords
103+
# SQL keywords should appear first in the list
104+
sql_keyword_indices = [
105+
i for i, k in enumerate(keywords_list) if k in ["SELECT", "FROM", "WHERE", "JOIN"]
106+
]
107+
file_keyword_indices = [
108+
i for i, k in enumerate(keywords_list) if k in ["my_custom_column", "my_custom_table"]
109+
]
110+
111+
if sql_keyword_indices and file_keyword_indices:
112+
assert max(sql_keyword_indices) < min(file_keyword_indices), (
113+
"SQL keywords should come before file keywords"
114+
)
115+
116+
117+
def test_get_sql_completions_with_partial_cte_query():
118+
context = Context(paths=["examples/sushi"])
119+
lsp_context = LSPContext(context)
120+
121+
# Partial SQL query with CTEs
122+
content = """
123+
WITH _latest_complete_month AS (
124+
SELECT MAX(date_trunc('month', order_date)) as month
125+
FROM orders
126+
),
127+
_filtered AS (
128+
SELECT * FROM
129+
"""
130+
131+
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
132+
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
133+
134+
# Check that CTE names are included in the keywords
135+
keywords_list = completions.keywords
136+
assert "_latest_complete_month" in keywords_list
137+
assert "_filtered" in keywords_list
138+
139+
# Also check other identifiers from the partial query
140+
assert "month" in keywords_list
141+
assert "order_date" in keywords_list
142+
assert "orders" in keywords_list
143+
144+
145+
def test_extract_keywords_from_partial_query():
146+
# Test extracting keywords from an incomplete SQL query
147+
content = """
148+
WITH cte1 AS (
149+
SELECT col1, col2 FROM table1
150+
),
151+
cte2 AS (
152+
SELECT * FROM cte1 WHERE
153+
"""
154+
155+
keywords = extract_keywords_from_content(content)
156+
157+
# Check that CTEs are extracted
158+
assert "cte1" in keywords
159+
assert "cte2" in keywords
160+
161+
# Check that columns and tables are extracted
162+
assert "col1" in keywords
163+
assert "col2" in keywords
164+
assert "table1" in keywords

0 commit comments

Comments
 (0)