Skip to content

Commit 46522ac

Browse files
committed
rerun lplr sim and add tuning with 100 reps and 200 trials
1 parent e21b357 commit 46522ac

File tree

12 files changed

+265
-18
lines changed

12 files changed

+265
-18
lines changed

monte-cover/src/montecover/plm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from montecover.plm.plr_cate import PLRCATECoverageSimulation
88
from montecover.plm.plr_gate import PLRGATECoverageSimulation
99
from montecover.plm.lplr_ate import LPLRATECoverageSimulation
10+
from montecover.plm.lplr_ate_tune import LPLRATETuningCoverageSimulation
1011

1112
__all__ = [
1213
"PLRATECoverageSimulation",
@@ -16,4 +17,5 @@
1617
"PLRATESensitivityCoverageSimulation",
1718
"PLRATETuningCoverageSimulation",
1819
"LPLRATECoverageSimulation",
20+
"LPLRATETuningCoverageSimulation",
1921
]

monte-cover/src/montecover/plm/lplr_ate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from typing import Any, Dict, Optional
32

43
import doubleml as dml
@@ -46,7 +45,7 @@ def _calculate_oracle_values(self):
4645
self.logger.info("Calculating oracle values")
4746

4847
self.oracle_values = dict()
49-
self.oracle_values["theta"] = self.dgp_parameters["theta"]
48+
self.oracle_values["theta"] = self.dgp_parameters["alpha"]
5049

5150
def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
5251
"""Run a single repetition with the given parameters."""
@@ -64,7 +63,8 @@ def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
6463
ml_M=ml_M,
6564
ml_t=ml_t,
6665
score=score,
67-
error_on_convergence_failure= not self._use_failed_scores,)
66+
error_on_convergence_failure=(not self._use_failed_scores),
67+
)
6868

