Skip to content

Commit bf8cb3a

Browse files
tcyameterstick-copybara
authored andcommitted
Implement SQL generator for Bootstrap with percentile CIs.
PiperOrigin-RevId: 889671472
1 parent 185fb02 commit bf8cb3a

1 file changed

Lines changed: 95 additions & 29 deletions

File tree

operations.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,10 +2326,7 @@ def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None):
23262326
)
23272327
res = point_est.join(utils.melt(std))
23282328
if self.confidence:
2329-
res[self.prefix + ' CI-lower'] = (
2330-
res.iloc[:, 0] - res[self.prefix + ' CI-lower']
2331-
)
2332-
res[self.prefix + ' CI-upper'] += res.iloc[:, 0]
2329+
res = self.compute_ci(res)
23332330
res = utils.unmelt(res)
23342331
base = self.compute_change_base(table, split_by, execute, mode)
23352332
return self.add_base_to_res(res, base)
@@ -3041,23 +3038,81 @@ def __init__(
30413038
self.n_replicates = n_replicates
30423039
self.ci_method = ci_method
30433040

3044-
def to_sql(self, *args, **kwargs):
3045-
if self.ci_method == 'percentile':
3046-
raise NotImplementedError(
3047-
'to_sql and compute_on_sql are not implemented for percentile'
3048-
' bootstrap.'
3041+
def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
3042+
"""Computes self in a SQL query and processes the result.
3043+
3044+
It behaves identically to MetricWithCI.compute_on_sql_sql_mode when
3045+
`ci_method` is 'std'. When `ci_method` is 'percentile', the SQL query
3046+
already computes the percentile CI bounds, so we just parse those columns
3047+
directly without applying the normal approximation.
3048+
3049+
Args:
3050+
table: The table we want to query from.
3051+
split_by: The columns that we use to split the data.
3052+
execute: A function that can execute a SQL query and returns a DataFrame.
3053+
3054+
Returns:
3055+
The result DataFrame of Bootstrap.
3056+
"""
3057+
if self.ci_method == 'std':
3058+
return super(Bootstrap, self).compute_on_sql_sql_mode(
3059+
table, split_by, execute
30493060
)
3050-
return super(Bootstrap, self).to_sql(*args, **kwargs)
3061+
if self.ci_method != 'percentile':
3062+
raise ValueError('ci_method must be either "std" or "percentile"')
30513063

3052-
def compute_on_sql(self, *args, **kwargs):
3053-
if self.ci_method == 'percentile':
3064+
res = super(MetricWithCI,
3065+
self).compute_on_sql_sql_mode(table, split_by, execute)
3066+
sub_dfs = []
3067+
base = None
3068+
if self.confidence is None:
3069+
raise ValueError('confidence is required for percentile Bootstrap')
3070+
3071+
if len(self.children) == 1 and isinstance(
3072+
self.children[0], (PercentChange, AbsoluteChange)):
3073+
# The first 3n columns are Value, CI-lower, CI-upper for n Metrics. The
3074+
# last n columns are the base values of Change.
3075+
if len(res.columns) % 4:
3076+
raise ValueError('Wrong shape for a MetricWithCI with confidence!')
3077+
n_metrics = len(res.columns) // 4
3078+
base = res.iloc[:, -n_metrics:]
3079+
res = res.iloc[:, :3 * n_metrics]
3080+
change = self.children[0]
3081+
base.columns = [change.name_tmpl.format(c) for c in base.columns]
3082+
base = utils.melt(base)
3083+
base.columns = ['_base_value']
3084+
3085+
if len(res.columns) % 3:
3086+
raise ValueError('Wrong shape for a MetricWithCI with confidence!')
3087+
3088+
# The columns are like metric1, metric1 CI-lower, metric1 CI-upper, ...
3089+
metric_names = res.columns[::3]
3090+
sub_dfs = []
3091+
3092+
percentiles = list(self.select_percentiles().keys())
3093+
if len(percentiles) != 2:
30543094
raise NotImplementedError(
3055-
'to_sql and compute_on_sql are not implemented for percentile'
3056-
' bootstrap.'
3095+
'SQL mode for percentile bootstrap currently supports exactly 2'
3096+
' percentiles (e.g. CI-lower and CI-upper)'
30573097
)
3058-
return super(Bootstrap, self).compute_on_sql(*args, **kwargs)
30593098

3060-
def _select_percentiles(self) -> dict[str, float]:
3099+
col1 = self.prefix + ' ' + percentiles[0]
3100+
col2 = self.prefix + ' ' + percentiles[1]
3101+
3102+
for i in range(0, len(res.columns), 3):
3103+
sub_df = pd.DataFrame(
3104+
{
3105+
'Value': res.iloc[:, i],
3106+
col1: res.iloc[:, i + 1],
3107+
col2: res.iloc[:, i + 2]
3108+
},
3109+
columns=['Value', col1, col2])
3110+
sub_dfs.append(sub_df)
3111+
3112+
res = pd.concat((sub_dfs), axis=1, keys=metric_names, names=['Metric'])
3113+
return self.add_base_to_res(res, base)
3114+
3115+
def select_percentiles(self) -> dict[str, float]:
30613116
"""Returns the percentiles (as quantiles) to compute for Bootstrap.
30623117
30633118
By default, this method only uses the requested confidence level to select
@@ -3092,7 +3147,7 @@ def compute_on_children(
30923147
bucket_estimates = pd.concat(children, axis=1, sort=False)
30933148
stats_df = pd.DataFrame({
30943149
f'{self.prefix} {name}': bucket_estimates.quantile(q, axis=1)
3095-
for name, q in self._select_percentiles().items()
3150+
for name, q in self.select_percentiles().items()
30963151
})
30973152

30983153
return utils.unmelt(stats_df)
@@ -3755,18 +3810,29 @@ def get_se_sql(
37553810
groupby.add(c.alias)
37563811
else:
37573812
alias = c.alias
3758-
se = sql.Column(c.alias, sql.STDDEV_SAMP_FN,
3759-
'%s Bootstrap SE' % c.alias_raw)
3760-
if isinstance(metric, Jackknife):
3761-
adjustment = sql.Column(
3762-
sql.SAFE_DIVIDE_FN(
3763-
numer='COUNT({c}) - 1', denom='SQRT(COUNT({c}))'
3764-
).format(c=alias)
3765-
)
3766-
se = (se * adjustment).set_alias('%s Jackknife SE' % c.alias_raw)
3767-
columns.add(se)
3768-
if metric.confidence:
3769-
columns.add(sql.Column(alias, 'COUNT({}) - 1', '%s dof' % c.alias_raw))
3813+
ci_method = getattr(metric, 'ci_method', 'std')
3814+
if ci_method == 'percentile':
3815+
for k, v in metric.select_percentiles().items():
3816+
pct_col = sql.Column(alias, sql.QUANTILE_FN(v),
3817+
f'{c.alias_raw} {k}')
3818+
columns.add(pct_col)
3819+
elif ci_method == 'std':
3820+
se = sql.Column(c.alias, sql.STDDEV_SAMP_FN,
3821+
'%s Bootstrap SE' % c.alias_raw)
3822+
if isinstance(metric, Jackknife):
3823+
adjustment = sql.Column(
3824+
sql.SAFE_DIVIDE_FN(
3825+
numer='COUNT({c}) - 1', denom='SQRT(COUNT({c}))'
3826+
).format(c=alias)
3827+
)
3828+
se = (se * adjustment).set_alias('%s Jackknife SE' % c.alias_raw)
3829+
columns.add(se)
3830+
if metric.confidence:
3831+
columns.add(
3832+
sql.Column(alias, 'COUNT({}) - 1', '%s dof' % c.alias_raw)
3833+
)
3834+
else:
3835+
raise ValueError(f'Unknown ci_method: {ci_method}')
37703836
return sql.Sql(columns, samples_alias, groupby=groupby), with_data
37713837

37723838

0 commit comments

Comments
 (0)