From 21ed8175b6ef92bc4e17810933cb1c95f9ed0123 Mon Sep 17 00:00:00 2001 From: Xunmo Yang Date: Sat, 16 Dec 2023 20:41:00 -0800 Subject: [PATCH] Add sql generators to LinearRegression and Ridge. PiperOrigin-RevId: 591602164 --- metrics.py | 32 +++- models.py | 401 +++++++++++++++++++++++++++++++++++++++-------- models_test.py | 34 ++++ operations.py | 25 +-- requirements.txt | 1 + sql.py | 4 + utils.py | 126 ++++++++++++--- utils_test.py | 354 +++++++++++++++++++++++------------------ 8 files changed, 718 insertions(+), 259 deletions(-) diff --git a/metrics.py b/metrics.py index eabd349..8aab271 100644 --- a/metrics.py +++ b/metrics.py @@ -699,7 +699,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): finally: sql.TEMP_TABLE_SUPPORTED = None res = execute(str(query)) - extra_idx = list(utils.get_extra_idx(self, return_superset=True)) + extra_idx = list(self.get_extra_idx(return_superset=True)) indexes = split_by + extra_idx if split_by else extra_idx columns = [a.alias_raw for a in query.groupby.add(query.columns)] columns[:len(indexes)] = indexes @@ -745,7 +745,7 @@ def to_sql( """ global_filter = utils.get_global_filter(self) indexes = sql.Columns(split_by).add( - utils.get_extra_idx(self, return_superset=True) + self.get_extra_idx(return_superset=True) ) with_data = sql.Datasources() if isinstance(table, sql.Sql) and table.with_data: @@ -1005,6 +1005,32 @@ def add_edges(metric): add_edges(self) return dot.to_string() + def get_extra_idx(self, return_superset=False): + """Collects the extra indexes added by self and its descendants. + + Args: + return_superset: If to return the superset of extra indexes if metric has + incompatible indexes. + + Returns: + A tuple of column names which are just the index of metric.compute_on(df). + """ + extra_idx = self.extra_index[:] + children_idx = [ + c.get_extra_idx(return_superset) + for c in self.children + if utils.is_metric(c) + ] + if len(set(children_idx)) > 1: + if not return_superset: + raise ValueError('Incompatible indexes!') + children_idx_superset = set() + children_idx_superset.update(*children_idx) + children_idx = [list(children_idx_superset)] + if children_idx: + extra_idx += list(children_idx[0]) + return tuple(extra_idx) + def traverse(self, include_self=True, include_constants=False): ms = [self] if include_self else list(self.children) while ms: @@ -1358,7 +1384,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ - utils.get_extra_idx(self) # Check if indexes are compatible. + self.get_extra_idx() # Check if indexes are compatible. local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) ) diff --git a/models.py b/models.py index 80ba026..bac9821 100644 --- a/models.py +++ b/models.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function +import copy import itertools from typing import List, Optional, Sequence, Text, Union @@ -75,22 +76,13 @@ def __init__( raise ValueError( 'y must be a 1D array but is %iD!' % operations.count_features(y) ) - if isinstance(x, Sequence): - x = metrics.MetricList(x) + if isinstance(x, metrics.Metric): + x = [x] child = None if x and y: - self.x = x - self.y = y - child = metrics.MetricList((y, x)) + child = metrics.MetricList([y] + x) self.model = model - self.k = operations.count_features(x) self.model_name = model_name - if not name and x and y: - x_names = ( - [m.name for m in x] if isinstance(x, metrics.MetricList) else [x.name] - ) - name = '%s(%s ~ %s)' % (model_name, y.name, ' + '.join(x_names)) - name_tmpl = '%s Coefficient: {}' % name additional_fingerprint_attrs = ( [additional_fingerprint_attrs] if isinstance(additional_fingerprint_attrs, str) @@ -98,7 +90,7 @@ def __init__( ) super(Model, self).__init__( child, - name_tmpl, + None, group_by, [], name=name, @@ -131,6 +123,17 @@ def compute(self, df): def compute_through_sql(self, table, split_by, execute, mode): try: + if ( + not mode + and isinstance(self, (LinearRegression, Ridge)) + and not self.normalize + and self.k > 5 + ): + print( + 'INFO: SQL generation for your Model can be slow because the number' + ' of features > 5. Try compute_on_sql(mode="mixed") (for small' + ' data) or compute_on_sql(mode="magic") (for large data).' + ) if mode == 'magic': if self.where: table = sql.Sql(None, table, self.where_) @@ -155,25 +158,65 @@ def compute_on_sql_magic_mode(self, table, split_by, execute): raise NotImplementedError @property - def group_by(self): - return self.extra_split_by + def y(self): + if not self.children or not isinstance( + self.children[0], metrics.MetricList + ): + raise ValueError('y must be a Metric!') + return self.children[0][0] + + @property + def x(self): + if not self.children or not isinstance( + self.children[0], metrics.MetricList + ): + raise ValueError('x must be a MetricList!') + return metrics.MetricList(self.children[0][1:]) + + @property + def k(self): + return operations.count_features(self.x) - def __call__(self, child): - if not isinstance(child, metrics.MetricList): - raise ValueError(f'Model can only take a MetricList but got {child}!') - model = super(Model, self).__call__(child) - model.y = child[0] - model.x = metrics.MetricList(child[1:]) - model.k = operations.count_features(model.x) - x_names = [m.name for m in model.x] - model.name = '%s(%s ~ %s)' % ( - model.model_name, - model.y.name, + @property + def name(self): + if self.name_: # pytype: disable=attribute-error + return self.name_ # pytype: disable=attribute-error + if not self.children: + return self.model_name + x_names = [m.name for m in self.x] + return '%s(%s ~ %s)' % ( + self.model_name, + self.y.name, ' + '.join(x_names), ) - model.name_tmpl = model.name + ' Coefficient: {}' + + @name.setter + def name(self, name): + self.name_ = name + + @property + def name_tmpl(self): + if self.name_tmpl_: # pytype: disable=attribute-error + return self.name_tmpl_ # pytype: disable=attribute-error + return self.name + ' Coefficient: {}' + + @name_tmpl.setter + def name_tmpl(self, name_tmpl): + self.name_tmpl_ = name_tmpl + + @property + def group_by(self): + return self.extra_split_by + + def __call__(self, child: metrics.Metric): + model = copy.deepcopy(self) if self.children else self + model.children = (child,) return model + def get_extra_idx(self, return_superset=False): + # Model blocks the propagation of extra split_by from the descendants. + return () + class LinearRegression(Model): """A class that can fit a linear regression.""" @@ -196,6 +239,21 @@ def __init__( y, x, group_by, model, 'OLS', where, name, fit_intercept, normalize ) + def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, + local_filter, with_data): + return Ridge( + self.y, + self.x, + self.group_by, + 0, + self.fit_intercept, + self.normalize, + self.where_, + self.name, + ).get_sql_and_with_clause( + table, split_by, global_filter, indexes, local_filter, with_data + ) + def compute_on_sql_magic_mode(self, table, split_by, execute): return Ridge( self.y, @@ -254,11 +312,134 @@ def __init__( ) self.alpha = alpha + def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, + local_filter, with_data): + """Gets the SQL query and WITH clause. + + First we get the query that computes all the elements of X'X and X'y. This + step is same to that in the 'magic' mode. Then we get the elements of + (X'X)^(-1)*(X'y) by doing symbolic computation in SymPy, and translate + them to SQL queries. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + The SQL instance for metric, without the WITH clause component. + The global with_data which holds all datasources we need in the WITH + clause. + """ + import sympy # pylint: disable=g-import-not-at-top + + normalize = self.fit_intercept and self.normalize and self.alpha + if normalize: + data_to_fit, with_data = get_data_to_fit( + self, table, split_by, global_filter, indexes, local_filter, with_data + ) + data_to_fit_alias = with_data.merge( + sql.Datasource(data_to_fit, 'DataToFit') + ) + split_by = sql.Columns(split_by.aliases) + indexes = sql.Columns(indexes.aliases) + groupby = sql.Columns(self.group_by).aliases + all_split_by = sql.Columns(split_by).add(indexes).add(groupby) + cols = [] + xs = sql.Columns() + for c in data_to_fit.all_columns: + if c.alias in all_split_by: + cols.append(c.alias) + else: + cols.append(c.alias) + y = c.alias + break # y is the 1st column that is not in all_split_by. + partition_by = ', '.join([c for c in cols[:-1] if c not in groupby]) + for x in data_to_fit.all_columns[len(cols):]: + x = x.alias + centered = sql.Column(x) - sql.Column( + x, 'AVG({})', partition=partition_by + ) + cols.append(centered.set_alias(x)) + xs.add(x) + centered_sql = sql.Sql(cols, data_to_fit_alias) + centered_sql_table = sql.Datasource(centered_sql, 'DataWithXCentered') + table = with_data.merge(centered_sql_table) + + x_t_x, x_t_y = utils.get_x_t_x_and_x_t_y_cols(xs, y, normalize=True) + cols = sql.Columns(x_t_x + x_t_y) + cols.add(sql.Column('COUNT(*)', alias='n_obs')) + sufficient_stats = sql.Sql(cols, table, groupby=split_by) + sufficient_stats_table = sql.Datasource( + sufficient_stats, 'SufficientStatElements' + ) + sufficient_stats_alias = with_data.merge(sufficient_stats_table) + else: + xs, sufficient_stats, _, _ = get_sufficient_stats_elements_sql( + self, + table, + split_by, + False, + self.alpha, + global_filter, + indexes, + local_filter, + with_data, + ) + with_data.merge(sufficient_stats.with_data) + sufficient_stats.with_data = None + sufficient_stats_table = sql.Datasource( + sufficient_stats, 'SufficientStatElements' + ) + sufficient_stats_alias = with_data.merge(sufficient_stats_table) + n = len(xs) + bool(self.fit_intercept) + split_by = sql.Columns(split_by.aliases) + sufficient_stats_cols = [ + c for c in sufficient_stats.columns.aliases if c not in split_by + ] + n_x_t_x_elements = n * (n + 1) // 2 - bool(self.fit_intercept) + x_t_x_elements = sufficient_stats_cols[:n_x_t_x_elements] + x_t_y_elements = sufficient_stats_cols[ + n_x_t_x_elements : n_x_t_x_elements + n + ] + if self.fit_intercept: + x_t_x_elements = [1] + x_t_x_elements + penalty = 0 + if isinstance(self, Ridge) and self.alpha: + # if normalize: + # penalty = self.alpha + # else: + n_obs = sufficient_stats_cols[-1] + # We use AVG() to compute x_t_x so the penalty needs to be scaled. + penalty = self.alpha / sympy.Symbol(n_obs) + coefs = utils.get_ridge_coefficients( + x_t_x_elements, x_t_y_elements, self.fit_intercept, penalty, normalize + ) + xs = xs.raw_aliases + cols = sql.Columns(split_by) + if self.fit_intercept: + xs = ['intercept'] + xs + for x, c in zip(xs, coefs): + # ccode prints x**2 to pow(x, 2) which works in SQL. + cols.add( + [sql.Column(sympy.printing.ccode(c), alias=self.name_tmpl.format(x))] + ) + return sql.Sql(cols, sufficient_stats_alias), with_data + def compute_on_sql_magic_mode(self, table, split_by, execute): # Never normalize for the sufficient_stats. Normalization is handled in # compute_ridge_coefs() instead. xs, sufficient_stats, _, _ = get_sufficient_stats_elements( - self, table, split_by, execute, normalize=False, include_n_obs=True + self, + table, + split_by, + execute, + normalize=False, + include_n_obs=self.alpha, ) return apply_algorithm_to_sufficient_stats_elements( sufficient_stats, split_by, compute_ridge_coefs, xs, self @@ -270,7 +451,6 @@ def get_sufficient_stats_elements( table, split_by, execute, - fit_intercept=None, normalize=None, include_n_obs=False, ): @@ -281,7 +461,6 @@ def get_sufficient_stats_elements( table: The table we want to query from. split_by: The columns that we use to split the data. execute: A function that can executes a SQL query and returns a DataFrame. - fit_intercept: If to include intercept in the model. normalize: If to normalize the X. Note that only has effect when m.fit_intercept is True, which is consistent to sklearn. include_n_obs: If to include the number of observations in the return. @@ -306,43 +485,107 @@ def get_sufficient_stats_elements( norms: Nonempty only when normalize. A pd.DataFrame which holds the l2-norm values of all centered-x columns. """ - fit_intercept = m.fit_intercept if fit_intercept is None else fit_intercept if normalize is None: normalize = m.normalize and m.fit_intercept + xs_cols, sufficient_stats_elements, avg_x, norms = ( + get_sufficient_stats_elements_sql( + m, + table, + split_by, + normalize, + include_n_obs, + ) + ) + sufficient_stats_elements = execute(str(sufficient_stats_elements)) + if normalize: + col_names = list(sufficient_stats_elements.columns) + avg_x_names = [f'x{i}' for i in range(len(xs_cols))] + sufficient_stats_elements[avg_x_names] = 0 + sufficient_stats_elements = sufficient_stats_elements[ + col_names[: len(split_by)] + avg_x_names + col_names[len(split_by) :] + ] + return xs_cols, sufficient_stats_elements, avg_x, norms + + +def get_sufficient_stats_elements_sql( + m, + table, + split_by, + normalize=None, + include_n_obs=False, + global_filter=None, + indexes=None, + local_filter=None, + with_data=None, +): + """Generates the SQL columns for the elements of X'X and X'y. + + Args: + m: A Model instance. + table: The table we want to query from. + split_by: The columns that we use to split the data. + normalize: If to normalize the X. Note that only has effect when + m.fit_intercept is True, which is consistent to sklearn. + include_n_obs: If to include the number of observations in the return. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + xs: A list of the column names of x1, x2, ... + sufficient_stats_elements: A SQL query that has all unique elements of + sufficient stats. Each row corresponds to one slice in split_by. The + columns are + split_by, + avg(x0), avg(x1), ..., # if fit_intercept + avg(x0 * x0), avg(x0 * x1), avg(x0 * x2), avg(x1 * x2), ..., + avg(y), # if fit_intercept + avg(x0 * y), avg(x1 * y), ..., + n_observation # if include_n_obs. + The column are named as + split_by, x0, x1,..., x0x0, x0x1,..., y, x0y, x1y,..., n_obs. + avg_x: Nonempty only when normalize. A pd.DataFrame which holds the + avg(x0), avg(x1), ... of the UNNORMALIZED x. + Don't confuse it with the ones in the sufficient_stats_elements, which are + the average of normalized x, which are just 0s. + norms: Nonempty only when normalize. A pd.DataFrame which holds the l2-norm + values of all centered-x columns. + """ table, with_data, xs_cols, y, avg_x, norms = get_data( - m, table, split_by, execute, normalize + m, + table, + split_by, + normalize, + global_filter, + indexes, + local_filter, + with_data, ) xs = xs_cols.aliases - x_t_x = [] - x_t_y = [] - if m.fit_intercept: - if not normalize: - x_t_x = [sql.Column(f'AVG({x})', alias=f'x{i}') for i, x in enumerate(xs)] - x_t_y = [sql.Column(f'AVG({y})', alias='y')] - for i, x1 in enumerate(xs): - for j, x2 in enumerate(xs[i:]): - x_t_x.append(sql.Column(f'AVG({x1} * {x2})', alias=f'x{i}x{i + j}')) - x_t_y += [ - sql.Column(f'AVG({x} * {y})', alias=f'x{i}y') for i, x in enumerate(xs) - ] + x_t_x, x_t_y = utils.get_x_t_x_and_x_t_y_cols( + xs, y, '', m.fit_intercept, normalize + ) cols = sql.Columns(x_t_x + x_t_y) if include_n_obs: cols.add(sql.Column('COUNT(*)', alias='n_obs')) sufficient_stats_elements = sql.Sql( cols, table, groupby=sql.Columns(split_by).aliases, with_data=with_data ) - sufficient_stats_elements = execute(str(sufficient_stats_elements)) - if normalize: - col_names = list(sufficient_stats_elements.columns) - avg_x_names = [f'x{i}' for i in range(len(xs))] - sufficient_stats_elements[avg_x_names] = 0 - sufficient_stats_elements = sufficient_stats_elements[ - col_names[: len(split_by)] + avg_x_names + col_names[len(split_by) :] - ] return xs_cols, sufficient_stats_elements, avg_x, norms -def get_data(m, table, split_by, execute, normalize=False): +def get_data( + m, + table, + split_by, + normalize=False, + global_filter=None, + indexes=None, + local_filter=None, + with_data=None, +): """Retrieves the data that the model will be fit on. We compute a Model by first computing its children, and then fitting @@ -357,8 +600,13 @@ def get_data(m, table, split_by, execute, normalize=False): m: A Model instance. table: The table we want to query from. split_by: The columns that we use to split the data. - execute: A function that can executes a SQL query and returns a DataFrame. normalize: If the Model normalizes x. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: The WITH clause that holds all necessary subqueries so we can + query the `table`. Returns: table: A string representing the table name which we can query from. The @@ -373,9 +621,9 @@ def get_data(m, table, split_by, execute, normalize=False): norms: Nonempty only when normalize is True. A pd.DataFrame which holds the l2-norm values of all centered-x columns. """ - data = m.children[0].to_sql(table, split_by + m.group_by) - with_data = data.with_data - data.with_data = None + data, with_data = get_data_to_fit( + m, table, split_by, global_filter, indexes, local_filter, with_data + ) table = with_data.merge(sql.Datasource(data, 'DataToFit')) y = data.columns[-m.k - 1].alias xs_cols = sql.Columns(data.columns[-m.k :]) @@ -387,9 +635,7 @@ def get_data(m, table, split_by, execute, normalize=False): avg_x_and_y = sql.Columns([sql.Column(f'AVG({x})', alias=x) for x in xs]) avg_x_and_y.add(sql.Column(f'AVG({y})', alias=y)) cols = sql.Columns(split_by).add(avg_x_and_y) - avgs = execute( - str(sql.Sql(cols, table, groupby=split_by, with_data=with_data)) - ) + avgs = sql.Sql(cols, table, groupby=split_by, with_data=with_data) avg_table = sql.Sql( cols, table, @@ -422,7 +668,6 @@ def get_data(m, table, split_by, execute, normalize=False): groupby=split_by, with_data=with_data, ) - norms = execute(str(norms)) x_norm_squared = [sql.Column(f'SUM(POWER({x}, 2))', alias=x) for x in xs] norm_squared_table = sql.Sql( @@ -456,6 +701,32 @@ def get_data(m, table, split_by, execute, normalize=False): return table, with_data, xs_cols, y, avgs, norms +def get_data_to_fit( + m, + table, + split_by, + global_filter=None, + indexes=None, + local_filter=None, + with_data=None, +): + """Gets data for model fitting.""" + # All filters are global when getting data to fit. + global_filter = sql.Filters(global_filter).add(local_filter).add(m.where_) + if indexes is None: + indexes = sql.Columns(split_by) + return m.children[0].get_sql_and_with_clause( + table, + sql.Columns(split_by).add(m.extra_split_by), + global_filter, + sql.Columns(indexes) + .add(m.extra_split_by) + .add(m.children[0].get_extra_idx()), + sql.Filters(), + sql.Datasources(with_data), + ) + + def apply_algorithm_to_sufficient_stats_elements( sufficient_stats_elements, split_by, algorithm, *args, **kwargs ): @@ -491,7 +762,7 @@ def compute_ridge_coefs(sufficient_stats, xs, m): if fit_intercept and m.normalize: return compute_coef_for_normalize_ridge(sufficient_stats, xs, m) x_t_x, x_t_y = construct_matrix_from_elements(sufficient_stats, fit_intercept) - if isinstance(m, Ridge): + if isinstance(m, Ridge) and m.alpha: n_obs = sufficient_stats['n_obs'] penalty = np.identity(len(x_t_y)) if fit_intercept: @@ -758,6 +1029,8 @@ def compute_on_sql_magic_mode(self, table, split_by, execute): self.max_iter, ) if self.fit_intercept and self.normalize: + avgs = execute(str(avgs)) + norms = execute(str(norms)) coef = compute_normalized_coef(coef, norms, avgs, split_by) columns = list(coef.columns) columns[-len(xs) :] = [x.alias_raw for x in xs] @@ -1064,9 +1337,7 @@ def compute_on_sql_magic_mode(self, table, split_by, execute): self, table, split_by, execute, include_n_obs=True ) - table, with_data, xs_cols, y, _, _ = get_data( - self, table, split_by, execute - ) + table, with_data, xs_cols, y, _, _ = get_data(self, table, split_by) xs = xs_cols.aliases if self.fit_intercept: xs.append('1') diff --git a/models_test.py b/models_test.py index b5e4798..5fbe915 100644 --- a/models_test.py +++ b/models_test.py @@ -56,6 +56,40 @@ def test_model(self, model, sklearn_model, name): ]) pd.testing.assert_frame_equal(output, expected) + def test_model_on_operations(self, model, sklearn_model, name): + del name # unused + s = metrics.Ratio('X1', 'Y') + s2 = metrics.Sum('Y') + pct = operations.PercentChange('grp1', 'A', s, include_base=True) + ab = operations.AbsoluteChange('grp1', 'A', s, include_base=True) + mh = operations.MH('grp1', 'A', 'grp2', s, include_base=True) + prepost = operations.PrePostChange( + 'grp1', 'A', s, s2, 'grp2', include_base=True + ) + cuped = operations.CUPED('grp1', 'A', s, s2, 'grp2', include_base=True) + all_changes = metrics.MetricList((pct, ab, mh, prepost, cuped)) + m1 = model(pct, [ab, mh, prepost, cuped], name='foo') + m2 = model(name='foo')(all_changes) + + output1 = m1.compute_on(DF) + output2 = m2.compute_on(DF) + + data_to_fit = all_changes.compute_on(DF) + model = sklearn_model().fit(data_to_fit.iloc[:, 1:], data_to_fit.iloc[:, 0]) + expected = pd.DataFrame([[model.intercept_] + list(model.coef_)]) + expected.columns = ['foo Coefficient: intercept'] + [ + f'foo Coefficient: sum(X1) / sum(Y) {c}' + for c in ( + 'Absolute Change', + 'MH Ratio', + 'PrePost Percent Change', + 'CUPED Change', + ) + ] + + pd.testing.assert_frame_equal(output1, expected) + pd.testing.assert_frame_equal(output2, expected) + def test_melted(self, model, sklearn_model, name): del sklearn_model, name # unused m = model(metrics.Sum('Y'), metrics.Sum('X1'), 'grp1') diff --git a/operations.py b/operations.py index ec68e5d..0f91ec3 100644 --- a/operations.py +++ b/operations.py @@ -374,7 +374,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, child_table = sql.Datasource(dist_sql, 'CumulativeDistributionRaw') child_table_alias = with_data.merge(child_table) columns = sql.Columns(indexes.aliases) - order = list(utils.get_extra_idx(self)) + order = list(self.get_extra_idx()) order = [ sql.Column(self.get_ordered_col(sql.Column(o).alias), auto_alias=False) for o in order @@ -521,23 +521,28 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, sql_template_for_comparison = self.get_sql_template_for_comparison( raw_table_alias, base_table_alias ) - columns = sql.Columns() - val_col_len = len(raw_table_sql.all_columns) - len(indexes) + columns = [] for r, b in zip( - raw_table_sql.all_columns[-val_col_len:], - base_value.columns[-val_col_len:], + raw_table_sql.all_columns[::-1], + base_value.columns[::-1], ): + if r.alias in sql.Columns(utils.get_extra_split_by(self)).aliases: + break col = sql.Column( sql_template_for_comparison % {'r': r.alias, 'b': b.alias}, alias=self.name_tmpl.format(r.alias_raw), ) - columns.add(col) + columns = [col] + columns using = indexes.difference(cond_cols) join = '' if using else 'CROSS' - return sql.Sql( - sql.Columns(indexes.aliases).add(columns), - sql.Join(raw_table_alias, base_table_alias, join=join, using=using), - cond), with_data + return ( + sql.Sql( + sql.Columns(indexes.aliases).add(columns), + sql.Join(raw_table_alias, base_table_alias, join=join, using=using), + cond, + ), + with_data, + ) def get_change_raw_sql( self, table, split_by, global_filter, indexes, local_filter, with_data diff --git a/requirements.txt b/requirements.txt index 87623ed..84140f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ six numpy>=2.0.0 scipy>=1.9.3 sklearn>=1.6.1 +sympy>=1.12 pandas>=2.2.2 pydot>=1.4.2 \ No newline at end of file diff --git a/sql.py b/sql.py index d4a3d0a..3ce2d4a 100644 --- a/sql.py +++ b/sql.py @@ -491,6 +491,10 @@ def __init__(self, columns=None, distinct=None): # pylint: disable=super-init-n def aliases(self): return [c.alias for c in self] + @property + def raw_aliases(self): + return [c.alias_raw for c in self] + @property def original_columns(self): # Returns the original Column instances added. diff --git a/utils.py b/utils.py index f1873b9..4b7c555 100644 --- a/utils.py +++ b/utils.py @@ -24,6 +24,7 @@ from typing import Iterable, List, Optional, Text, Union from meterstick import sql +import numpy as np import pandas as pd @@ -155,32 +156,6 @@ def apply_name_tmpl(name_tmpl, res, melted=False): return res -def get_extra_idx(metric, return_superset=False): - """Collects the extra indexes added by Operations for the metric tree. - - Args: - metric: A Metric instance. - return_superset: If to return the superset of extra indexes if metric has - incompatible indexes. - - Returns: - A tuple of column names which are just the index of metric.compute_on(df). - """ - extra_idx = metric.extra_index[:] - children_idx = [ - get_extra_idx(c, return_superset) for c in metric.children if is_metric(c) - ] - if len(set(children_idx)) > 1: - if not return_superset: - raise ValueError('Incompatible indexes!') - children_idx_superset = set() - children_idx_superset.update(*children_idx) - children_idx = [list(children_idx_superset)] - if children_idx: - extra_idx += list(children_idx[0]) - return tuple(extra_idx) - - def get_extra_split_by(metric, return_superset=False): """Collects the extra split_by added by Operations for the metric tree. @@ -711,3 +686,102 @@ def pcollection_to_df_via_file_io( if cleanup: os.remove(f) return pd.concat(res, ignore_index=True) + + +def get_x_t_x_and_x_t_y_cols( + xs: List[str], y: str, prefix='', fit_intercept=True, normalize=False +): + """Computes the x_t_x and x_t_y elements. + + When solving LinearRegression or Ridge using sufficient stats, we need to + constuct SQL columns for X'X and X'Y. This function takes the SQL columns of + X and y, and output SQL columns for the elements of X'X and X'Y. + + Args: + xs: A list of column names of the features. + y: The column name of y. + prefix: A prefix to be added to the alias of the generated SQL columns. + fit_intercept: If the model in question fits intercept. + normalize: If the model in question normalizes x. + + Returns: + The SQL columns for the elements of X'X and X'Y. The elements of X'X are + avg(x0), avg(x1), ..., # if fit_intercept + avg(x0 * x0), avg(x0 * x1), avg(x0 * x2), avg(x1 * x2), .... + The elements of X'Y are + avg(y), # if fit_intercept + avg(x0 * y), avg(x1 * y), ..., + Note that when fit_intercept, the return cannot be directly fed to the + get_ridge_coefficients() below. You need to prepend a '1' to x_t_x. + """ + x_t_x = [] + x_t_y = [] + if fit_intercept: + if not normalize: + x_t_x = [ + sql.Column(f'AVG({x})', alias=f'{prefix}x{i}') + for i, x in enumerate(xs) + ] + x_t_y = [sql.Column(f'AVG({y})', alias=f'{prefix}y')] + for i, x1 in enumerate(xs): + for j, x2 in enumerate(xs[i:]): + x_t_x.append( + sql.Column(f'AVG({x1} * {x2})', alias=f'{prefix}x{i}x{i + j}') + ) + x_t_y += [ + sql.Column(f'AVG({x} * {y})', alias=f'{prefix}x{i}y') + for i, x in enumerate(xs) + ] + return x_t_x, x_t_y + + +def get_ridge_coefficients( + x_t_x_elements, + x_t_y_elements, + fit_intercept=True, + penalty=0, + normalize=False, +): + """Computes coefficients of Ridge regression. + + Args: + x_t_x_elements: The SQL column names of the elements of X'X. It's the 1st + return of get_x_t_x_and_x_t_y_cols. + x_t_y_elements: The SQL column names of the elements of X'Y. It's the 2nd + return of get_x_t_x_and_x_t_y_cols. + fit_intercept: If the model in question fits intercept. + penalty: The penalty of Ridge regression. + normalize: If the model normalizes x. It only has effect when fit_intercept + is also True. If normalize, all the AVG(x) columns are 0. + + Returns: + (X'X)^(-1)(X'Y) as a Sympy matrix. + """ + import sympy # pylint: disable=g-import-not-at-top + del normalize + # normalize = fit_intercept and normalize + if fit_intercept: + x_t_y_elements = ['1'] + x_t_y_elements + n = len(x_t_y_elements) + if n > 5: + print( + f'WARNING: Doing symbolic computation on {n}*{n} matrix. It will be' + " slow. Consider using compute_on_sql(mode='mixed')." + ) + x_t_x = np.empty([n, n], dtype=object) + x_t_x[np.triu_indices(n)] = x_t_x_elements + x_t_x[np.tril_indices(n)] = x_t_x.T[np.tril_indices(n)] + x_t_x = sympy.Matrix(x_t_x) + if penalty: + iden = np.identity(n) + if fit_intercept: + iden[0, 0] = 0 + x_t_x += penalty * sympy.Matrix(iden) + # Do not use x_t_x.inv(). It's very slow + # https://stackoverflow.com/questions/75553096/why-is-sympy-matrix-inv-slow. + x_t_x_inv = x_t_x.adjugate() / x_t_x.det() + x_t_y = sympy.Matrix(x_t_y_elements) + coefs = x_t_x_inv * x_t_y + if n < 5: + coefs = sympy.simplify(coefs) # slow if n is large + return coefs diff --git a/utils_test.py b/utils_test.py index 2fb55f7..b2d5082 100644 --- a/utils_test.py +++ b/utils_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from meterstick import metrics from meterstick import operations +from meterstick import sql from meterstick import utils import numpy as np import pandas as pd @@ -32,36 +33,37 @@ def test_adjust_slices_for_loo_no_splitby_no_operation_unit_filled(self): df = pd.DataFrame({'unit': list('abc'), 'x': range(1, 4)}) bucket_res = df[df.unit != 'a'].groupby('unit').sum() output = utils.adjust_slices_for_loo(bucket_res, [], df) - expected = pd.DataFrame({'x': [0, 2, 3]}, - index=pd.Index(list('abc'), name='unit')) + expected = pd.DataFrame( + {'x': [0, 2, 3]}, index=pd.Index(list('abc'), name='unit') + ) testing.assert_frame_equal(output, expected) def test_adjust_slices_for_loo_no_splitby_operation(self): - df = pd.DataFrame({ - 'unit': list('abb'), - 'grp': list('bbc'), - 'x': range(1, 4) - }) + df = pd.DataFrame( + {'unit': list('abb'), 'grp': list('bbc'), 'x': range(1, 4)} + ) bucket_res = df[df.unit != 'a'].groupby(['unit', 'grp']).sum() output = utils.adjust_slices_for_loo(bucket_res, [], df) - expected = pd.DataFrame({'x': [0, 0]}, - index=pd.MultiIndex.from_tuples( - (('a', 'b'), ('a', 'c')), - names=('unit', 'grp'))) + expected = pd.DataFrame( + {'x': [0, 0]}, + index=pd.MultiIndex.from_tuples( + (('a', 'b'), ('a', 'c')), names=('unit', 'grp') + ), + ) testing.assert_frame_equal(output, expected) def test_adjust_slices_for_loo_splitby_no_operation(self): - df = pd.DataFrame({ - 'unit': list('abc'), - 'grp': list('abb'), - 'x': range(1, 4) - }) + df = pd.DataFrame( + {'unit': list('abc'), 'grp': list('abb'), 'x': range(1, 4)} + ) bucket_res = df[df.grp != 'b'].groupby(['grp', 'unit']).sum() output = utils.adjust_slices_for_loo(bucket_res, ['grp'], df) - expected = pd.DataFrame({'x': [1, 0, 0]}, - index=pd.MultiIndex.from_tuples( - (('a', 'a'), ('b', 'b'), ('b', 'c')), - names=('grp', 'unit'))) + expected = pd.DataFrame( + {'x': [1, 0, 0]}, + index=pd.MultiIndex.from_tuples( + (('a', 'a'), ('b', 'b'), ('b', 'c')), names=('grp', 'unit') + ), + ) testing.assert_frame_equal(output, expected) def test_adjust_slices_for_loo_splitby_operation(self): @@ -69,22 +71,25 @@ def test_adjust_slices_for_loo_splitby_operation(self): 'grp': list('aaabbb'), 'op': ['x'] * 2 + ['y'] * 2 + ['z'] * 2, 'unit': [1, 2, 3, 2, 3, 2], - 'x': range(1, 7) + 'x': range(1, 7), }) bucket_res = df[df.unit != 1].groupby(['grp', 'unit', 'op']).sum() output = utils.adjust_slices_for_loo(bucket_res, ['grp'], df) - expected = pd.DataFrame({'x': [0, 0, 0, 0, 6, 0, 5]}, - index=pd.MultiIndex.from_tuples( - ( - ('a', 1, 'x'), - ('a', 1, 'y'), - ('a', 2, 'y'), - ('a', 3, 'x'), - ('b', 2, 'z'), - ('b', 3, 'y'), - ('b', 3, 'z'), - ), - names=('grp', 'unit', 'op'))) + expected = pd.DataFrame( + {'x': [0, 0, 0, 0, 6, 0, 5]}, + index=pd.MultiIndex.from_tuples( + ( + ('a', 1, 'x'), + ('a', 1, 'y'), + ('a', 2, 'y'), + ('a', 3, 'x'), + ('b', 2, 'z'), + ('b', 3, 'y'), + ('b', 3, 'z'), + ), + names=('grp', 'unit', 'op'), + ), + ) testing.assert_frame_equal(output, expected) def test_one_level_column_and_no_splitby_melt(self): @@ -103,173 +108,183 @@ def test_one_level_value_column_and_no_splitby_unmelt(self): def test_one_level_not_value_column_and_no_splitby_unmelt(self): melted = pd.DataFrame({'Baz': [1, 2]}, index=['foo', 'bar']) melted.index.name = 'Metric' - expected = pd.DataFrame([[1, 2]], - columns=pd.MultiIndex.from_product( - [['foo', 'bar'], ['Baz']], - names=['Metric', None])) + expected = pd.DataFrame( + [[1, 2]], + columns=pd.MultiIndex.from_product( + [['foo', 'bar'], ['Baz']], names=['Metric', None] + ), + ) testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_one_level_column_and_single_splitby_melt(self): unmelted = pd.DataFrame( - data={ - 'foo': [0, 1], - 'bar': [2, 3] - }, + data={'foo': [0, 1], 'bar': [2, 3]}, columns=['foo', 'bar'], - index=['B', 'A']) + index=['B', 'A'], + ) unmelted.index.name = 'grp' - expected = pd.DataFrame({'Value': range(4)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp'])) + expected = pd.DataFrame( + {'Value': range(4)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + ) expected.index.name = 'Metric' testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_one_level_column_and_single_splitby_unmelt(self): expected = pd.DataFrame( - data={ - 'foo': [0, 1], - 'bar': [2, 3] - }, + data={'foo': [0, 1], 'bar': [2, 3]}, columns=['foo', 'bar'], - index=['B', 'A']) + index=['B', 'A'], + ) expected.index.name = 'grp' expected.columns.name = 'Metric' - melted = pd.DataFrame({'Value': range(4)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp'])) + melted = pd.DataFrame( + {'Value': range(4)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + ) melted.index.name = 'Metric' testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_one_level_column_and_multiple_splitby_melt(self): unmelted = pd.DataFrame( - data={ - 'foo': range(4), - 'bar': range(4, 8) - }, + data={'foo': range(4), 'bar': range(4, 8)}, columns=['foo', 'bar'], - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) - expected = pd.DataFrame({'Value': range(8)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) + expected = pd.DataFrame( + {'Value': range(8)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), + names=['Metric', 'grp', 'country'], + ), + ) expected.index.name = 'Metric' testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_one_level_column_and_multiple_splitby_unmelt(self): - melted = pd.DataFrame({'Value': range(8)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country'])) + melted = pd.DataFrame( + {'Value': range(8)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), + names=['Metric', 'grp', 'country'], + ), + ) expected = pd.DataFrame( - data={ - 'foo': range(4), - 'bar': range(4, 8) - }, + data={'foo': range(4), 'bar': range(4, 8)}, columns=['foo', 'bar'], - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) expected.columns.name = 'Metric' testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_multiple_index_columns_and_no_splitby_melt(self): - unmelted = pd.DataFrame([[1, 2, 3, 4]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE']))) + unmelted = pd.DataFrame( + [[1, 2, 3, 4]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + ) expected = pd.DataFrame( - data={ - 'Value': [1, 3], - 'SE': [2, 4] - }, + data={'Value': [1, 3], 'SE': [2, 4]}, index=['foo', 'bar'], - columns=['Value', 'SE']) + columns=['Value', 'SE'], + ) expected.index.name = 'Metric' testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_multiple_index_columns_and_no_splitby_unmelt(self): melted = pd.DataFrame( - data={ - 'Value': [1, 3], - 'SE': [2, 4] - }, + data={'Value': [1, 3], 'SE': [2, 4]}, index=['foo', 'bar'], - columns=['Value', 'SE']) + columns=['Value', 'SE'], + ) melted.index.name = 'Metric' - expected = pd.DataFrame([[1, 2, 3, 4]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE']))) + expected = pd.DataFrame( + [[1, 2, 3, 4]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + ) expected.columns.names = ['Metric', None] testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_multiple_index_column_and_single_splitby_melt(self): - unmelted = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE'])), - index=['B', 'A']) + unmelted = pd.DataFrame( + [[1, 2, 3, 4], [5, 6, 7, 8]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + index=['B', 'A'], + ) unmelted.index.name = 'grp' expected = pd.DataFrame( - data={ - 'Value': [1, 5, 3, 7], - 'SE': [2, 6, 4, 8] - }, - index=pd.MultiIndex.from_product((['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp']), - columns=['Value', 'SE']) + data={'Value': [1, 5, 3, 7], 'SE': [2, 6, 4, 8]}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + columns=['Value', 'SE'], + ) testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_multiple_index_column_and_single_splitby_unmelt(self): melted = pd.DataFrame( - data={ - 'Value': [1, 5, 3, 7], - 'SE': [2, 6, 4, 8] - }, - index=pd.MultiIndex.from_product((['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp']), - columns=['Value', 'SE']) - expected = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE'])), - index=['B', 'A']) + data={'Value': [1, 5, 3, 7], 'SE': [2, 6, 4, 8]}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + columns=['Value', 'SE'], + ) + expected = pd.DataFrame( + [[1, 2, 3, 4], [5, 6, 7, 8]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + index=['B', 'A'], + ) expected.index.name = 'grp' expected.columns.names = ['Metric', None] testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_multiple_index_column_and_multiple_splitby_melt(self): unmelted = pd.DataFrame( - [range(4), range(4, 8), - range(8, 12), range(12, 16)], + [range(4), range(4, 8), range(8, 12), range(12, 16)], columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) expected = pd.DataFrame( data={ 'Value': [0, 4, 8, 12, 2, 6, 10, 14], - 'SE': [1, 5, 9, 13, 3, 7, 11, 15] + 'SE': [1, 5, 9, 13, 3, 7, 11, 15], }, index=pd.MultiIndex.from_product( (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country']), - columns=['Value', 'SE']) + names=['Metric', 'grp', 'country'], + ), + columns=['Value', 'SE'], + ) testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_multiple_index_column_and_multiple_splitby_unmelt(self): melted = pd.DataFrame( data={ 'Value': [0, 4, 8, 12, 2, 6, 10, 14], - 'SE': [1, 5, 9, 13, 3, 7, 11, 15] + 'SE': [1, 5, 9, 13, 3, 7, 11, 15], }, index=pd.MultiIndex.from_product( (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country']), - columns=['Value', 'SE']) + names=['Metric', 'grp', 'country'], + ), + columns=['Value', 'SE'], + ) expected = pd.DataFrame( - [range(4), range(4, 8), - range(8, 12), range(12, 16)], + [range(4), range(4, 8), range(8, 12), range(12, 16)], columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) expected.columns.names = ['Metric', None] testing.assert_frame_equal(expected, utils.unmelt(melted)) @@ -288,29 +303,6 @@ def test_remove_empty_level(self): actual = utils.remove_empty_level(df) testing.assert_frame_equal(expected, actual) - def test_get_extra_idx(self): - mh = operations.MH('foo', 'f', 'bar', metrics.Ratio('a', 'b')) - ab = operations.AbsoluteChange('foo', 'f', metrics.Sum('c')) - m = operations.Jackknife('unit', metrics.MetricList((mh, ab))) - self.assertEqual(utils.get_extra_idx(m), ('foo',)) - - def test_get_extra_idx_return_superset(self): - s = metrics.Sum('x') - m = metrics.MetricList(( - operations.AbsoluteChange('g', 0, s), - operations.AbsoluteChange('g2', 1, s), - )) - actual = utils.get_extra_idx(m, True) - self.assertEqual(set(actual), set(('g', 'g2'))) - - def test_get_extra_idx_raises(self): - mh = operations.MH('foo', 'f', 'bar', metrics.Ratio('a', 'b')) - ab = operations.AbsoluteChange('baz', 'f', metrics.Sum('c')) - m = operations.Jackknife('unit', metrics.MetricList((mh, ab))) - with self.assertRaises(ValueError) as cm: - utils.get_extra_idx(m) - self.assertEqual(str(cm.exception), 'Incompatible indexes!') - def test_get_extra_split_by(self): mh = operations.MH('foo', 'f', 'bar', metrics.Ratio('a', 'b')) m = operations.AbsoluteChange('unit', 'a', mh) @@ -352,11 +344,9 @@ def test_get_equivalent_metric_with_df(self): expected = metrics.Sum('meterstick_tmp:(x * y)') expected.where = 'a' expected.name = 'foo' - expected_df = pd.DataFrame({ - 'x': [1, 2], - 'y': [2, 3], - 'meterstick_tmp:(x * y)': [2, 6] - }) + expected_df = pd.DataFrame( + {'x': [1, 2], 'y': [2, 3], 'meterstick_tmp:(x * y)': [2, 6]} + ) self.assertEqual(output, expected) testing.assert_frame_equal(df, expected_df) @@ -424,6 +414,60 @@ def test_get_leaf_metrics_include_constants(self): expected = [metrics.Sum('x'), metrics.Sum('y'), metrics.Sum('c'), 1] self.assertEqual(output, expected) + def test_get_x_t_x_and_x_t_y_cols_one_x(self): + xs = ['a'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y, 'foo_') + x_t_x = sql.Columns([ + sql.Column('AVG(a)', alias='foo_x0'), + sql.Column('AVG(a * a)', alias='foo_x0x0'), + ]) + x_t_y = sql.Columns([ + sql.Column('AVG(y)', alias='foo_y'), + sql.Column('AVG(a * y)', alias='foo_x0y'), + ]) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + + def test_get_x_t_x_and_x_t_y_cols_multiple_xs(self): + xs = ['a', 'b'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y) + x_t_x = sql.Columns([ + sql.Column('AVG(a)', alias='x0'), + sql.Column('AVG(b)', alias='x1'), + sql.Column('AVG(a * a)', alias='x0x0'), + sql.Column('AVG(a * b)', alias='x0x1'), + sql.Column('AVG(b * b)', alias='x1x1'), + ]) + x_t_y = sql.Columns([ + sql.Column('AVG(y)', alias='y'), + sql.Column('AVG(a * y)', alias='x0y'), + sql.Column('AVG(b * y)', alias='x1y'), + ]) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + + def test_get_x_t_x_and_x_t_y_cols_no_intercept(self): + xs = ['a'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y, fit_intercept=False) + x_t_x = sql.Columns([sql.Column('AVG(a * a)', alias='x0x0')]) + x_t_y = sql.Columns([sql.Column('AVG(a * y)', alias='x0y')]) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + + def test_get_x_t_x_and_x_t_y_cols_normalize(self): + xs = ['a'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y, normalize=True) + x_t_x = sql.Columns([sql.Column('AVG(a * a)', alias='x0x0')]) + x_t_y = sql.Columns( + [sql.Column('AVG(y)', alias='y'), sql.Column('AVG(a * y)', alias='x0y')] + ) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + if __name__ == '__main__': absltest.main()