Skip to content

Commit 598dfbd

Browse files
authored
Chore: fix mypy type hints, function prototypes to get parity with sqlglot (#901)
1 parent 74420d9 commit 598dfbd

File tree

1 file changed

+37
-40
lines changed

1 file changed

+37
-40
lines changed

sqlmesh/core/dialect.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class DColonCast(exp.Cast):
6161
pass
6262

6363

64-
@t.no_type_check
6564
def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
6665
if self._curr is None:
6766
return None
@@ -73,16 +72,16 @@ def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
7372
comments = self._curr.comments
7473

7574
self._advance()
76-
meta = self._parse_wrapped(lambda: parser(self))
75+
meta = self._parse_wrapped(lambda: t.cast(t.Callable, parser)(self))
7776

7877
meta.comments = comments
7978
return meta
80-
return self.__parse_statement()
8179

80+
return self.__parse_statement() # type: ignore
8281

83-
@t.no_type_check
84-
def _parse_lambda(self: Parser) -> t.Optional[exp.Expression]:
85-
node = self.__parse_lambda()
82+
83+
def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expression]:
84+
node = self.__parse_lambda(alias=alias) # type: ignore
8685
if isinstance(node, exp.Lambda):
8786
node.set("this", self._parse_alias(node.this))
8887
return node
@@ -97,6 +96,7 @@ def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expres
9796
if macro_name != keyword_macro and macro_name in KEYWORD_MACROS:
9897
self._retreat(index)
9998
return None
99+
100100
if isinstance(field, exp.Anonymous):
101101
name = field.name.upper()
102102
if name == "DEF":
@@ -106,13 +106,15 @@ def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expres
106106
if name == "SQL":
107107
into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None
108108
return self.expression(MacroSQL, this=field.expressions[0], into=into)
109+
109110
return self.expression(MacroFunc, this=field)
110111

111112
if field is None:
112113
return None
113114

114115
if field.is_string or (isinstance(field, exp.Identifier) and field.quoted):
115116
return self.expression(MacroStrReplace, this=exp.Literal.string(field.this))
117+
116118
return self.expression(MacroVar, this=field.this)
117119

118120

@@ -125,30 +127,28 @@ def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]
125127
):
126128
return None
127129

128-
self._advance(1)
130+
self._advance()
129131
return _parse_macro(self, keyword_macro=name)
130132

131133

132-
@t.no_type_check
133-
def _parse_with(self: Parser) -> t.Optional[exp.Expression]:
134+
def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
134135
macro = _parse_matching_macro(self, "WITH")
135136
if not macro:
136-
return self.__parse_with()
137+
return self.__parse_with() # type: ignore
137138

138-
macro.this.append("expressions", self.__parse_with(True))
139+
macro.this.append("expressions", self.__parse_with(skip_with_token=True)) # type: ignore
139140
return macro
140141

141142

