@@ -2326,10 +2326,7 @@ def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None):
23262326 )
23272327 res = point_est .join (utils .melt (std ))
23282328 if self .confidence :
2329- res [self .prefix + ' CI-lower' ] = (
2330- res .iloc [:, 0 ] - res [self .prefix + ' CI-lower' ]
2331- )
2332- res [self .prefix + ' CI-upper' ] += res .iloc [:, 0 ]
2329+ res = self .compute_ci (res )
23332330 res = utils .unmelt (res )
23342331 base = self .compute_change_base (table , split_by , execute , mode )
23352332 return self .add_base_to_res (res , base )
@@ -3041,23 +3038,81 @@ def __init__(
30413038 self .n_replicates = n_replicates
30423039 self .ci_method = ci_method
30433040
3044- def to_sql (self , * args , ** kwargs ):
3045- if self .ci_method == 'percentile' :
3046- raise NotImplementedError (
3047- 'to_sql and compute_on_sql are not implemented for percentile'
3048- ' bootstrap.'
3041+ def compute_on_sql_sql_mode (self , table , split_by = None , execute = None ):
3042+ """Computes self in a SQL query and processes the result.
3043+
3044+ It behaves identically to MetricWithCI.compute_on_sql_sql_mode when
3045+ `ci_method` is 'std'. When `ci_method` is 'percentile', the SQL query
3046+ already computes the percentile CI bounds, so we just parse those columns
3047+ directly without applying the normal approximation.
3048+
3049+ Args:
3050+ table: The table we want to query from.
3051+ split_by: The columns that we use to split the data.
3052+ execute: A function that can execute a SQL query and returns a DataFrame.
3053+
3054+ Returns:
3055+ The result DataFrame of Bootstrap.
3056+ """
3057+ if self .ci_method == 'std' :
3058+ return super (Bootstrap , self ).compute_on_sql_sql_mode (
3059+ table , split_by , execute
30493060 )
3050- return super (Bootstrap , self ).to_sql (* args , ** kwargs )
3061+ if self .ci_method != 'percentile' :
3062+ raise ValueError ('ci_method must be either "std" or "percentile"' )
30513063
3052- def compute_on_sql (self , * args , ** kwargs ):
3053- if self .ci_method == 'percentile' :
3064+ res = super (MetricWithCI ,
3065+ self ).compute_on_sql_sql_mode (table , split_by , execute )
3066+ sub_dfs = []
3067+ base = None
3068+ if self .confidence is None :
3069+ raise ValueError ('confidence is required for percentile Bootstrap' )
3070+
3071+ if len (self .children ) == 1 and isinstance (
3072+ self .children [0 ], (PercentChange , AbsoluteChange )):
3073+ # The first 3n columns are Value, CI-lower, CI-upper for n Metrics. The
3074+ # last n columns are the base values of Change.
3075+ if len (res .columns ) % 4 :
3076+ raise ValueError ('Wrong shape for a MetricWithCI with confidence!' )
3077+ n_metrics = len (res .columns ) // 4
3078+ base = res .iloc [:, - n_metrics :]
3079+ res = res .iloc [:, :3 * n_metrics ]
3080+ change = self .children [0 ]
3081+ base .columns = [change .name_tmpl .format (c ) for c in base .columns ]
3082+ base = utils .melt (base )
3083+ base .columns = ['_base_value' ]
3084+
3085+ if len (res .columns ) % 3 :
3086+ raise ValueError ('Wrong shape for a MetricWithCI with confidence!' )
3087+
3088+ # The columns are like metric1, metric1 CI-lower, metric1 CI-upper, ...
3089+ metric_names = res .columns [::3 ]
3090+ sub_dfs = []
3091+
3092+ percentiles = list (self .select_percentiles ().keys ())
3093+ if len (percentiles ) != 2 :
30543094 raise NotImplementedError (
3055- 'to_sql and compute_on_sql are not implemented for percentile '
3056- ' bootstrap. '
3095+ 'SQL mode for percentile bootstrap currently supports exactly 2 '
3096+ ' percentiles (e.g. CI-lower and CI-upper) '
30573097 )
3058- return super (Bootstrap , self ).compute_on_sql (* args , ** kwargs )
30593098
3060- def _select_percentiles (self ) -> dict [str , float ]:
3099+ col1 = self .prefix + ' ' + percentiles [0 ]
3100+ col2 = self .prefix + ' ' + percentiles [1 ]
3101+
3102+ for i in range (0 , len (res .columns ), 3 ):
3103+ sub_df = pd .DataFrame (
3104+ {
3105+ 'Value' : res .iloc [:, i ],
3106+ col1 : res .iloc [:, i + 1 ],
3107+ col2 : res .iloc [:, i + 2 ]
3108+ },
3109+ columns = ['Value' , col1 , col2 ])
3110+ sub_dfs .append (sub_df )
3111+
3112+ res = pd .concat ((sub_dfs ), axis = 1 , keys = metric_names , names = ['Metric' ])
3113+ return self .add_base_to_res (res , base )
3114+
3115+ def select_percentiles (self ) -> dict [str , float ]:
30613116 """Returns the percentiles (as quantiles) to compute for Bootstrap.
30623117
30633118 By default, this method only uses the requested confidence level to select
@@ -3092,7 +3147,7 @@ def compute_on_children(
30923147 bucket_estimates = pd .concat (children , axis = 1 , sort = False )
30933148 stats_df = pd .DataFrame ({
30943149 f'{ self .prefix } { name } ' : bucket_estimates .quantile (q , axis = 1 )
3095- for name , q in self ._select_percentiles ().items ()
3150+ for name , q in self .select_percentiles ().items ()
30963151 })
30973152
30983153 return utils .unmelt (stats_df )
@@ -3755,18 +3810,29 @@ def get_se_sql(
37553810 groupby .add (c .alias )
37563811 else :
37573812 alias = c .alias
3758- se = sql .Column (c .alias , sql .STDDEV_SAMP_FN ,
3759- '%s Bootstrap SE' % c .alias_raw )
3760- if isinstance (metric , Jackknife ):
3761- adjustment = sql .Column (
3762- sql .SAFE_DIVIDE_FN (
3763- numer = 'COUNT({c}) - 1' , denom = 'SQRT(COUNT({c}))'
3764- ).format (c = alias )
3765- )
3766- se = (se * adjustment ).set_alias ('%s Jackknife SE' % c .alias_raw )
3767- columns .add (se )
3768- if metric .confidence :
3769- columns .add (sql .Column (alias , 'COUNT({}) - 1' , '%s dof' % c .alias_raw ))
3813+ ci_method = getattr (metric , 'ci_method' , 'std' )
3814+ if ci_method == 'percentile' :
3815+ for k , v in metric .select_percentiles ().items ():
3816+ pct_col = sql .Column (alias , sql .QUANTILE_FN (v ),
3817+ f'{ c .alias_raw } { k } ' )
3818+ columns .add (pct_col )
3819+ elif ci_method == 'std' :
3820+ se = sql .Column (c .alias , sql .STDDEV_SAMP_FN ,
3821+ '%s Bootstrap SE' % c .alias_raw )
3822+ if isinstance (metric , Jackknife ):
3823+ adjustment = sql .Column (
3824+ sql .SAFE_DIVIDE_FN (
3825+ numer = 'COUNT({c}) - 1' , denom = 'SQRT(COUNT({c}))'
3826+ ).format (c = alias )
3827+ )
3828+ se = (se * adjustment ).set_alias ('%s Jackknife SE' % c .alias_raw )
3829+ columns .add (se )
3830+ if metric .confidence :
3831+ columns .add (
3832+ sql .Column (alias , 'COUNT({}) - 1' , '%s dof' % c .alias_raw )
3833+ )
3834+ else :
3835+ raise ValueError (f'Unknown ci_method: { ci_method } ' )
37703836 return sql .Sql (columns , samples_alias , groupby = groupby ), with_data
37713837
37723838
0 commit comments