Skip to content

Commit 943e496

Browse files
feat: function to get range of key/value in model block (#5119)
Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
1 parent 24e3e50 commit 943e496

File tree

2 files changed

+214
-59
lines changed

2 files changed

+214
-59
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 155 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sqlmesh.core.linter.rule import Range, Position
44
from sqlmesh.utils.pydantic import PydanticModel
5-
from sqlglot import tokenize, TokenType
5+
from sqlglot import tokenize, TokenType, Token
66
import typing as t
77

88

@@ -122,25 +122,65 @@ def read_range_from_file(file: Path, text_range: Range) -> str:
122122
return read_range_from_string("".join(lines), text_range)
123123

124124

125+
def get_start_and_end_of_model_block(
126+
tokens: t.List[Token],
127+
) -> t.Optional[t.Tuple[int, int]]:
128+
"""
129+
Returns the start and end tokens of the MODEL block in an SQL file.
130+
The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by
131+
an opening parenthesis and a closing parenthesis that matches the opening one.
132+
"""
133+
# 1) Find the MODEL token
134+
try:
135+
model_idx = next(
136+
i
137+
for i, tok in enumerate(tokens)
138+
if tok.token_type is TokenType.VAR and tok.text.upper() == "MODEL"
139+
)
140+
except StopIteration:
141+
return None
142+
143+
# 2) Find the opening parenthesis for the MODEL properties list
144+
try:
145+
lparen_idx = next(
146+
i
147+
for i in range(model_idx + 1, len(tokens))
148+
if tokens[i].token_type is TokenType.L_PAREN
149+
)
150+
except StopIteration:
151+
return None
152+
153+
# 3) Find the matching closing parenthesis by looking for the first semicolon after
154+
# the opening parenthesis and assuming the MODEL block ends there.
155+
try:
156+
closing_semicolon = next(
157+
i
158+
for i in range(lparen_idx + 1, len(tokens))
159+
if tokens[i].token_type is TokenType.SEMICOLON
160+
)
161+
# If we find a semicolon, we can assume the MODEL block ends there
162+
rparen_idx = closing_semicolon - 1
163+
if tokens[rparen_idx].token_type is TokenType.R_PAREN:
164+
return (lparen_idx, rparen_idx)
165+
return None
166+
except StopIteration:
167+
return None
168+
169+
125170
def get_range_of_model_block(
126171
sql: str,
127172
dialect: str,
128173
) -> t.Optional[Range]:
129174
"""
130-
Get the range of the model block in an SQL file.
175+
Get the range of the model block in an SQL file,
131176
"""
132177
tokens = tokenize(sql, dialect=dialect)
133-
134-
# Find start of the model block
135-
start = next(
136-
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
137-
None,
138-
)
139-
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)
140-
141-
if start is None or end is None:
178+
block = get_start_and_end_of_model_block(tokens)
179+
if not block:
142180
return None
143-
181+
(start_idx, end_idx) = block
182+
start = tokens[start_idx - 1]
183+
end = tokens[end_idx + 1]
144184
start_position = TokenPositionDetails(
145185
line=start.line,
146186
col=start.col,
@@ -153,60 +193,123 @@ def get_range_of_model_block(
153193
start=end.start,
154194
end=end.end,
155195
)
156-
157196
splitlines = sql.splitlines()
158197
return Range(
159-
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
198+
start=start_position.to_range(splitlines).start,
199+
end=end_position.to_range(splitlines).end,
160200
)
161201

162202

163203
def get_range_of_a_key_in_model_block(
164204
sql: str,
165205
dialect: str,
166206
key: str,
167-
) -> t.Optional[Range]:
207+
) -> t.Optional[t.Tuple[Range, Range]]:
168208
"""
169-
Get the range of a specific key in the model block of an SQL file.
209+
Get the ranges of a specific key and its value in the MODEL block of an SQL file.
210+
211+
Returns a tuple of (key_range, value_range) if found, otherwise None.
170212
"""
171213
tokens = tokenize(sql, dialect=dialect)
172-
if tokens is None:
214+
if not tokens:
173215
return None
174216