142-
@t.no_type_check
143-
def _parse_join(self: Parser) -> t.Optional[exp.Expression]:
143+
def _parse_join(self: Parser, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
144144
index = self._index
145145
natural, side, kind = self._parse_join_side_and_kind()
146146
macro = _parse_matching_macro(self, "JOIN")
147147
if not macro:
148148
self._retreat(index)
149-
return self.__parse_join()
149+
return self.__parse_join() # type: ignore
150150

151-
join = self.__parse_join(True)
151+
join = self.__parse_join(skip_join_token=True) # type: ignore
152152
if natural:
153153
join.set("natural", True)
154154
if side:
@@ -160,48 +160,46 @@ def _parse_join(self: Parser) -> t.Optional[exp.Expression]:
160160
return macro
161161

162162

163-
@t.no_type_check
164-
def _parse_where(self: Parser) -> t.Optional[exp.Expression]:
163+
def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
165164
macro = _parse_matching_macro(self, "WHERE")
166165
if not macro:
167-
return self.__parse_where()
166+
return self.__parse_where() # type: ignore
168167

169-
macro.this.append("expressions", self.__parse_where(True))
168+
macro.this.append("expressions", self.__parse_where(skip_where_token=True)) # type: ignore
170169
return macro
171170

172171

173-
@t.no_type_check
174-
def _parse_group(self: Parser) -> t.Optional[exp.Expression]:
172+
def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
175173
macro = _parse_matching_macro(self, "GROUP_BY")
176174
if not macro:
177-
return self.__parse_group()
175+
return self.__parse_group() # type: ignore
178176

179-
macro.this.append("expressions", self.__parse_group(True))
177+
macro.this.append("expressions", self.__parse_group(skip_group_by_token=True)) # type: ignore
180178
return macro
181179

182180

183-
@t.no_type_check
184-
def _parse_having(self: Parser) -> t.Optional[exp.Expression]:
181+
def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
185182
macro = _parse_matching_macro(self, "HAVING")
186183
if not macro:
187-
return self.__parse_having()
184+
return self.__parse_having() # type: ignore
188185

189-
macro.this.append("expressions", self.__parse_having(True))
186+
macro.this.append("expressions", self.__parse_having(skip_having_token=True)) # type: ignore
190187
return macro
191188

192189

193-
@t.no_type_check
194-
def _parse_order(self: Parser, this: exp.Expression = None) -> t.Optional[exp.Expression]:
190+
def _parse_order(
191+
self: Parser, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
192+
) -> t.Optional[exp.Expression]:
195193
macro = _parse_matching_macro(self, "ORDER_BY")
196194
if not macro:
197-
return self.__parse_order(this)
195+
return self.__parse_order(this) # type: ignore
198196

199-
macro.this.append("expressions", self.__parse_order(this, True))
197+
macro.this.append("expressions", self.__parse_order(this, skip_order_token=True)) # type: ignore
200198
return macro
201199

202200

203201
def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
204-
key = self._parse_id_var(True)
202+
key = self._parse_id_var(any_token=True)
205203

206204
if not key:
207205
return None
@@ -344,7 +342,7 @@ def format_model_expressions(
344342

345343
if not isinstance(expression, exp.Alias):
346344
if expression.name:
347-
expression = expression.replace(exp.alias_(expression.copy(), expression.name))
345+
expression = expression.replace(exp.alias_(expression, expression.name))
348346

349347
column = column or expression
350348
expression = expression.this
@@ -353,6 +351,7 @@ def format_model_expressions(
353351
this = expression.this
354352
if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
355353
expression.replace(DColonCast(this=this, to=expression.to))
354+
356355
column.comments = comments
357356
selects.append(column)
358357

@@ -385,7 +384,7 @@ def text_diff(
385384
)
386385

387386

388-
def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]:
387+
def parse(sql: str, default_dialect: t.Optional[str] = None) -> t.List[exp.Expression]:
389388
"""Parse a sql string.
390389
391390
Supports parsing model definition.
@@ -437,7 +436,6 @@ def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression
437436
return expressions
438437

439438

440-
@t.no_type_check
441439
def extend_sqlglot() -> None:
442440
"""Extend SQLGlot with SQLMesh's custom macro aware dialect."""
443441
parsers = {Parser}
@@ -465,10 +463,8 @@ def extend_sqlglot() -> None:
465463
PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
466464
}
467465
)
468-
generator.WITH_SEPARATED_COMMENTS = (
469-
*generator.WITH_SEPARATED_COMMENTS,
470-
Model,
471-
)
466+
467+
generator.WITH_SEPARATED_COMMENTS = (*generator.WITH_SEPARATED_COMMENTS, Model) # type: ignore
472468

473469
for parser in parsers:
474470
parser.FUNCTIONS.update(
@@ -510,7 +506,8 @@ def select_from_values(
510506
This method operates as a generator and yields a VALUES expression.
511507
"""
512508
casted_columns = [
513-
exp.alias_(exp.cast(column, to=kind), column) for column, kind in columns_to_types.items()
509+
exp.alias_(exp.cast(column, to=kind), column, copy=False)
510+
for column, kind in columns_to_types.items()
514511
]
515512
batch = []
516513
for row in values:

0 commit comments

Comments
 (0)