From c639a44a4fee71cad07827eb33266e1eec7183ae Mon Sep 17 00:00:00 2001 From: Egor Kraev Date: Thu, 9 Apr 2026 12:22:31 +0200 Subject: [PATCH] Support all transforms for joined measures --- .../05_joined_measures/joined_measures.md | 5 +- slayer/sql/generator.py | 307 +++++++++++++----- tests/integration/test_integration.py | 178 ++++++++++ 3 files changed, 412 insertions(+), 78 deletions(-) diff --git a/docs/examples/05_joined_measures/joined_measures.md b/docs/examples/05_joined_measures/joined_measures.md index 497cc4c..acf35a2 100644 --- a/docs/examples/05_joined_measures/joined_measures.md +++ b/docs/examples/05_joined_measures/joined_measures.md @@ -20,7 +20,7 @@ We then evaluate both queries (the results may have different cardinality becaus This way we guarantee that the values of that joined measure are exactly the same as in the original — as that is exactly how it's evaluated. -[Transforms](../../concepts/formulas.md) like `cumsum()` and `change()` work on cross-model measures too — the transform is applied after the sub-query join: +All [transforms](../../concepts/formulas.md) work on cross-model measures — window transforms (`cumsum`, `lag`, `lead`, `rank`, `last`) and self-join transforms (`change`, `change_pct`, `time_shift`) alike. Window transforms are applied as window functions over the sub-query result; self-join transforms generate their own CTE chain on top of the cross-model sub-query. ```json { @@ -28,7 +28,8 @@ This way we guarantee that the values of that joined measure are exactly the sam "time_dimensions": [{"dimension": {"name": "ordered_at"}, "granularity": "month"}], "fields": [ {"formula": "customers.count"}, - {"formula": "cumsum(customers.count)", "name": "cumulative_customers"} + {"formula": "cumsum(customers.count)", "name": "cumulative_customers"}, + {"formula": "change(customers.count)", "name": "count_change"} ] } ``` diff --git a/slayer/sql/generator.py b/slayer/sql/generator.py index e44c801..00beff2 100644 --- a/slayer/sql/generator.py +++ b/slayer/sql/generator.py @@ -486,20 +486,179 @@ def _generate_with_computed(self, enriched: EnrichedQuery, base_sql: str, return sql + def _build_cm_cte_sql(self, cm, enriched: EnrichedQuery, + time_offset=None) -> str: + """Build the SQL body for a cross-model measure CTE. + + Args: + cm: CrossModelMeasure to build the CTE for. + enriched: The parent EnrichedQuery (for WHERE filters). + time_offset: Optional (offset, granularity) tuple for calendar-based + self-join transforms. Shifts time dimension expressions. + """ + select_parts = [] + group_parts = [] + + # Shared dimensions + for dim in cm.shared_dimensions: + col_expr = self._resolve_sql(sql=dim.sql, name=dim.name, model_name=cm.source_model_name) + col_sql = col_expr.sql(dialect=self.dialect) + select_parts.append(f'{col_sql} AS "{dim.alias}"') + group_parts.append(col_sql) + + # Shared time dimensions + for td in cm.shared_time_dimensions: + col_expr = self._resolve_sql(sql=td.sql, name=td.name, model_name=cm.source_model_name) + if time_offset is not None: + offset_val, gran = time_offset + col_expr = self._build_time_offset_expr(col_expr=col_expr, offset=offset_val, granularity=gran) + td_expr = self._build_date_trunc(col_expr=col_expr, granularity=td.granularity) + td_sql = td_expr.sql(dialect=self.dialect) + select_parts.append(f'{td_sql} AS "{td.alias}"') + group_parts.append(td_sql) + + # The measure aggregation + agg_expr, _ = self._build_agg(measure=cm.measure) + select_parts.append(f'{agg_expr.sql(dialect=self.dialect)} AS "{cm.alias}"') + + # FROM: source table with JOIN to target + if cm.source_sql: + from_sql = f"({cm.source_sql}) AS {cm.source_model_name}" + else: + from_sql = f"{cm.source_sql_table} AS {cm.source_model_name}" + + if cm.target_model_sql: + target_from = f"({cm.target_model_sql}) AS {cm.target_model_name}" + else: + target_from = f"{cm.target_model_sql_table} AS {cm.target_model_name}" + + join_conditions = [] + for src_dim, tgt_dim in cm.join_pairs: + join_conditions.append( + f"{cm.source_model_name}.{src_dim} = {cm.target_model_name}.{tgt_dim}" + ) + join_on = " AND ".join(join_conditions) + + cte_sql = ( + f"SELECT {', '.join(select_parts)}\n" + f"FROM {from_sql}\n" + f"LEFT JOIN {target_from} ON {join_on}" + ) + + # Apply the main query's WHERE filters + where_clause, _ = self._build_where_and_having(enriched=enriched) + if where_clause is not None: + cte_sql += f"\nWHERE {where_clause.sql(dialect=self.dialect)}" + + if group_parts: + cte_sql += f"\nGROUP BY {', '.join(group_parts)}" + + return cte_sql + + def _build_cm_self_join_ctes(self, t, cm, cm_cte_name: str, + enriched: EnrichedQuery) -> list: + """Build the CTE chain for a self-join transform on a cross-model measure. + + Returns a list of (cte_name, cte_sql) tuples to append to the top-level CTEs. + The last CTE in the chain contains the transform result. + """ + result_ctes = [] + + # Column aliases in the CM CTE + cm_col_aliases = [] + for dim in cm.shared_dimensions: + cm_col_aliases.append(dim.alias) + for td in cm.shared_time_dimensions: + cm_col_aliases.append(td.alias) + cm_col_aliases.append(cm.alias) + + time_col = f'"{t.time_alias}"' if t.time_alias else None + + # Determine effective join granularity + has_date_ranges = any( + td.date_range and len(td.date_range) == 2 + for td in enriched.time_dimensions + ) + join_granularity = t.granularity + if not join_granularity and has_date_ranges: + for td in enriched.time_dimensions: + if td.alias == t.time_alias: + join_granularity = td.granularity.value + break + + is_calendar = join_granularity is not None + src_cte = cm_cte_name + + # Add ROW_NUMBER if using row-number join + if not is_calendar: + rn_cte_name = f"{cm_cte_name}_rn" + all_cols = ", ".join(f'"{a}"' for a in cm_col_aliases) + rn_sql = f"SELECT {all_cols}, ROW_NUMBER() OVER (ORDER BY {time_col}) AS _rn FROM {cm_cte_name}" + result_ctes.append((rn_cte_name, rn_sql)) + src_cte = rn_cte_name + + # Build shifted base CTE + shift_base_name = f"shifted_base_cm_{t.name}" + if is_calendar: + # Calendar-based: regenerate CM CTE with shifted time expressions + gran = join_granularity + offset = t.offset + shifted_sql = self._build_cm_cte_sql( + cm=cm, enriched=enriched, + time_offset=(-offset, gran), + ) + else: + # Row-based: shifted base is identical to original + shifted_sql = self._build_cm_cte_sql(cm=cm, enriched=enriched) + result_ctes.append((shift_base_name, shifted_sql)) + + # Add ROW_NUMBER to shifted CTE + shift_name = f"shifted_cm_{t.name}" + if not is_calendar: + shift_cols = ", ".join(f'"{a}"' for a in cm_col_aliases) + shift_rn_sql = f"SELECT {shift_cols}, ROW_NUMBER() OVER (ORDER BY {time_col}) AS _rn FROM {shift_base_name}" + result_ctes.append((shift_name, shift_rn_sql)) + else: + result_ctes.append((shift_name, f"SELECT * FROM {shift_base_name}")) + + # Build self-join CTE + if is_calendar: + join_cond = f'{src_cte}.{time_col} = {shift_name}.{time_col}' + else: + join_cond = self._build_row_number_join( + left_table=src_cte, right_table=shift_name, offset=t.offset, + ) + + col_sql = self._build_self_join_column( + transform=t.transform, left_table=src_cte, + right_table=shift_name, measure_alias=cm.alias, + ) + join_cols = ", ".join(f'{src_cte}."{a}"' for a in cm_col_aliases) + sjoin_name = f"sjoin_cm_{t.name}" + sjoin_sql = ( + f"SELECT {join_cols}, {col_sql} AS \"{t.alias}\"\n" + f"FROM {src_cte}\n" + f"LEFT JOIN {shift_name}\n" + f" ON {join_cond}" + ) + result_ctes.append((sjoin_name, sjoin_sql)) + + return result_ctes + def _generate_with_cross_model(self, enriched: EnrichedQuery, - base_sql: str, is_cte: bool) -> str: + base_sql: str, is_cte: bool = False) -> str: """Wrap the main query with cross-model measure sub-queries. Each cross-model measure becomes a CTE that aggregates the target model's measure scoped to shared dimensions, then LEFT JOINed to the main query. + + Window transforms (cumsum, lag, lead, rank, last) are applied as window + functions in the outer SELECT. Self-join transforms (change, change_pct, + time_shift) generate additional CTE layers on top of the cross-model CTE. """ + _ = is_cte # All paths wrap base_sql as a CTE # Wrap the base/computed SQL as a CTE - if is_cte: - # base_sql is already a WITH ... SELECT — wrap it as a subquery CTE - main_cte = f"_main AS (\n{base_sql}\n)" - else: - main_cte = f"_main AS (\n{base_sql}\n)" - + main_cte = f"_main AS (\n{base_sql}\n)" ctes = [main_cte] # Build join columns from the main query (for the final SELECT) @@ -513,7 +672,7 @@ def _generate_with_cross_model(self, enriched: EnrichedQuery, for expr in enriched.expressions: main_columns.append(expr.alias) # Transforms that depend on cross-model aliases are computed in the - # outer SELECT, not inside _main — exclude them from main_columns + # outer SELECT or via extra CTEs — exclude them from main_columns cm_aliases_pre = {cm.alias for cm in enriched.cross_model_measures} for t in enriched.transforms: if t.measure_alias not in cm_aliases_pre: @@ -531,105 +690,101 @@ def _generate_with_cross_model(self, enriched: EnrichedQuery, if is_duplicate: continue # CTE already generated, just reuse in final SELECT - # Build the sub-query: SELECT shared_dims, AGG(measure) FROM target GROUP BY shared_dims - select_parts = [] - group_parts = [] - - # Shared dimensions - for dim in cm.shared_dimensions: - col_expr = self._resolve_sql(sql=dim.sql, name=dim.name, model_name=cm.source_model_name) - col_sql = col_expr.sql(dialect=self.dialect) - select_parts.append(f'{col_sql} AS "{dim.alias}"') - group_parts.append(col_sql) - - # Shared time dimensions - for td in cm.shared_time_dimensions: - col_expr = self._resolve_sql(sql=td.sql, name=td.name, model_name=cm.source_model_name) - td_expr = self._build_date_trunc(col_expr=col_expr, granularity=td.granularity) - td_sql = td_expr.sql(dialect=self.dialect) - select_parts.append(f'{td_sql} AS "{td.alias}"') - group_parts.append(td_sql) - - # The measure aggregation - agg_expr, _ = self._build_agg(measure=cm.measure) - select_parts.append(f'{agg_expr.sql(dialect=self.dialect)} AS "{cm.alias}"') - - # FROM: source table with JOIN to target - if cm.source_sql: - from_sql = f"({cm.source_sql}) AS {cm.source_model_name}" - else: - from_sql = f"{cm.source_sql_table} AS {cm.source_model_name}" - - # JOIN to target model - if cm.target_model_sql: - target_from = f"({cm.target_model_sql}) AS {cm.target_model_name}" - else: - target_from = f"{cm.target_model_sql_table} AS {cm.target_model_name}" - - join_conditions = [] - for src_dim, tgt_dim in cm.join_pairs: - join_conditions.append( - f"{cm.source_model_name}.{src_dim} = {cm.target_model_name}.{tgt_dim}" - ) - join_on = " AND ".join(join_conditions) - - cte_sql = ( - f"SELECT {', '.join(select_parts)}\n" - f"FROM {from_sql}\n" - f"LEFT JOIN {target_from} ON {join_on}" - ) - - # Apply the main query's WHERE filters to the cross-model CTE - where_clause, _ = self._build_where_and_having(enriched=enriched) - if where_clause is not None: - cte_sql += f"\nWHERE {where_clause.sql(dialect=self.dialect)}" - - if group_parts: - cte_sql += f"\nGROUP BY {', '.join(group_parts)}" - + cte_sql = self._build_cm_cte_sql(cm=cm, enriched=enriched) ctes.append(f"{cte_name} AS (\n{cte_sql}\n)") # Identify transforms that depend on cross-model measure aliases cm_aliases = {cm.alias for _, cm in cm_cte_names} post_cm_transforms = [t for t in enriched.transforms if t.measure_alias in cm_aliases] + cm_window_transforms = [t for t in post_cm_transforms if t.transform not in _SELF_JOIN_TRANSFORMS] + cm_self_join_transforms = [t for t in post_cm_transforms if t.transform in _SELF_JOIN_TRANSFORMS] + + # Build self-join CTE chains for self-join transforms on cross-model measures. + # Maps transform alias -> sjoin CTE name (for the final SELECT/JOIN). + sjoin_cte_map = {} + for t in cm_self_join_transforms: + # Find the CM and CTE name this transform targets + target_cm = None + target_cte_name = "" + for cte_name, cm in cm_cte_names: + if cm.alias == t.measure_alias: + target_cm = cm + target_cte_name = cte_name + break + if target_cm is None: + raise ValueError(f"No cross-model measure found for transform '{t.name}'") + + extra_ctes = self._build_cm_self_join_ctes( + t=t, cm=target_cm, cm_cte_name=target_cte_name, enriched=enriched, + ) + for name, sql in extra_ctes: + ctes.append(f"{name} AS (\n{sql}\n)") + # The last CTE in the chain has the transform result + sjoin_cte_map[t.alias] = (extra_ctes[-1][0], target_cm) - # Build final SELECT: main columns + cross-model measure columns + post-CM transforms + # Build final SELECT: main columns + cross-model measure columns + transforms final_parts = [f'_main."{a}"' for a in main_columns] + + # Add bare cross-model measure columns (from base CM CTEs or sjoin CTEs) seen_cm_aliases = set() for cte_name, cm in cm_cte_names: if cm.alias not in seen_cm_aliases: seen_cm_aliases.add(cm.alias) - final_parts.append(f'{cte_name}."{cm.alias}"') - for t in post_cm_transforms: + # If a self-join transform targets this CM, get the measure from the + # sjoin CTE (which carries it through); otherwise from the base CM CTE + source_cte = cte_name + for sjoin_cte_name, sjoin_cm in sjoin_cte_map.values(): + if sjoin_cm.alias == cm.alias: + source_cte = sjoin_cte_name + break + final_parts.append(f'{source_cte}."{cm.alias}"') + + # Add window transforms in outer SELECT + for t in cm_window_transforms: window_sql = self._build_transform_sql(t) - # Replace the quoted measure alias with the cross-model CTE reference for cte_name, cm in cm_cte_names: if cm.alias == t.measure_alias: + # If a sjoin CTE exists for this CM, reference it + source_cte = cte_name + for sjoin_cte_name, sjoin_cm in sjoin_cte_map.values(): + if sjoin_cm.alias == cm.alias: + source_cte = sjoin_cte_name + break window_sql = window_sql.replace( - f'"{t.measure_alias}"', f'{cte_name}."{cm.alias}"' + f'"{t.measure_alias}"', f'{source_cte}."{cm.alias}"' ) break - # Qualify time alias with _main to avoid ambiguity in JOINed context if t.time_alias: window_sql = window_sql.replace( f'"{t.time_alias}"', f'_main."{t.time_alias}"' ) final_parts.append(f'{window_sql} AS "{t.alias}"') - # Build JOINs: join each cross-model CTE to _main on shared dimensions (deduplicate) + # Add self-join transform columns + for t in cm_self_join_transforms: + sjoin_cte_name, _ = sjoin_cte_map[t.alias] + final_parts.append(f'{sjoin_cte_name}."{t.alias}"') + + # Build JOINs: join each cross-model CTE (or its sjoin CTE) to _main from_clause = "FROM _main" joined_ctes = set() for cte_name, cm in cm_cte_names: - if cte_name in joined_ctes: + # Determine which CTE to join: sjoin CTE if self-join transforms exist, else base CM CTE + join_cte = cte_name + for sjoin_cte_name, sjoin_cm in sjoin_cte_map.values(): + if sjoin_cm.alias == cm.alias: + join_cte = sjoin_cte_name + break + if join_cte in joined_ctes: continue - joined_ctes.add(cte_name) + joined_ctes.add(join_cte) join_on_parts = [] for dim in cm.shared_dimensions: - join_on_parts.append(f'_main."{dim.alias}" = {cte_name}."{dim.alias}"') + join_on_parts.append(f'_main."{dim.alias}" = {join_cte}."{dim.alias}"') for td in cm.shared_time_dimensions: - join_on_parts.append(f'_main."{td.alias}" = {cte_name}."{td.alias}"') + join_on_parts.append(f'_main."{td.alias}" = {join_cte}."{td.alias}"') if join_on_parts: - from_clause += f"\nLEFT JOIN {cte_name} ON {' AND '.join(join_on_parts)}" + from_clause += f"\nLEFT JOIN {join_cte} ON {' AND '.join(join_on_parts)}" sql = f"WITH {','.join(ctes)}\nSELECT {', '.join(final_parts)}\n{from_clause}" diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 1016357..bb44b41 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1083,6 +1083,184 @@ def test_transform_on_cross_model(cross_model_env): assert response.data[2]["orders.running"] == pytest.approx(235.0) +@pytest.mark.integration +def test_change_on_cross_model(cross_model_env): + """change() on cross-model measure uses self-join CTE chain.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="customers.avg_score"), + Field(formula="change(customers.avg_score)", name="score_change"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # Jan: 90, Feb: 60, Mar: 85 + # change: None, 60-90=-30, 85-60=25 + assert response.data[0]["orders.score_change"] is None + assert response.data[1]["orders.score_change"] == pytest.approx(-30.0) + assert response.data[2]["orders.score_change"] == pytest.approx(25.0) + + +@pytest.mark.integration +def test_change_pct_on_cross_model(cross_model_env): + """change_pct() on cross-model measure.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="change_pct(customers.avg_score)", name="score_change_pct"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # change_pct: None, (60-90)/90 = -0.333, (85-60)/60 = 0.4167 + assert response.data[0]["orders.score_change_pct"] is None + assert response.data[1]["orders.score_change_pct"] == pytest.approx(-30.0 / 90.0) + assert response.data[2]["orders.score_change_pct"] == pytest.approx(25.0 / 60.0) + + +@pytest.mark.integration +def test_time_shift_on_cross_model(cross_model_env): + """time_shift() on cross-model measure.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="customers.avg_score"), + Field(formula="time_shift(customers.avg_score, -1)", name="prev_score"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # time_shift(-1) = previous row: None, 90, 60 + assert response.data[0]["orders.prev_score"] is None + assert response.data[1]["orders.prev_score"] == pytest.approx(90.0) + assert response.data[2]["orders.prev_score"] == pytest.approx(60.0) + + +@pytest.mark.integration +def test_lag_on_cross_model(cross_model_env): + """lag() window transform on cross-model measure.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="lag(customers.avg_score)", name="prev_score"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # lag: None, 90, 60 + assert response.data[0]["orders.prev_score"] is None + assert response.data[1]["orders.prev_score"] == pytest.approx(90.0) + assert response.data[2]["orders.prev_score"] == pytest.approx(60.0) + + +@pytest.mark.integration +def test_lead_on_cross_model(cross_model_env): + """lead() window transform on cross-model measure.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="lead(customers.avg_score)", name="next_score"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # lead: 60, 85, None + assert response.data[0]["orders.next_score"] == pytest.approx(60.0) + assert response.data[1]["orders.next_score"] == pytest.approx(85.0) + assert response.data[2]["orders.next_score"] is None + + +@pytest.mark.integration +def test_rank_on_cross_model(cross_model_env): + """rank() window transform on cross-model measure.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="customers.avg_score"), + Field(formula="rank(customers.avg_score)", name="score_rank"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # rank DESC: 90→1, 60→3, 85→2 + assert response.data[0]["orders.score_rank"] == 1 + assert response.data[1]["orders.score_rank"] == 3 + assert response.data[2]["orders.score_rank"] == 2 + + +@pytest.mark.integration +def test_mixed_window_and_selfjoin_on_cross_model(cross_model_env): + """Combining window and self-join transforms on the same cross-model measure.""" + engine = cross_model_env + + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="customers.avg_score"), + Field(formula="cumsum(customers.avg_score)", name="running"), + Field(formula="change(customers.avg_score)", name="score_change"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # cumsum: 90, 150, 235 + assert response.data[0]["orders.running"] == pytest.approx(90.0) + assert response.data[1]["orders.running"] == pytest.approx(150.0) + assert response.data[2]["orders.running"] == pytest.approx(235.0) + # change: None, -30, 25 + assert response.data[0]["orders.score_change"] is None + assert response.data[1]["orders.score_change"] == pytest.approx(-30.0) + assert response.data[2]["orders.score_change"] == pytest.approx(25.0) + + # --------------------------------------------------------------------------- # Query as model (multistage queries) # ---------------------------------------------------------------------------