Skip to content

Commit 7566b8a

Browse files
authored
Refactor!: use a rewriter instead of a renderer for metrics (#1340)
1 parent 88da4d9 commit 7566b8a

File tree

7 files changed

+270
-198
lines changed

7 files changed

+270
-198
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"requests",
4747
"rich",
4848
"ruamel.yaml",
49-
"sqlglot~=17.14.1",
49+
"sqlglot~=17.15.0",
5050
],
5151
extras_require={
5252
"bigquery": [

sqlmesh/core/dialect.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ class DColonCast(exp.Cast):
7676
pass
7777

7878

79+
class MetricAgg(exp.AggFunc):
80+
"""Used for computing metrics."""
81+
82+
arg_types = {"this": True}
83+
84+
7985
def _scan_var(self: Tokenizer) -> None:
8086
param = False
8187
bracket = False
@@ -283,10 +289,13 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
283289

284290

285291
def _parse_types(
286-
self: Parser, check_func: bool = False, schema: bool = False
292+
self: Parser,
293+
check_func: bool = False,
294+
schema: bool = False,
295+
allow_identifiers: bool = True,
287296
) -> t.Optional[exp.Expression]:
288297
start = self._curr
289-
parsed_type = self.__parse_types(check_func=check_func, schema=schema) # type: ignore
298+
parsed_type = self.__parse_types(check_func=check_func, schema=schema, allow_identifiers=allow_identifiers) # type: ignore
290299

291300
if schema and parsed_type:
292301
parsed_type.meta["sql"] = self._find_sql(start, self._prev)
@@ -602,7 +611,7 @@ def extend_sqlglot() -> None:
602611
tokenizer.VAR_SINGLE_TOKENS.update("@")
603612

604613
for parser in parsers:
605-
parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list})
614+
parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list, "METRIC": MetricAgg.from_arg_list})
606615
parser.PLACEHOLDER_PARSERS.update({TokenType.PARAMETER: _parse_macro})
607616
parser.QUERY_MODIFIER_PARSERS.update(
608617
{TokenType.PARAMETER: lambda self: _parse_body_macro(self)}

sqlmesh/core/metric/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
expand_metrics,
55
load_metric_ddl,
66
)
7-
from sqlmesh.core.metric.renderer import Renderer
7+
from sqlmesh.core.metric.rewriter import rewrite

sqlmesh/core/metric/renderer.py

Lines changed: 0 additions & 138 deletions
This file was deleted.

