@@ -677,7 +677,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
677677 """Executes the query from to_sql() and process the result."""
678678 query = self .to_sql (table , split_by )
679679 res = execute (str (query ))
680- extra_idx = list (utils .get_extra_idx (self , return_superset = True ))
680+ extra_idx = list (self .get_extra_idx (return_superset = True ))
681681 indexes = split_by + extra_idx if split_by else extra_idx
682682 columns = [a .alias_raw for a in query .groupby .add (query .columns )]
683683 columns [:len (indexes )] = indexes
@@ -692,7 +692,7 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
692692 """Generates SQL query for the metric."""
693693 global_filter = utils .get_global_filter (self )
694694 indexes = sql .Columns (split_by ).add (
695- utils .get_extra_idx (self , return_superset = True )
695+ self .get_extra_idx (return_superset = True )
696696 )
697697 with_data = sql .Datasources ()
698698 if isinstance (table , sql .Sql ) and table .with_data :
@@ -941,6 +941,32 @@ def add_edges(metric):
941941 add_edges (self )
942942 return dot .to_string ()
943943
944+ def get_extra_idx (self , return_superset = False ):
945+ """Collects the extra indexes added by self and its descendants.
946+
947+ Args:
948+ return_superset: If to return the superset of extra indexes if metric has
949+ incompatible indexes.
950+
951+ Returns:
952+ A tuple of column names which are just the index of metric.compute_on(df).
953+ """
954+ extra_idx = self .extra_index [:]
955+ children_idx = [
956+ c .get_extra_idx (return_superset )
957+ for c in self .children
958+ if utils .is_metric (c )
959+ ]
960+ if len (set (children_idx )) > 1 :
961+ if not return_superset :
962+ raise ValueError ('Incompatible indexes!' )
963+ children_idx_superset = set ()
964+ children_idx_superset .update (* children_idx )
965+ children_idx = [list (children_idx_superset )]
966+ if children_idx :
967+ extra_idx += list (children_idx [0 ])
968+ return tuple (extra_idx )
969+
944970 def traverse (self , include_self = True , include_constants = False ):
945971 ms = [self ] if include_self else list (self .children )
946972 while ms :
@@ -1291,7 +1317,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
12911317 The global with_data which holds all datasources we need in the WITH
12921318 clause.
12931319 """
1294- utils .get_extra_idx (self ) # Check if indexes are compatible.
1320+ self .get_extra_idx () # Check if indexes are compatible.
12951321 local_filter = (
12961322 sql .Filters (self .where_ ).add (local_filter ).remove (global_filter )
12971323 )
0 commit comments