6969
try:
7070
dml_model.fit()
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from typing import Any, Dict, Optional
2+
import optuna
3+
4+
import doubleml as dml
5+
from doubleml.plm.datasets import make_lplr_LZZ2020
6+
7+
from montecover.base import BaseSimulation
8+
from montecover.utils import create_learner_from_config
9+
10+
11+
class LPLRATETuningCoverageSimulation(BaseSimulation):
12+
"""Simulation class for coverage properties of DoubleMLPLR for ATE estimation."""
13+
14+
def __init__(
15+
self,
16+
config_file: str,
17+
suppress_warnings: bool = True,
18+
log_level: str = "INFO",
19+
log_file: Optional[str] = None,
20+
use_failed_scores: bool = False,
21+
):
22+
super().__init__(
23+
config_file=config_file,
24+
suppress_warnings=suppress_warnings,
25+
log_level=log_level,
26+
log_file=log_file,
27+
)
28+
29+
# Calculate oracle values
30+
self._calculate_oracle_values()
31+
self._use_failed_scores = use_failed_scores
32+
33+
# for simplicity, we use the same parameter space for all learners
34+
def ml_params(trial):
35+
return {
36+
'n_estimators': trial.suggest_int('n_estimators', 100, 500, step=50),
37+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
38+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 100, step=5),
39+
'max_depth': trial.suggest_int('max_depth', 3, 10, step=1),
40+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-8, 10.0, log=True),
41+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-8, 10.0, log=True),
42+
}
43+
44+
self._param_space = {
45+
'ml_M': ml_params,
46+
'ml_t': ml_params,
47+
'ml_m': ml_params,
48+
'ml_a': ml_params,
49+
}
50+
51+
self._optuna_settings = {
52+
'n_trials': 200,
53+
'show_progress_bar': False,
54+
'verbosity': optuna.logging.WARNING, # Suppress Optuna logs
55+
}
56+
57+
def _process_config_parameters(self):
58+
"""Process simulation-specific parameters from config"""
59+
# Process ML models in parameter grid
60+
assert "learners" in self.dml_parameters, "No learners specified in the config file"
61+
62+
required_learners = ["ml_m", "ml_M", "ml_t"]
63+
for learner in self.dml_parameters["learners"]:
64+
for ml in required_learners:
65+
assert ml in learner, f"No {ml} specified in the config file"
66+
67+
def _calculate_oracle_values(self):
68+
"""Calculate oracle values for the simulation."""
69+
self.logger.info("Calculating oracle values")
70+
71+
self.oracle_values = dict()
72+
self.oracle_values["theta"] = self.dgp_parameters["alpha"]
73+
74+
def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
75+
"""Run a single repetition with the given parameters."""
76+
# Extract parameters
77+
learner_config = dml_params["learners"]
78+
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
79+
learner_M_name, ml_M = create_learner_from_config(learner_config["ml_M"])
80+
learner_t_name, ml_t = create_learner_from_config(learner_config["ml_t"])
81+
score = dml_params["score"]
82+
83+
model_inputs = {
84+
"obj_dml_data": dml_data,
85+
"ml_m": ml_m,
86+
"ml_M": ml_M,
87+
"ml_t": ml_t,
88+
"score": score,
89+
"error_on_convergence_failure": not self._use_failed_scores,
90+
91+
}
92+
# Model
93+
dml_model = dml.DoubleMLLPLR(**model_inputs)
94+
dml_model_tuned = dml.DoubleMLLPLR(**model_inputs)
95+
dml_model_tuned.tune_ml_models(
96+
ml_param_space=self._param_space,
97+
optuna_settings=self._optuna_settings,
98+
)
99+
100+
result = {
101+
"coverage": [],
102+
}
103+
104+
for model in [dml_model, dml_model_tuned]:
105+
try:
106+
model.fit()
107+
except RuntimeError as e:
108+
self.logger.info(f"Exception during fit: {e}")
109+
return None
110+
111+
for level in self.confidence_parameters["level"]:
112+
level_result = dict()
113+
level_result["coverage"] = self._compute_coverage(
114+
thetas=model.coef,
115+
oracle_thetas=self.oracle_values["theta"],
116+
confint=model.confint(level=level),
117+
joint_confint=None,
118+
)
119+
120+
# add parameters to the result
121+
for res in level_result.values():
122+
res.update(
123+
{
124+
"Learner m": learner_m_name,
125+
"Learner M": learner_M_name,
126+
"Learner t": learner_t_name,
127+
"Score": score,
128+
"level": level,
129+
"Tuned": model is dml_model_tuned,
130+
}
131+
)
132+
for key, res in level_result.items():
133+
result[key].append(res)
134+
135+
return result
136+
137+
def summarize_results(self):
138+
"""Summarize the simulation results."""
139+
self.logger.info("Summarizing simulation results")
140+
141+
# Group by parameter combinations
142+
groupby_cols = ["Learner m", "Learner M", "Learner t", "Score", "level", "Tuned"]
143+
aggregation_dict = {
144+
"Coverage": "mean",
145+
"CI Length": "mean",
146+
"Bias": "mean",
147+
"repetition": "count",
148+
}
149+
150+
# Aggregate results (possibly multiple result dfs)
151+
result_summary = dict()
152+
for result_name, result_df in self.results.items():
153+
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
154+
self.logger.debug(f"Summarized {result_name} results")
155+
156+
return result_summary
157+
158+
def _generate_dml_data(self, dgp_params) -> dml.DoubleMLData:
159+
"""Generate data for the simulation."""
160+
return make_lplr_LZZ2020(**dgp_params)

results/plm/lplr_ate_config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ simulation_parameters:
44
random_seed: 42
55
n_jobs: -2
66
dgp_parameters:
7-
theta:
7+
alpha:
88
- 0.5
99
n_obs:
1010
- 500