sqlmesh/core/metric/rewriter.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from sqlglot import exp
6+
from sqlglot.dialects.dialect import DialectType
7+
from sqlglot.optimizer import find_all_in_scope, optimize
8+
from sqlglot.optimizer.qualify import qualify
9+
10+
from sqlmesh.core import dialect as d
11+
from sqlmesh.core.metric.definition import Metric, remove_namespace
12+
13+
if t.TYPE_CHECKING:
14+
from sqlmesh.core.reference import ReferenceGraph
15+
16+
17+
AggsAndJoins = t.Tuple[t.Set[exp.AggFunc], t.Set[str]]
18+
19+
20+
class Rewriter:
21+
def __init__(
22+
self,
23+
graph: ReferenceGraph,
24+
metrics: t.Dict[str, Metric],
25+
dialect: DialectType = "",
26+
join_type: str = "FULL",
27+
semantic_prefix: str = "__semantic",
28+
):
29+
self.graph = graph
30+
self.metrics = metrics
31+
self.dialect = dialect
32+
self.join_type = join_type
33+
self.semantic_prefix = semantic_prefix
34+
self.semantic_table = f"{self.semantic_prefix}.__table"
35+
36+
def rewrite(self, expression: exp.Expression) -> exp.Expression:
37+
for select in list(expression.find_all(exp.Select)):
38+
self._expand(select)
39+
40+
return expression
41+
42+
def _expand(self, select: exp.Select) -> None:
43+
base = select.args["from"].this
44+
base_name = d.normalize_model_name(base.find(exp.Table), dialect=self.dialect)
45+
base_alias = base.alias_or_name
46+
47+
sources: t.Dict[str, AggsAndJoins] = {}
48+
49+
for projection in select.selects:
50+
for ref in find_all_in_scope(projection, d.MetricAgg):
51+
metric = self.metrics[ref.this.name]
52+
ref.replace(metric.formula.this)
53+
54+
for agg, (measure, dims) in metric.aggs.items():
55+
if base_name.lower() == self.semantic_table:
56+
base_name = measure
57+
base = exp.to_table(base_name)
58+
59+
aggs, joins = sources.setdefault(measure, (set(), set()))
60+
aggs.add(agg)
61+
joins.update(dims)
62+
63+
def remove_table(node: exp.Expression) -> exp.Expression:
64+
for column in find_all_in_scope(node, exp.Column):
65+
if column.table == base_alias:
66+
column.args["table"].pop()
67+
return node
68+
69+
def replace_table(node: exp.Expression, table: str) -> exp.Expression:
70+
for column in find_all_in_scope(node, exp.Column):
71+
if column.table == base_alias:
72+
column.args["table"] = exp.to_identifier(table)
73+
return node
74+
75+
group = select.args.pop("group", None)
76+
group_by = group.expressions if group else []
77+
78+
for name, (aggs, joins) in sources.items():
79+
if name == base_name:
80+
source = base
81+
else:
82+
source = exp.to_table(name)
83+
84+
table_name = remove_namespace(name)
85+
86+
if not isinstance(source, exp.Subqueryable):
87+
source = exp.Select().from_(
88+
exp.alias_(source, table_name, table=True, copy=False), copy=False
89+
)
90+
91+
self._add_joins(source, name, joins)
92+
93+
grain = [replace_table(e.copy(), table_name) for e in group_by]
94+
95+
query = source.select(
96+
*grain,
97+
*sorted(aggs, key=str),
98+
copy=False,
99+
).group_by(*grain, copy=False)
100+
101+
if not query.selects:
102+
query.select("*", copy=False)
103+
104+
if name == base_name:
105+
where = select.args.pop("where", None)
106+
107+
if where:
108+
query.where(remove_table(where.this), copy=False)
109+
110+
select.from_(query.subquery(base_alias, copy=False), copy=False)
111+
else:
112+
select.join(
113+
query,
114+
on=[e.eq(replace_table(e.copy(), table_name)) for e in group_by], # type: ignore
115+
join_type=self.join_type,
116+
join_alias=table_name,
117+
copy=False,
118+
)
119+
120+
def _add_joins(self, source: exp.Select, name: str, joins: t.Collection[str]) -> None:
121+
for join in joins:
122+
path = self.graph.find_path(name, join)
123+
for i in range(len(path) - 1):
124+
a_ref = path[i]
125+
b_ref = path[i + 1]
126+
a_model_alias = remove_namespace(a_ref.model_name)
127+
b_model_alias = remove_namespace(b_ref.model_name)
128+
129+
a = a_ref.expression.copy()
130+
a.set("table", exp.to_identifier(a_model_alias))
131+
b = b_ref.expression.copy()
132+
b.set("table", exp.to_identifier(b_model_alias))
133+
134+
source.join(
135+
b_ref.model_name,
136+
on=a.eq(b),
137+
join_type="LEFT",
138+
join_alias=b_model_alias,
139+
dialect=self.dialect,
140+
copy=False,
141+
)
142+
143+
144+
def rewrite(
145+
sql: str | exp.Expression,
146+
graph: ReferenceGraph,
147+
metrics: t.Dict[str, Metric],
148+
dialect: str = "",
149+
) -> exp.Expression:
150+
rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect)
151+
152+
return optimize(
153+
d.parse_one(sql, dialect=dialect) if isinstance(sql, str) else sql,
154+
dialect=dialect,
155+
rules=(
156+
qualify,
157+
rewriter.rewrite,
158+
),
159+
)

0 commit comments

Comments
 (0)