Skip to content

Commit 23a8caf

Browse files
authored
Feat: add more core macros (#1664)
* feat: add star macro * feat: add surrogate key macro * chore: add example of direct usage, use alias if passed * chore: add doctring to surrogate key and use exp true * chore: incorporate pr feedback * ci: ensure formatter run * chore: remove unused vars * fix: intersperse in surrogate key, typing for mypy * feat: simple safe math macros * ci: make style * feat: union macro and pr feedback * fix: docstring examples * feat: add width bucket macro * feat: add haversine distance macro * ci: use float for conv rate * feat: add pivot macro * ci: correct the docstring * chore: remove width bucket * feat: set type in union, use builtin dict since its insertion ordered
1 parent 7524013 commit 23a8caf

File tree

1 file changed

+242
-2
lines changed

1 file changed

+242
-2
lines changed

sqlmesh/core/macros.py

Lines changed: 242 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def send(
146146
return func(self, *args)
147147
except Exception as e:
148148
print_exception(e, self.python_env)
149-
raise MacroEvalError(f"Error trying to eval macro.") from e
149+
raise MacroEvalError("Error trying to eval macro.") from e
150150

151151
def transform(
152152
self, expression: exp.Expression
@@ -218,7 +218,7 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
218218
return node
219219

220220
if isinstance(node, (MacroSQL, MacroStrReplace)):
221-
result: t.Optional[t.Union[exp.Expression | t.List[exp.Expression]]] = exp.convert(
221+
result: t.Optional[exp.Expression | t.List[exp.Expression]] = exp.convert(
222222
self.eval_expression(node)
223223
)
224224
else:
@@ -676,6 +676,246 @@ def eval_(evaluator: MacroEvaluator, condition: exp.Condition) -> t.Any:
676676
return evaluator.eval_expression(condition)
677677

678678

679+
@macro()
680+
def star(
681+
evaluator: MacroEvaluator,
682+
relation: exp.Table,
683+
alias: t.Optional[exp.Identifier | exp.Column] = None,
684+
except_: t.Optional[exp.Array | exp.Tuple] = None,
685+
prefix: exp.Literal = exp.Literal.string(""),
686+
suffix: exp.Literal = exp.Literal.string(""),
687+
quote_identifiers: exp.Boolean = exp.true(),
688+
) -> t.List[exp.Alias]:
689+
"""Returns a list of projections for the given relation.
690+
691+
Example:
692+
>>> from sqlglot import parse_one
693+
>>> from sqlmesh.core.macros import MacroEvaluator
694+
>>> sql = "SELECT @STAR(foo, bar, [c], 'baz_') FROM foo AS bar"
695+
>>> MacroEvaluator(schema={"foo": {"a": "string", "b": "string", "c": "string", "d": "int"}}).transform(parse_one(sql)).sql()
696+
'SELECT CAST("bar"."a" AS TEXT) AS "baz_a", CAST("bar"."b" AS TEXT) AS "baz_b", CAST("bar"."d" AS INT) AS "baz_d" FROM foo AS bar'
697+
"""
698+
if alias and not isinstance(alias, (exp.Identifier, exp.Column)):
699+
raise SQLMeshError(f"Invalid alias '{alias}'. Expected an identifier.")
700+
if except_ and not isinstance(except_, (exp.Array, exp.Tuple)):
701+
raise SQLMeshError(f"Invalid except '{except_}'. Expected an array.")
702+
if prefix and not isinstance(prefix, exp.Literal):
703+
raise SQLMeshError(f"Invalid prefix '{prefix}'. Expected a literal.")
704+
if suffix and not isinstance(suffix, exp.Literal):
705+
raise SQLMeshError(f"Invalid suffix '{suffix}'. Expected a literal.")
706+
if not isinstance(quote_identifiers, exp.Boolean):
707+
raise SQLMeshError(f"Invalid quote_identifiers '{quote_identifiers}'. Expected a boolean.")
708+
projections: t.List[exp.Alias] = []
709+
exclude = set()
710+
kwargs = {"quoted": quote_identifiers.this}
711+
if alias:
712+
kwargs["table"] = alias.name
713+
if except_:
714+
exclude |= {
715+
e.name for e in except_.expressions if isinstance(e, (exp.Identifier, exp.Column))
716+
}
717+
for column, type_ in evaluator.columns_to_types(relation.sql()).items():
718+
if column in exclude:
719+
continue
720+
projections.append(
721+
exp.cast(exp.column(column, **kwargs), type_).as_(
722+
f"{prefix.this}{column}{suffix.this}", quoted=kwargs["quoted"]
723+
)
724+
)
725+
return projections
726+
727+
728+
@macro()
729+
def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Column | exp.Identifier) -> exp.Func:
730+
"""Generates a surrogate key for the given fields.
731+
732+
Example:
733+
>>> from sqlglot import parse_one
734+
>>> from sqlmesh.core.macros import MacroEvaluator
735+
>>> sql = "SELECT @GENERATE_SURROGATE_KEY(a, b, c) FROM foo"
736+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
737+
"SELECT MD5(CONCAT(COALESCE(CAST(a AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS TEXT), '_sqlmesh_surrogate_key_null_'))) FROM foo"
738+
"""
739+
default_null_value = exp.Literal.string("_sqlmesh_surrogate_key_null_")
740+
string_fields: t.List[exp.Expression] = []
741+
for i, field in enumerate(fields):
742+
if i > 0:
743+
string_fields.append(exp.Literal.string("|"))
744+
string_fields.append(
745+
exp.func(
746+
"COALESCE",
747+
exp.cast(field, exp.DataType.build("string")),
748+
default_null_value,
749+
)
750+
)
751+
return exp.func("MD5", exp.func("CONCAT", *string_fields))
752+
753+
754+
@macro()
755+
def safe_add(_: MacroEvaluator, *fields: exp.Column) -> exp.Case:
756+
"""Adds numbers together, substitutes nulls for 0s and only returns null if all fields are null.
757+
758+
Example:
759+
>>> from sqlglot import parse_one
760+
>>> from sqlmesh.core.macros import MacroEvaluator
761+
>>> sql = "SELECT @SAFE_ADD(a, b) FROM foo"
762+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
763+
'SELECT CASE WHEN a IS NULL AND b IS NULL THEN NULL ELSE COALESCE(a, 0) + COALESCE(b, 0) END FROM foo'
764+
"""
765+
null_cond = exp.and_(*[field.is_(exp.null()) for field in fields])
766+
case = exp.Case().when(null_cond, exp.null())
767+
terms: t.List[exp.Func | exp.Add] = []
768+
for field in fields:
769+
terms.append(exp.func("COALESCE", field, 0))
770+
return case.else_(reduce(lambda a, b: a + b, terms))
771+
772+
773+
@macro()
774+
def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
775+
"""Subtract numbers, substitutes nulls for 0s and only returns null if all fields are null.
776+
777+
Example:
778+
>>> from sqlglot import parse_one
779+
>>> from sqlmesh.core.macros import MacroEvaluator
780+
>>> sql = "SELECT @SAFE_SUB(a, b) FROM foo"
781+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
782+
'SELECT CASE WHEN a IS NULL AND b IS NULL THEN NULL ELSE COALESCE(a, 0) - COALESCE(b, 0) END FROM foo'
783+
"""
784+
null_cond = exp.and_(*[field.is_(exp.null()) for field in fields])
785+
case = exp.Case().when(null_cond, exp.null())
786+
terms: t.List[exp.Func | exp.Sub] = []
787+
for field in fields:
788+
terms.append(exp.func("COALESCE", field, 0))
789+
return case.else_(reduce(lambda a, b: a - b, terms))
790+
791+
792+
@macro()
793+
def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expression) -> exp.Div:
794+
"""Divides numbers, returns null if the denominator is 0.
795+
796+
Example:
797+
>>> from sqlglot import parse_one
798+
>>> from sqlmesh.core.macros import MacroEvaluator
799+
>>> sql = "SELECT @SAFE_DIV(a, b) FROM foo"
800+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
801+
'SELECT a / CASE WHEN b = 0 THEN NULL ELSE b END FROM foo'
802+
"""
803+
return numerator / exp.Case().when(denominator.eq(0), exp.null()).else_(denominator)
804+
805+
806+
@macro()
807+
def union(
808+
evaluator: MacroEvaluator,
809+
type_: exp.Literal = exp.Literal.string("ALL"),
810+
*tables: exp.Table,
811+
) -> exp.Union:
812+
"""Returns a UNION of the given tables.
813+
814+
Example:
815+
>>> from sqlglot import parse_one
816+
>>> from sqlmesh.core.macros import MacroEvaluator
817+
>>> sql = "@UNION('distinct', foo, bar)"
818+
>>> MacroEvaluator(schema={"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"a": "int", "b": "int", "c": "string"}}).transform(parse_one(sql)).sql()
819+
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
820+
"""
821+
if type_.this.upper() not in ("ALL", "DISTINCT"):
822+
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
823+
column_sets: t.List[t.Set[t.Tuple[str, exp.DataType]]] = []
824+
columns_seen: t.Dict[str, None] = {} # Ensure order is deterministic, 3.6+ dicts are ordered
825+
for table in tables:
826+
map = evaluator.columns_to_types(table.sql())
827+
column_sets.append(set(map.items()))
828+
for c in map:
829+
columns_seen[c] = None
830+
superset = reduce(lambda a, b: a.intersection(b), column_sets)
831+
precedence = {c: i for i, c in enumerate(columns_seen.keys())}
832+
projection = [
833+
exp.cast(exp.column(name), typ).as_(name)
834+
for name, typ in sorted(superset, key=lambda c: precedence[c[0]])
835+
]
836+
disinct = type_.this.upper() == "DISTINCT"
837+
selects: t.List[exp.Unionable] = [exp.select(*projection).from_(t) for t in tables]
838+
return t.cast(exp.Union, reduce(lambda a, b: a.union(b, disinct=disinct), selects))
839+
840+
841+
@macro()
842+
def haversine_distance(
843+
_: MacroEvaluator,
844+
lat1: exp.Expression,
845+
lon1: exp.Expression,
846+
lat2: exp.Expression,
847+
lon2: exp.Expression,
848+
unit: exp.Literal = exp.Literal.string("mi"),
849+
) -> exp.Mul:
850+
"""Returns the haversine distance between two points.
851+
852+
Example:
853+
>>> from sqlglot import parse_one
854+
>>> from sqlmesh.core.macros import MacroEvaluator
855+
>>> sql = "SELECT @HAVERSINE_DISTANCE(driver_y, driver_x, passenger_y, passenger_x, 'mi') FROM rides"
856+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
857+
'SELECT 7922 * ASIN(SQRT((POWER(SIN(RADIANS((passenger_y - driver_y) / 2)), 2)) + (COS(RADIANS(driver_y)) * COS(RADIANS(passenger_y)) * POWER(SIN(RADIANS((passenger_x - driver_x) / 2)), 2)))) * 1.0 FROM rides'
858+
"""
859+
if unit.this == "mi":
860+
conversion_rate = 1.0
861+
elif unit.this == "km":
862+
conversion_rate = 1.60934
863+
else:
864+
raise SQLMeshError(f"Invalid unit '{unit}'. Expected 'mi' or 'km'.")
865+
return (
866+
2
867+
* 3961
868+
* exp.func(
869+
"ASIN",
870+
exp.func(
871+
"SQRT",
872+
exp.func("POWER", exp.func("SIN", exp.func("RADIANS", (lat2 - lat1) / 2)), 2)
873+
+ exp.func("COS", exp.func("RADIANS", lat1))
874+
* exp.func("COS", exp.func("RADIANS", lat2))
875+
* exp.func("POWER", exp.func("SIN", exp.func("RADIANS", (lon2 - lon1) / 2)), 2),
876+
),
877+
)
878+
* conversion_rate
879+
)
880+
881+
882+
@macro()
883+
def pivot(
884+
evaluator: MacroEvaluator,
885+
column: exp.Column,
886+
values: exp.Array | exp.Tuple,
887+
alias: exp.Boolean = exp.true(),
888+
agg: exp.Literal = exp.Literal.string("SUM"),
889+
cmp: exp.Literal = exp.Literal.string("="),
890+
prefix: exp.Literal = exp.Literal.string(""),
891+
suffix: exp.Literal = exp.Literal.string(""),
892+
then_value: exp.Literal = exp.Literal.number(1),
893+
else_value: exp.Literal = exp.Literal.number(0),
894+
quote: exp.Boolean = exp.true(),
895+
distinct: exp.Boolean = exp.false(),
896+
) -> t.List[exp.Expression]:
897+
"""Returns a list of projections as a result of pivoting the given column on the given values.
898+
899+
Example:
900+
>>> from sqlglot import parse_one
901+
>>> from sqlmesh.core.macros import MacroEvaluator
902+
>>> sql = "SELECT date_day, @PIVOT(status, ['cancelled', 'completed']) FROM rides GROUP BY 1"
903+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
904+
"SELECT date_day, SUM(CASE WHEN status = 'cancelled' THEN 1 ELSE 0 END), SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) FROM rides GROUP BY 1"
905+
"""
906+
aggregates: t.List[exp.Expression] = []
907+
for value in values.expressions:
908+
proj = f"{agg.this}("
909+
if distinct.this:
910+
proj += "DISTINCT "
911+
proj += f"CASE WHEN {column} {cmp.this} {value} THEN {then_value} ELSE {else_value} END) "
912+
node = evaluator.parse_one(proj)
913+
if alias.this:
914+
node.as_(f"{prefix.this}{value}{suffix.this}", quoted=quote.this, copy=False)
915+
aggregates.append(node)
916+
return aggregates
917+
918+
679919
def normalize_macro_name(name: str) -> str:
680920
"""Prefix macro name with @ and upcase"""
681921
return f"@{name.upper()}"

0 commit comments

Comments
 (0)