Skip to content

Commit 3b03f51

Browse files
authored
Feat: allow formatter to use CAST over :: syntax (#3173)
1 parent b9b5e45 commit 3b03f51

File tree

9 files changed

+94
-28
lines changed

9 files changed

+94
-28
lines changed

docs/reference/cli.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ Options:
172172
-t, --transpile TEXT Transpile project models to the specified
173173
dialect.
174174
--append-newline Include a newline at the end of each file.
175+
--no-rewrite-casts Preserve the existing casts, without rewriting
176+
them to use the :: syntax.
175177
--normalize Whether or not to normalize identifiers to
176178
lowercase.
177179
--pad INTEGER Determines the pad size in a formatted string.

docs/reference/configuration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ Formatting settings for the `sqlmesh format` command and UI.
9393
| `leading_comma` | Whether to use leading commas (Default: False) | boolean | N |
9494
| `max_text_width` | The maximum text width in a segment before creating new lines (Default: 80) | int | N |
9595
| `append_newline` | Whether to append a newline to the end of the file (Default: False) | boolean | N |
96+
| `no_rewrite_casts` | Preserve the existing casts, without rewriting them to use the :: syntax. (Default: False) | boolean | N |
9697

9798
## UI
9899

docs/reference/notebook.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ options:
441441

442442
#### format
443443
```
444-
%format [--transpile TRANSPILE] [--append-newline] [--normalize]
445-
[--pad PAD] [--indent INDENT]
444+
%format [--transpile TRANSPILE] [--append-newline] [--no-rewrite-casts]
445+
[--normalize] [--pad PAD] [--indent INDENT]
446446
[--normalize-functions NORMALIZE_FUNCTIONS] [--leading-comma]
447447
[--max-text-width MAX_TEXT_WIDTH] [--check]
448448
@@ -453,6 +453,8 @@ options:
453453
Transpile project models to the specified dialect.
454454
--append-newline Whether or not to append a newline to the end of the
455455
file.
456+
--no-rewrite-casts Preserve the existing casts, without rewriting them
457+
to use the :: syntax.
456458
--normalize Whether or not to normalize identifiers to lowercase.
457459
--pad PAD Determines the pad size in a formatted string.
458460
--indent INDENT Determines the indentation size in a formatted string.

sqlmesh/cli/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def evaluate(
223223
help="Include a newline at the end of each file.",
224224
default=None,
225225
)
226+
@click.option(
227+
"--no-rewrite-casts",
228+
is_flag=True,
229+
help="Preserve the existing casts, without rewriting them to use the :: syntax.",
230+
default=None,
231+
)
226232
@click.option(
227233
"--normalize",
228234
is_flag=True,
@@ -266,6 +272,9 @@ def evaluate(
266272
@cli_analytics
267273
def format(ctx: click.Context, **kwargs: t.Any) -> None:
268274
"""Format all SQL models and audits."""
275+
if kwargs.pop("no_rewrite_casts", None):
276+
kwargs["rewrite_casts"] = False
277+
269278
if not ctx.obj.format(**{k: v for k, v in kwargs.items() if v is not None}):
270279
ctx.exit(1)
271280

sqlmesh/core/config/format.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class FormatConfig(BaseConfig):
1616
leading_comma: Whether to use leading commas or not.
1717
max_text_width: The maximum text width in a segment before creating new lines.
1818
append_newline: Whether to append a newline to the end of the file or not.
19+
no_rewrite_casts: Preserve the existing casts, without rewriting them to use the :: syntax.
1920
"""
2021

2122
normalize: bool = False
@@ -25,6 +26,7 @@ class FormatConfig(BaseConfig):
2526
leading_comma: bool = False
2627
max_text_width: int = 80
2728
append_newline: bool = False
29+
no_rewrite_casts: bool = False
2830

2931
@property
3032
def generator_options(self) -> t.Dict[str, t.Any]:
@@ -33,4 +35,4 @@ def generator_options(self) -> t.Dict[str, t.Any]:
3335
Returns:
3436
The generator options.
3537
"""
36-
return self.dict(exclude={"append_newline"})
38+
return self.dict(exclude={"append_newline", "no_rewrite_casts"})

sqlmesh/core/context.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ def evaluate(
900900
def format(
901901
self,
902902
transpile: t.Optional[str] = None,
903+
rewrite_casts: t.Optional[bool] = None,
903904
append_newline: t.Optional[bool] = None,
904905
*,
905906
check: t.Optional[bool] = None,
@@ -910,6 +911,7 @@ def format(
910911
for target in format_targets.values():
911912
if target._path is None or target._path.suffix != ".sql":
912913
continue
914+
913915
with open(target._path, "r+", encoding="utf-8") as file:
914916
before = file.read()
915917
expressions = parse(before, default_dialect=self.config_for_node(target).dialect)
@@ -922,13 +924,24 @@ def format(
922924
value=exp.Literal.string(transpile or target.dialect),
923925
)
924926
)
925-
format = self.config_for_node(target).format
926-
opts = {**format.generator_options, **kwargs}
927-
after = format_model_expressions(expressions, transpile or target.dialect, **opts)
927+
928+
format_config = self.config_for_node(target).format
929+
after = format_model_expressions(
930+
expressions,
931+
transpile or target.dialect,
932+
rewrite_casts=(
933+
rewrite_casts
934+
if rewrite_casts is not None
935+
else not format_config.no_rewrite_casts
936+
),
937+
**{**format_config.generator_options, **kwargs},
938+
)
939+
928940
if append_newline is None:
929-
append_newline = format.append_newline
941+
append_newline = format_config.append_newline
930942
if append_newline:
931943
after += "\n"
944+
932945
if not check:
933946
file.seek(0)
934947
file.write(after)

sqlmesh/core/dialect.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -643,13 +643,17 @@ def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
643643

644644

645645
def format_model_expressions(
646-
expressions: t.List[exp.Expression], dialect: t.Optional[str] = None, **kwargs: t.Any
646+
expressions: t.List[exp.Expression],
647+
dialect: t.Optional[str] = None,
648+
rewrite_casts: bool = True,
649+
**kwargs: t.Any,
647650
) -> str:
648651
"""Format a model's expressions into a standardized format.
649652
650653
Args:
651654
expressions: The model's expressions, must be at least model def + query.
652655
dialect: The dialect to render the expressions as.
656+
rewrite_casts: Whether to rewrite all casts to use the :: syntax.
653657
**kwargs: Additional keyword arguments to pass to the sql generator.
654658
655659
Returns:
@@ -660,26 +664,28 @@ def format_model_expressions(
660664

661665
*statements, query = expressions
662666

663-
def cast_to_colon(node: exp.Expression) -> exp.Expression:
664-
if isinstance(node, exp.Cast) and not any(
665-
# Only convert CAST into :: if it doesn't have additional args set, otherwise this
666-
# conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
667-
arg
668-
for name, arg in node.args.items()
669-
if name not in ("this", "to")
670-
):
671-
this = node.this
667+
if rewrite_casts:
672668

673-
if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
674-
cast = DColonCast(this=this, to=node.to)
675-
cast.comments = node.comments
676-
node = cast
669+
def cast_to_colon(node: exp.Expression) -> exp.Expression:
670+
if isinstance(node, exp.Cast) and not any(
671+
# Only convert CAST into :: if it doesn't have additional args set, otherwise this
672+
# conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
673+
arg
674+
for name, arg in node.args.items()
675+
if name not in ("this", "to")
676+
):
677+
this = node.this
677678

678-
exp.replace_children(node, cast_to_colon)
679-
return node
679+
if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
680+
cast = DColonCast(this=this, to=node.to)
681+
cast.comments = node.comments
682+
node = cast
683+
684+
exp.replace_children(node, cast_to_colon)
685+
return node
680686

681-
query = query.copy()
682-
exp.replace_children(query, cast_to_colon)
687+
query = query.copy()
688+
exp.replace_children(query, cast_to_colon)
683689

684690
return ";\n\n".join(
685691
[

sqlmesh/magics.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,10 @@ def model(self, context: Context, line: str, sql: t.Optional[str] = None) -> Non
199199
expressions = parse(file.read(), default_dialect=config.dialect)
200200

201201
formatted = format_model_expressions(
202-
expressions, model.dialect, **config.format.generator_options
202+
expressions,
203+
model.dialect,
204+
rewrite_casts=not config.format.no_rewrite_casts,
205+
**config.format.generator_options,
203206
)
204207

205208
self._shell.set_next_input(
@@ -703,6 +706,12 @@ def rewrite(self, context: Context, line: str, sql: str) -> None:
703706
help="Whether or not to append a newline to the end of the file.",
704707
default=None,
705708
)
709+
@argument(
710+
"--no-rewrite-casts",
711+
action="store_true",
712+
help="Whether or not to preserve the existing casts, without rewriting them to use the :: syntax.",
713+
default=None,
714+
)
706715
@argument(
707716
"--normalize",
708717
action="store_true",
@@ -745,8 +754,11 @@ def rewrite(self, context: Context, line: str, sql: str) -> None:
745754
@pass_sqlmesh_context
746755
def format(self, context: Context, line: str) -> bool:
747756
"""Format all SQL models and audits."""
748-
args = parse_argstring(self.format, line)
749-
return context.format(**{k: v for k, v in vars(args).items() if v is not None})
757+
format_opts = vars(parse_argstring(self.format, line))
758+
if format_opts.pop("no_rewrite_casts", None):
759+
format_opts["rewrite_casts"] = False
760+
761+
return context.format(**{k: v for k, v in format_opts.items() if v is not None})
750762

751763
@magic_arguments()
752764
@argument("environment", type=str, help="The environment to diff local state against.")

tests/core/test_dialect.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,25 @@ def test_format_model_expressions():
194194
SAFE_CAST('bla' AS INT64) AS FOO"""
195195
)
196196

197+
x = format_model_expressions(
198+
parse(
199+
"""
200+
MODEL(name foo);
201+
SELECT 1::INT AS bla
202+
"""
203+
),
204+
rewrite_casts=False,
205+
)
206+
assert (
207+
x
208+
== """MODEL (
209+
name foo
210+
);
211+
212+
SELECT
213+
CAST(1 AS INT) AS bla"""
214+
)
215+
197216

198217
def test_macro_format():
199218
assert parse_one("@EACH(ARRAY(1,2), x -> x)").sql() == "@EACH(ARRAY(1, 2), x -> x)"

0 commit comments

Comments
 (0)