|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import typing as t |
| 4 | + |
| 5 | +from sqlglot import exp |
| 6 | +from sqlglot.dialects.dialect import DialectType |
| 7 | +from sqlglot.optimizer import find_all_in_scope, optimize |
| 8 | +from sqlglot.optimizer.qualify import qualify |
| 9 | + |
| 10 | +from sqlmesh.core import dialect as d |
| 11 | +from sqlmesh.core.metric.definition import Metric, remove_namespace |
| 12 | + |
| 13 | +if t.TYPE_CHECKING: |
| 14 | + from sqlmesh.core.reference import ReferenceGraph |
| 15 | + |
| 16 | + |
| 17 | +AggsAndJoins = t.Tuple[t.Set[exp.AggFunc], t.Set[str]] |
| 18 | + |
| 19 | + |
| 20 | +class Rewriter: |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + graph: ReferenceGraph, |
| 24 | + metrics: t.Dict[str, Metric], |
| 25 | + dialect: DialectType = "", |
| 26 | + join_type: str = "FULL", |
| 27 | + semantic_prefix: str = "__semantic", |
| 28 | + ): |
| 29 | + self.graph = graph |
| 30 | + self.metrics = metrics |
| 31 | + self.dialect = dialect |
| 32 | + self.join_type = join_type |
| 33 | + self.semantic_prefix = semantic_prefix |
| 34 | + self.semantic_table = f"{self.semantic_prefix}.__table" |
| 35 | + |
| 36 | + def rewrite(self, expression: exp.Expression) -> exp.Expression: |
| 37 | + for select in list(expression.find_all(exp.Select)): |
| 38 | + self._expand(select) |
| 39 | + |
| 40 | + return expression |
| 41 | + |
| 42 | + def _expand(self, select: exp.Select) -> None: |
| 43 | + base = select.args["from"].this |
| 44 | + base_name = d.normalize_model_name(base.find(exp.Table), dialect=self.dialect) |
| 45 | + base_alias = base.alias_or_name |
| 46 | + |
| 47 | + sources: t.Dict[str, AggsAndJoins] = {} |
| 48 | + |
| 49 | + for projection in select.selects: |
| 50 | + for ref in find_all_in_scope(projection, d.MetricAgg): |
| 51 | + metric = self.metrics[ref.this.name] |
| 52 | + ref.replace(metric.formula.this) |
| 53 | + |
| 54 | + for agg, (measure, dims) in metric.aggs.items(): |
| 55 | + if base_name.lower() == self.semantic_table: |
| 56 | + base_name = measure |
| 57 | + base = exp.to_table(base_name) |
| 58 | + |
| 59 | + aggs, joins = sources.setdefault(measure, (set(), set())) |
| 60 | + aggs.add(agg) |
| 61 | + joins.update(dims) |
| 62 | + |
| 63 | + def remove_table(node: exp.Expression) -> exp.Expression: |
| 64 | + for column in find_all_in_scope(node, exp.Column): |
| 65 | + if column.table == base_alias: |
| 66 | + column.args["table"].pop() |
| 67 | + return node |
| 68 | + |
| 69 | + def replace_table(node: exp.Expression, table: str) -> exp.Expression: |
| 70 | + for column in find_all_in_scope(node, exp.Column): |
| 71 | + if column.table == base_alias: |
| 72 | + column.args["table"] = exp.to_identifier(table) |
| 73 | + return node |
| 74 | + |
| 75 | + group = select.args.pop("group", None) |
| 76 | + group_by = group.expressions if group else [] |
| 77 | + |
| 78 | + for name, (aggs, joins) in sources.items(): |
| 79 | + if name == base_name: |
| 80 | + source = base |
| 81 | + else: |
| 82 | + source = exp.to_table(name) |
| 83 | + |
| 84 | + table_name = remove_namespace(name) |
| 85 | + |
| 86 | + if not isinstance(source, exp.Subqueryable): |
| 87 | + source = exp.Select().from_( |
| 88 | + exp.alias_(source, table_name, table=True, copy=False), copy=False |
| 89 | + ) |
| 90 | + |
| 91 | + self._add_joins(source, name, joins) |
| 92 | + |
| 93 | + grain = [replace_table(e.copy(), table_name) for e in group_by] |
| 94 | + |
| 95 | + query = source.select( |
| 96 | + *grain, |
| 97 | + *sorted(aggs, key=str), |
| 98 | + copy=False, |
| 99 | + ).group_by(*grain, copy=False) |
| 100 | + |
| 101 | + if not query.selects: |
| 102 | + query.select("*", copy=False) |
| 103 | + |
| 104 | + if name == base_name: |
| 105 | + where = select.args.pop("where", None) |
| 106 | + |
| 107 | + if where: |
| 108 | + query.where(remove_table(where.this), copy=False) |
| 109 | + |
| 110 | + select.from_(query.subquery(base_alias, copy=False), copy=False) |
| 111 | + else: |
| 112 | + select.join( |
| 113 | + query, |
| 114 | + on=[e.eq(replace_table(e.copy(), table_name)) for e in group_by], # type: ignore |
| 115 | + join_type=self.join_type, |
| 116 | + join_alias=table_name, |
| 117 | + copy=False, |
| 118 | + ) |
| 119 | + |
| 120 | + def _add_joins(self, source: exp.Select, name: str, joins: t.Collection[str]) -> None: |
| 121 | + for join in joins: |
| 122 | + path = self.graph.find_path(name, join) |
| 123 | + for i in range(len(path) - 1): |
| 124 | + a_ref = path[i] |
| 125 | + b_ref = path[i + 1] |
| 126 | + a_model_alias = remove_namespace(a_ref.model_name) |
| 127 | + b_model_alias = remove_namespace(b_ref.model_name) |
| 128 | + |
| 129 | + a = a_ref.expression.copy() |
| 130 | + a.set("table", exp.to_identifier(a_model_alias)) |
| 131 | + b = b_ref.expression.copy() |
| 132 | + b.set("table", exp.to_identifier(b_model_alias)) |
| 133 | + |
| 134 | + source.join( |
| 135 | + b_ref.model_name, |
| 136 | + on=a.eq(b), |
| 137 | + join_type="LEFT", |
| 138 | + join_alias=b_model_alias, |
| 139 | + dialect=self.dialect, |
| 140 | + copy=False, |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +def rewrite( |
| 145 | + sql: str | exp.Expression, |
| 146 | + graph: ReferenceGraph, |
| 147 | + metrics: t.Dict[str, Metric], |
| 148 | + dialect: str = "", |
| 149 | +) -> exp.Expression: |
| 150 | + rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect) |
| 151 | + |
| 152 | + return optimize( |
| 153 | + d.parse_one(sql, dialect=dialect) if isinstance(sql, str) else sql, |
| 154 | + dialect=dialect, |
| 155 | + rules=( |
| 156 | + qualify, |
| 157 | + rewriter.rewrite, |
| 158 | + ), |
| 159 | + ) |
0 commit comments