results/plm/lplr_ate_coverage.csv

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
Learner m,Learner M,Learner t,Score,level,Coverage,CI Length,Bias,repetition
2-
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.9,0.872,0.6540916267945179,0.17501445022837125,500
3-
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.95,0.928,0.7793982455949509,0.17501445022837125,500
4-
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,0.88,0.598241346108922,0.15586913796966942,500
5-
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,0.946,0.7128485314583201,0.15586913796966942,500
6-
LassoCV,Logistic,LassoCV,instrument,0.9,0.856,0.5890452894815547,0.16482024691605957,500
7-
LassoCV,Logistic,LassoCV,instrument,0.95,0.924,0.7018907541253692,0.16482024691605957,500
8-
LassoCV,Logistic,LassoCV,nuisance_space,0.9,0.868,0.5820699058557912,0.1507959338822808,500
9-
LassoCV,Logistic,LassoCV,nuisance_space,0.95,0.93,0.6935790718815301,0.1507959338822808,500
10-
RF Regr.,RF Clas.,RF Regr.,instrument,0.9,0.884,0.39484117997902796,0.09883032061915417,500
11-
RF Regr.,RF Clas.,RF Regr.,instrument,0.95,0.95,0.4704822846799266,0.09883032061915417,500
12-
RF Regr.,RF Clas.,RF Regr.,nuisance_space,0.9,0.886,0.38499391911236014,0.09772003875711463,500
13-
RF Regr.,RF Clas.,RF Regr.,nuisance_space,0.95,0.94,0.45874854963578754,0.09772003875711463,500
2+
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.9,0.866,0.6573798859045776,0.17600558265832575,500
3+
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.95,0.936,0.7833164479942107,0.17600558265832575,500
4+
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,0.89,0.5881153537384244,0.15332249272864673,500
5+
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,0.936,0.700782667342079,0.15332249272864673,500
6+
LassoCV,Logistic,LassoCV,instrument,0.9,0.858,0.5897233516083383,0.16268441455635813,500
7+
LassoCV,Logistic,LassoCV,instrument,0.95,0.916,0.7026987149834061,0.16268441455635813,500
8+
LassoCV,Logistic,LassoCV,nuisance_space,0.9,0.8937875751503006,0.576947311075238,0.1492081384708213,499
9+
LassoCV,Logistic,LassoCV,nuisance_space,0.95,0.9278557114228457,0.6874751237169234,0.1492081384708213,499
10+
RF Regr.,RF Clas.,RF Regr.,instrument,0.9,0.902,0.39485055228075816,0.09886061010323771,500
11+
RF Regr.,RF Clas.,RF Regr.,instrument,0.95,0.942,0.4704934524662526,0.09886061010323771,500
12+
RF Regr.,RF Clas.,RF Regr.,nuisance_space,0.9,0.892,0.38461199091029774,0.09604302638290617,500
13+
RF Regr.,RF Clas.,RF Regr.,nuisance_space,0.95,0.942,0.4582934541133308,0.09604302638290617,500

results/plm/lplr_ate_metadata.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
2-
0.11.dev0,LPLRATECoverageSimulation,2025-11-18 03:13,39.79484195311864,3.12.9,scripts/plm/lplr_ate_config.yml
2+
0.12.dev0,LPLRATECoverageSimulation,2025-11-26 13:24,14.800051196416218,3.12.9,scripts/plm/lplr_ate_config.yml
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
simulation_parameters:
2+
repetitions: 100
3+
max_runtime: 19800
4+
random_seed: 42
5+
n_jobs: -2
6+
dgp_parameters:
7+
alpha:
8+
- 0.5
9+
n_obs:
10+
- 500
11+
dim_x:
12+
- 20
13+
learner_definitions:
14+
lgbm: &id001
15+
name: LGBM Regr.
16+
lgbm-class: &id002
17+
name: LGBM Clas.
18+
dml_parameters:
19+
learners:
20+
- ml_m: *id001
21+
ml_M: *id002
22+
ml_t: *id001
23+
score:
24+
- nuisance_space
25+
- instrument
26+
confidence_parameters:
27+
level:
28+
- 0.95
29+
- 0.9
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Learner m,Learner M,Learner t,Score,level,Tuned,Coverage,CI Length,Bias,repetition
2+
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.9,False,0.91,0.9117258212067718,0.240354871477558,100
3+
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.9,True,0.95,0.8692681775643711,0.2054770002796413,100
4+
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.95,False,0.98,1.0863883229855305,0.240354871477558,100
5+
LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.95,True,0.96,1.0357969201737371,0.2054770002796413,100
6+
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,False,0.91,0.7841573908306078,0.18430486050109982,100
7+
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,True,0.86,0.7221800622589235,0.1665060542122647,100
8+
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,False,0.95,0.9343811625885382,0.18430486050109982,100
9+
LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,True,0.93,0.8605306205900738,0.1665060542122647,100
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
2+
0.12.dev0,LPLRATETuningCoverageSimulation,2025-11-26 17:47,44.12576818863551,3.12.9,scripts/plm/lplr_ate_tune_config.yml

scripts/plm/lplr_ate_config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ simulation_parameters:
77
n_jobs: -2
88

99
dgp_parameters:
10-
theta: [0.5] # Treatment effect
10+
alpha: [0.5] # Treatment effect
1111
n_obs: [500] # Sample size
1212
dim_x: [20] # Number of covariates
1313

0 commit comments

Comments
 (0)