Skip to content

Commit c06e5cd

Browse files
tcyameterstick-copybara
authored andcommitted
Add sql generators to LinearRegression and Ridge.
PiperOrigin-RevId: 591602164
1 parent 9c4f260 commit c06e5cd

File tree

10 files changed

+701
-178
lines changed

10 files changed

+701
-178
lines changed

metrics.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
677677
"""Executes the query from to_sql() and process the result."""
678678
query = self.to_sql(table, split_by)
679679
res = execute(str(query))
680-
extra_idx = list(utils.get_extra_idx(self, return_superset=True))
680+
extra_idx = list(self.get_extra_idx(return_superset=True))
681681
indexes = split_by + extra_idx if split_by else extra_idx
682682
columns = [a.alias_raw for a in query.groupby.add(query.columns)]
683683
columns[:len(indexes)] = indexes
@@ -692,7 +692,7 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
692692
"""Generates SQL query for the metric."""
693693
global_filter = utils.get_global_filter(self)
694694
indexes = sql.Columns(split_by).add(
695-
utils.get_extra_idx(self, return_superset=True)
695+
self.get_extra_idx(return_superset=True)
696696
)
697697
with_data = sql.Datasources()
698698
if isinstance(table, sql.Sql) and table.with_data:
@@ -941,6 +941,32 @@ def add_edges(metric):
941941
add_edges(self)
942942
return dot.to_string()
943943

944+
def get_extra_idx(self, return_superset=False):
945+
"""Collects the extra indexes added by self and its descendants.
946+
947+
Args:
948+
return_superset: If to return the superset of extra indexes if metric has
949+
incompatible indexes.
950+
951+
Returns:
952+
A tuple of column names which are just the index of metric.compute_on(df).
953+
"""
954+
extra_idx = self.extra_index[:]
955+
children_idx = [
956+
c.get_extra_idx(return_superset)
957+
for c in self.children
958+
if utils.is_metric(c)
959+
]
960+
if len(set(children_idx)) > 1:
961+
if not return_superset:
962+
raise ValueError('Incompatible indexes!')
963+
children_idx_superset = set()
964+
children_idx_superset.update(*children_idx)
965+
children_idx = [list(children_idx_superset)]
966+
if children_idx:
967+
extra_idx += list(children_idx[0])
968+
return tuple(extra_idx)
969+
944970
def traverse(self, include_self=True, include_constants=False):
945971
ms = [self] if include_self else list(self.children)
946972
while ms:
@@ -1291,7 +1317,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
12911317
The global with_data which holds all datasources we need in the WITH
12921318
clause.
12931319
"""
1294-
utils.get_extra_idx(self) # Check if indexes are compatible.
1320+
self.get_extra_idx() # Check if indexes are compatible.
12951321
local_filter = (
12961322
sql.Filters(self.where_).add(local_filter).remove(global_filter)
12971323
)

0 commit comments

Comments
 (0)