Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
finally:
sql.TEMP_TABLE_SUPPORTED = None
res = execute(str(query))
extra_idx = list(utils.get_extra_idx(self, return_superset=True))
extra_idx = list(self.get_extra_idx(return_superset=True))
indexes = split_by + extra_idx if split_by else extra_idx
columns = [a.alias_raw for a in query.groupby.add(query.columns)]
columns[:len(indexes)] = indexes
Expand Down Expand Up @@ -745,7 +745,7 @@ def to_sql(
"""
global_filter = utils.get_global_filter(self)
indexes = sql.Columns(split_by).add(
utils.get_extra_idx(self, return_superset=True)
self.get_extra_idx(return_superset=True)
)
with_data = sql.Datasources()
if isinstance(table, sql.Sql) and table.with_data:
Expand Down Expand Up @@ -1005,6 +1005,32 @@ def add_edges(metric):
add_edges(self)
return dot.to_string()

def get_extra_idx(self, return_superset=False):
"""Collects the extra indexes added by self and its descendants.

Args:
return_superset: If to return the superset of extra indexes if metric has
incompatible indexes.

Returns:
A tuple of column names which are just the index of metric.compute_on(df).
"""
extra_idx = self.extra_index[:]
children_idx = [
c.get_extra_idx(return_superset)
for c in self.children
if utils.is_metric(c)
]
if len(set(children_idx)) > 1:
if not return_superset:
raise ValueError('Incompatible indexes!')
children_idx_superset = set()
children_idx_superset.update(*children_idx)
children_idx = [list(children_idx_superset)]
if children_idx:
extra_idx += list(children_idx[0])
return tuple(extra_idx)

def traverse(self, include_self=True, include_constants=False):
ms = [self] if include_self else list(self.children)
while ms:
Expand Down Expand Up @@ -1358,7 +1384,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
The global with_data which holds all datasources we need in the WITH
clause.
"""
utils.get_extra_idx(self) # Check if indexes are compatible.
self.get_extra_idx() # Check if indexes are compatible.
local_filter = (
sql.Filters(self.where_).add(local_filter).remove(global_filter)
)
Expand Down
Loading