175-
# Find the start of the model block
176-
start_index = next(
177-
(
178-
i
179-
for i, t in enumerate(tokens)
180-
if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"
181-
),
182-
None,
183-
)
184-
end_index = next(
185-
(i for i, t in enumerate(tokens) if t.token_type is TokenType.SEMICOLON),
186-
None,
187-
)
188-
if start_index is None or end_index is None:
189-
return None
190-
if start_index >= end_index:
217+
block = get_start_and_end_of_model_block(tokens)
218+
if not block:
191219
return None
220+
(lparen_idx, rparen_idx) = block
221+
222+
# 4) Scan within the MODEL property list for the key at top-level (depth == 1)
223+
# Initialize depth to 1 since we're inside the first parentheses
224+
depth = 1
225+
for i in range(lparen_idx + 1, rparen_idx):
226+
tok = tokens[i]
227+
tt = tok.token_type
228+
229+
if tt is TokenType.L_PAREN:
230+
depth += 1
231+
continue
232+
if tt is TokenType.R_PAREN:
233+
depth -= 1
234+
# If we somehow exit before rparen_idx, stop early
235+
if depth <= 0:
236+
break
237+
continue
238+
239+
if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper():
240+
# Validate key position: it should immediately follow '(' or ',' at top level
241+
prev_idx = i - 1
242+
prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None
243+
if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA):
244+
continue
245+
246+
# Key range
247+
lines = sql.splitlines()
248+
key_start = TokenPositionDetails(
249+
line=tok.line, col=tok.col, start=tok.start, end=tok.end
250+
)
251+
key_range = key_start.to_range(lines)
252+
253+
value_start_idx = i + 1
254+
if value_start_idx >= rparen_idx:
255+
return None
256+
257+
# Walk to the end of the value expression: until top-level comma or closing paren
258+
# Track internal nesting for (), [], {}
259+
nested = 0
260+
j = value_start_idx
261+
value_end_idx = value_start_idx
262+
263+
def is_open(t: TokenType) -> bool:
264+
return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET)
265+
266+
def is_close(t: TokenType) -> bool:
267+
return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET)
268+
269+
while j < rparen_idx:
270+
ttype = tokens[j].token_type
271+
if is_open(ttype):
272+
nested += 1
273+
elif is_close(ttype):
274+
nested -= 1
275+
276+
# End of value: at top-level (nested == 0) encountering a comma or the end paren
277+
if nested == 0 and (
278+
ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1)
279+
):
280+
# For comma, don't include it in the value range
281+
# For closing paren, include it only if it's part of the value structure
282+
if ttype is TokenType.COMMA:
283+
# Don't include the comma in the value range
284+
break
285+
else:
286+
# Include the closing parenthesis in the value range
287+
value_end_idx = j
288+
break
289+
290+
value_end_idx = j
291+
j += 1
292+
293+
value_start_tok = tokens[value_start_idx]
294+
value_end_tok = tokens[value_end_idx]
295+
296+
value_start_pos = TokenPositionDetails(
297+
line=value_start_tok.line,
298+
col=value_start_tok.col,
299+
start=value_start_tok.start,
300+
end=value_start_tok.end,
301+
)
302+
value_end_pos = TokenPositionDetails(
303+
line=value_end_tok.line,
304+
col=value_end_tok.col,
305+
start=value_end_tok.start,
306+
end=value_end_tok.end,
307+
)
308+
value_range = Range(
309+
start=value_start_pos.to_range(lines).start,
310+
end=value_end_pos.to_range(lines).end,
311+
)
192312

193-
tokens_of_interest = tokens[start_index + 1 : end_index]
194-
# Find the key token
195-
key_token = next(
196-
(
197-
t
198-
for t in tokens_of_interest
199-
if t.token_type is TokenType.VAR and t.text.upper() == key.upper()
200-
),
201-
None,
202-
)
203-
if key_token is None:
204-
return None
313+
return (key_range, value_range)
205314

206-
position = TokenPositionDetails(
207-
line=key_token.line,
208-
col=key_token.col,
209-
start=key_token.start,
210-
end=key_token.end,
211-
)
212-
return position.to_range(sql.splitlines())
315+
return None

tests/core/linter/test_helpers.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
5252
]
5353
assert len(sql_models) > 0
5454

55+
# Test that the function works for all keys in the model block
5556
for model in sql_models:
56-
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]
57+
possible_keys = [
58+
"name",
59+
"tags",
60+
"description",
61+
"column_descriptions",
62+
"owner",
63+
"cron",
64+
"dialect",
65+
]
5766

5867
dialect = model.dialect
5968
assert dialect is not None
@@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
6776
count_properties_checked = 0
6877

6978
for key in possible_keys:
70-
range = get_range_of_a_key_in_model_block(content, dialect, key)
71-
72-
# Check that the range starts with the key and ends with ;
73-
if range:
74-
read_range = read_range_from_file(path, range)
75-
assert read_range.lower() == key.lower()
79+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
80+
81+
if ranges:
82+
key_range, value_range = ranges
83+
read_key = read_range_from_file(path, key_range)
84+
assert read_key.lower() == key.lower()
85+
# Value range should be non-empty
86+
read_value = read_range_from_file(path, value_range)
87+
assert len(read_value) > 0
7688
count_properties_checked += 1
7789

7890
assert count_properties_checked > 0
91+
92+
# Test that the function works for different kind of value blocks
93+
tests = [
94+
("sushi.customers", "name", "sushi.customers"),
95+
(
96+
"sushi.customers",
97+
"tags",
98+
"(pii, fact)",
99+
),
100+
("sushi.customers", "description", "'Sushi customer data'"),
101+
(
102+
"sushi.customers",
103+
"column_descriptions",
104+
"( customer_id = 'customer_id uniquely identifies customers' )",
105+
),
106+
("sushi.customers", "owner", "jen"),
107+
("sushi.customers", "cron", "'@daily'"),
108+
]
109+
for model_name, key, value in tests:
110+
model = context.get_model(model_name)
111+
assert model is not None
112+
113+
dialect = model.dialect
114+
assert dialect is not None
115+
116+
path = model._path
117+
assert path is not None
118+
119+
with open(path, "r", encoding="utf-8") as file:
120+
content = file.read()
121+
122+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
123+
assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'"
124+
125+
key_range, value_range = ranges
126+
read_key = read_range_from_file(path, key_range)
127+
assert read_key.lower() == key.lower()
128+
129+
read_value = read_range_from_file(path, value_range)
130+
assert read_value == value

0 commit comments

Comments
 (0)