Skip to content

Commit 899ff88

Browse files
committed
update APO models documentation and enhance APOSTuningCoverageSimulation with loss metrics and parameter tuning
1 parent 2334050 commit 899ff88

File tree

5 files changed

+55
-66
lines changed

5 files changed

+55
-66
lines changed

doc/irm/apo.qmd

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ generate_and_show_styled_table(
8484

8585
The simulations are based on the the [make_irm_data_discrete_treatments](https://docs.doubleml.org/stable/api/datasets.html#dataset-generators)-DGP with $500$ observations. Due to the linearity of the DGP, Lasso and Logit Regression are nearly optimal choices for the nuisance estimation.
8686

87-
The non-uniform results (coverage, ci length and bias) refer to averaged values over all quantiles (point-wise confidende intervals).
87+
The non-uniform results (coverage, ci length and bias) refer to averaged values over all levels (point-wise confidende intervals).
8888

8989
::: {.callout-note title="Metadata" collapse="true"}
9090

@@ -140,7 +140,7 @@ generate_and_show_styled_table(
140140

141141
The simulations are based on the the [make_irm_data_discrete_treatments](https://docs.doubleml.org/stable/api/datasets.html#dataset-generators)-DGP with $500$ observations. Due to the linearity of the DGP, Lasso and Logit Regression are nearly optimal choices for the nuisance estimation.
142142

143-
The non-uniform results (coverage, ci length and bias) refer to averaged values over all quantiles (point-wise confidende intervals).
143+
The non-uniform results (coverage, ci length and bias) refer to averaged values over all levels (point-wise confidende intervals).
144144

145145
::: {.callout-note title="Metadata" collapse="true"}
146146

@@ -199,7 +199,7 @@ The simulations are based on the the [make_irm_data_discrete_treatments](https:
199199

200200
### APOS Coverage
201201

202-
The non-uniform results (coverage, ci length and bias) refer to averaged values over all quantiles (point-wise confidende intervals).
202+
The non-uniform results (coverage, ci length and bias) refer to averaged values over all levels (point-wise confidende intervals). The same holds for the loss values which are averaged over all treatment levels.
203203

204204
::: {.callout-note title="Metadata" collapse="true"}
205205

@@ -216,22 +216,22 @@ print(metadata_df.T.to_string(header=False))
216216
#| echo: false
217217
218218
# set up data
219-
df_apos = pd.read_csv("../../results/irm/apos_tune_coverage.csv", index_col=None)
219+
df_apos_tune = pd.read_csv("../../results/irm/apos_tune_coverage.csv", index_col=None)
220220
221-
assert df_apos["repetition"].nunique() == 1
222-
n_rep_apos = df_apos["repetition"].unique()[0]
221+
assert df_apos_tune["repetition"].nunique() == 1
222+
n_rep_apos_tune = df_apos_tune["repetition"].unique()[0]
223223
224-
display_columns_apos = ["Learner g", "Learner m", "Tuned", "Bias", "CI Length", "Coverage", "Uniform CI Length", "Uniform Coverage"]
224+
display_columns_apos_tune = ["Learner g", "Learner m", "Tuned", "Bias", "CI Length", "Coverage", "Uniform CI Length", "Uniform Coverage", "Loss g_control", "Loss g_treated", "Loss m"]
225225
```
226226

227227
```{python}
228228
#| echo: false
229229
230230
generate_and_show_styled_table(
231-
main_df=df_apos,
231+
main_df=df_apos_tune,
232232
filters={"level": 0.95},
233-
display_cols=display_columns_apos,
234-
n_rep=n_rep_apos,
233+
display_cols=display_columns_apos_tune,
234+
n_rep=n_rep_apos_tune,
235235
level_col="level",
236236
coverage_highlight_cols=["Coverage", "Uniform Coverage"]
237237
)
@@ -242,10 +242,10 @@ generate_and_show_styled_table(
242242
#| echo: false
243243
244244
generate_and_show_styled_table(
245-
main_df=df_apos,
245+
main_df=df_apos_tune,
246246
filters={"level": 0.9},
247-
display_cols=display_columns_apos,
248-
n_rep=n_rep_apos,
247+
display_cols=display_columns_apos_tune,
248+
n_rep=n_rep_apos_tune,
249249
level_col="level",
250250
coverage_highlight_cols=["Coverage", "Uniform Coverage"]
251251
)
@@ -254,7 +254,7 @@ generate_and_show_styled_table(
254254

255255
### Causal Contrast Coverage
256256

257-
The non-uniform results (coverage, ci length and bias) refer to averaged values over all quantiles (point-wise confidende intervals).
257+
The non-uniform results (coverage, ci length and bias) refer to averaged values over all quantiles (point-wise confidende intervals). The same holds for the loss values which are averaged over all treatment levels.
258258

259259

260260
::: {.callout-note title="Metadata" collapse="true"}
@@ -272,22 +272,22 @@ print(metadata_df.T.to_string(header=False))
272272
#| echo: false
273273
274274
# set up data
275-
df_contrast = pd.read_csv("../../results/irm/apos_tune_causal_contrast.csv", index_col=None)
275+
df_contrast_tune = pd.read_csv("../../results/irm/apos_tune_causal_contrast.csv", index_col=None)
276276
277-
assert df_contrast["repetition"].nunique() == 1
278-
n_rep_contrast = df_contrast["repetition"].unique()[0]
277+
assert df_contrast_tune["repetition"].nunique() == 1
278+
n_rep_contrast_tune = df_contrast_tune["repetition"].unique()[0]
279279
280-
display_columns_contrast = ["Learner g", "Learner m", "Tuned", "Bias", "CI Length", "Coverage", "Uniform CI Length", "Uniform Coverage"]
280+
display_columns_contrast_tune = ["Learner g", "Learner m", "Tuned", "Bias", "CI Length", "Coverage", "Uniform CI Length", "Uniform Coverage", "Loss g_control", "Loss g_treated", "Loss m"]
281281
```
282282

283283
```{python}
284284
#| echo: false
285285
286286
generate_and_show_styled_table(
287-
main_df=df_contrast,
287+
main_df=df_contrast_tune,
288288
filters={"level": 0.95},
289-
display_cols=display_columns_contrast,
290-
n_rep=n_rep_contrast,
289+
display_cols=display_columns_contrast_tune,
290+
n_rep=n_rep_contrast_tune,
291291
level_col="level",
292292
coverage_highlight_cols=["Coverage", "Uniform Coverage"]
293293
)
@@ -298,10 +298,10 @@ generate_and_show_styled_table(
298298
#| echo: false
299299
300300
generate_and_show_styled_table(
301-
main_df=df_contrast,
301+
main_df=df_contrast_tune,
302302
filters={"level": 0.9},
303-
display_cols=display_columns_contrast,
304-
n_rep=n_rep_contrast,
303+
display_cols=display_columns_contrast_tune,
304+
n_rep=n_rep_contrast_tune,
305305
level_col="level",
306306
coverage_highlight_cols=["Coverage", "Uniform Coverage"]
307307
)

monte-cover/src/montecover/irm/apos_tune.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from montecover.base import BaseSimulation
1010
from montecover.utils import create_learner_from_config
11+
from montecover.utils_tuning import lgbm_reg_params, lgbm_cls_params
1112

1213

1314
class APOSTuningCoverageSimulation(BaseSimulation):
@@ -31,37 +32,7 @@ def __init__(
3132
self._calculate_oracle_values()
3233

3334
# tuning specific settings
34-
# parameter space for the outcome regression tuning
35-
def ml_g_params(trial):
36-
return {
37-
"n_estimators": trial.suggest_int("n_estimators", 100, 200, step=50),
38-
"learning_rate": trial.suggest_float(
39-
"learning_rate", 1e-3, 0.1, log=True
40-
),
41-
"min_child_samples": trial.suggest_int(
42-
"min_child_samples", 20, 50, step=5
43-
),
44-
"max_depth": 5,
45-
"lambda_l1": trial.suggest_float("lambda_l1", 1e-3, 10.0, log=True),
46-
"lambda_l2": trial.suggest_float("lambda_l2", 1e-3, 10.0, log=True),
47-
}
48-
49-
# parameter space for the propensity score tuning
50-
def ml_m_params(trial):
51-
return {
52-
"n_estimators": trial.suggest_int("n_estimators", 100, 200, step=50),
53-
"learning_rate": trial.suggest_float(
54-
"learning_rate", 1e-3, 0.1, log=True
55-
),
56-
"min_child_samples": trial.suggest_int(
57-
"min_child_samples", 20, 50, step=5
58-
),
59-
"max_depth": 5,
60-
"lambda_l1": trial.suggest_float("lambda_l1", 1e-3, 10.0, log=True),
61-
"lambda_l2": trial.suggest_float("lambda_l2", 1e-3, 10.0, log=True),
62-
}
63-
64-
self._param_space = {"ml_g": ml_g_params, "ml_m": ml_m_params}
35+
self._param_space = {"ml_g": lgbm_reg_params, "ml_m": lgbm_cls_params}
6536

6637
self._optuna_settings = {
6738
"n_trials": 200,
@@ -155,6 +126,18 @@ def run_single_rep(
155126
model.bootstrap(n_rep_boot=2000)
156127
causal_contrast_model = model.causal_contrast(reference_levels=0)
157128
causal_contrast_model.bootstrap(n_rep_boot=2000)
129+
130+
# average all nuisance losses over treatment levels
131+
n_lvls = len(model.modellist)
132+
loss_dict = {
133+
"ml_g_d_lvl0": np.full(n_lvls, np.nan),
134+
"ml_g_d_lvl1": np.full(n_lvls, np.nan),
135+
"ml_m": np.full(n_lvls, np.nan)
136+
}
137+
for key in loss_dict.keys():
138+
for i_submodel, submodel in enumerate(model.modellist):
139+
loss_dict[key][i_submodel] = submodel.nuisance_loss[key].mean()
140+
158141
for level in self.confidence_parameters["level"]:
159142
level_result = dict()
160143
level_result["coverage"] = self._compute_coverage(
@@ -180,6 +163,9 @@ def run_single_rep(
180163
"Learner m": learner_m_name,
181164
"level": level,
182165
"Tuned": model is dml_model_tuned,
166+
"Loss g_control": loss_dict["ml_g_d_lvl0"].mean(),
167+
"Loss g_treated": loss_dict["ml_g_d_lvl1"].mean(),
168+
"Loss m": loss_dict["ml_m"].mean(),
183169
}
184170
)
185171
for key, res in level_result.items():
@@ -199,6 +185,9 @@ def summarize_results(self):
199185
"Bias": "mean",
200186
"Uniform Coverage": "mean",
201187
"Uniform CI Length": "mean",
188+
"Loss g_control": "mean",
189+
"Loss g_treated": "mean",
190+
"Loss m": "mean",
202191
"repetition": "count",
203192
}
204193

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition
2-
LGBM Regr.,LGBM Clas.,0.9,False,0.915,37.253979808129365,8.903634313073434,0.95,44.16923532073818,200
3-
LGBM Regr.,LGBM Clas.,0.9,True,0.835,4.834836017081165,1.3591722777708926,0.815,5.703081839207788,200
4-
LGBM Regr.,LGBM Clas.,0.95,False,0.98,44.39085491153563,8.903634313073434,0.985,50.67060073707776,200
5-
LGBM Regr.,LGBM Clas.,0.95,True,0.915,5.761062449185174,1.3591722777708926,0.895,6.546100271377481,200
1+
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,Loss g_control,Loss g_treated,Loss m,repetition
2+
LGBM Regr.,LGBM Clas.,0.9,False,0.905,37.557614215103456,9.702752376788483,0.93,44.43580521211171,10.231838281825558,13.632270699354638,0.7977017509425842,200
3+
LGBM Regr.,LGBM Clas.,0.9,True,0.8625,4.281908161317747,1.1456270552515193,0.885,5.058837856166666,9.74905255354395,11.553230793169227,0.6041491925910187,200
4+
LGBM Regr.,LGBM Clas.,0.95,False,0.9625,44.75265762296555,9.702752376788483,0.975,51.03212790722511,10.231838281825558,13.632270699354638,0.7977017509425842,200
5+
LGBM Regr.,LGBM Clas.,0.95,True,0.945,5.102208271774997,1.1456270552515193,0.95,5.809070370902955,9.74905255354395,11.553230793169227,0.6041491925910187,200

results/irm/apos_tune_coverage.csv

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition
2-
LGBM Regr.,LGBM Clas.,0.9,False,0.93,27.82005877126217,6.439524063849977,0.96,35.676118532942496,200
3-
LGBM Regr.,LGBM Clas.,0.9,True,0.885,6.300962875030208,1.5916287837149021,0.88,7.710244822536492,200
4-
LGBM Regr.,LGBM Clas.,0.95,False,0.9766666666666667,33.14964465289176,6.439524063849977,0.975,40.33245838311548,200
5-
LGBM Regr.,LGBM Clas.,0.95,True,0.9466666666666668,7.50806035298818,1.5916287837149021,0.94,8.829514768318946,200
1+
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,Loss g_control,Loss g_treated,Loss m,repetition
2+
LGBM Regr.,LGBM Clas.,0.9,False,0.9133333333333333,28.052950188600192,7.055413388225961,0.945,35.926212976449165,10.231838281825558,13.632270699354638,0.7977017509425842,200
3+
LGBM Regr.,LGBM Clas.,0.9,True,0.8866666666666667,6.138357605002483,1.524949927232772,0.865,7.417779175659066,9.74905255354395,11.553230793169227,0.6041491925910187,200
4+
LGBM Regr.,LGBM Clas.,0.95,False,0.9766666666666667,33.42715189293536,7.055413388225961,0.98,40.556851028772705,10.231838281825558,13.632270699354638,0.7977017509425842,200
5+
LGBM Regr.,LGBM Clas.,0.95,True,0.945,7.31430422312426,1.524949927232772,0.94,8.537534779534262,9.74905255354395,11.553230793169227,0.6041491925910187,200

results/irm/apos_tune_metadata.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
2-
0.12.dev0,APOSTuningCoverageSimulation,2025-11-26 11:38,31.90539586544037,3.12.9,scripts/irm/apos_tune_config.yml
2+
0.12.dev0,APOSTuningCoverageSimulation,2025-12-01 13:09,38.63118334611257,3.12.9,scripts/irm/apos_tune_config.yml

0 commit comments

Comments
 (0)