Skip to content

Commit e21b357

Browse files
committed
first apos tuning sim
1 parent 27fa90e commit e21b357

File tree

9 files changed

+304
-1
lines changed

9 files changed

+304
-1
lines changed

monte-cover/src/montecover/irm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from montecover.irm.apo import APOCoverageSimulation
44
from montecover.irm.apos import APOSCoverageSimulation
5+
from montecover.irm.apos_tune import APOSTuningCoverageSimulation
56
from montecover.irm.cvar import CVARCoverageSimulation
67
from montecover.irm.iivm_late import IIVMLATECoverageSimulation
78
from montecover.irm.irm_ate import IRMATECoverageSimulation
@@ -17,6 +18,7 @@
1718
__all__ = [
1819
"APOCoverageSimulation",
1920
"APOSCoverageSimulation",
21+
"APOSTuningCoverageSimulation",
2022
"CVARCoverageSimulation",
2123
"IRMATECoverageSimulation",
2224
"IRMATETuningCoverageSimulation",
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from typing import Any, Dict, Optional
2+
import optuna
3+
4+
import doubleml as dml
5+
import numpy as np
6+
import pandas as pd
7+
from doubleml.irm.datasets import make_irm_data_discrete_treatments
8+
9+
from montecover.base import BaseSimulation
10+
from montecover.utils import create_learner_from_config
11+
12+
13+
class APOSTuningCoverageSimulation(BaseSimulation):
14+
"""Simulation class for coverage properties of DoubleMLAPOs for APO estimation with tuning."""
15+
16+
def __init__(
17+
self,
18+
config_file: str,
19+
suppress_warnings: bool = True,
20+
log_level: str = "INFO",
21+
log_file: Optional[str] = None,
22+
):
23+
super().__init__(
24+
config_file=config_file,
25+
suppress_warnings=suppress_warnings,
26+
log_level=log_level,
27+
log_file=log_file,
28+
)
29+
30+
# Calculate oracle values
31+
self._calculate_oracle_values()
32+
33+
# tuning specific settings
34+
# parameter space for the outcome regression tuning
35+
def ml_g_params(trial):
36+
return {
37+
'n_estimators': trial.suggest_int('n_estimators', 100, 200, step=50),
38+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
39+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 50, step=5),
40+
'max_depth': 5,
41+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-3, 10.0, log=True),
42+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-3, 10.0, log=True),
43+
}
44+
45+
# parameter space for the propensity score tuning
46+
def ml_m_params(trial):
47+
return {
48+
'n_estimators': trial.suggest_int('n_estimators', 100, 200, step=50),
49+
'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.1, log=True),
50+
'min_child_samples': trial.suggest_int('min_child_samples', 20, 50, step=5),
51+
'max_depth': 5,
52+
'lambda_l1': trial.suggest_float('lambda_l1', 1e-3, 10.0, log=True),
53+
'lambda_l2': trial.suggest_float('lambda_l2', 1e-3, 10.0, log=True),
54+
}
55+
56+
self._param_space = {
57+
'ml_g': ml_g_params,
58+
'ml_m': ml_m_params
59+
}
60+
61+
self._optuna_settings = {
62+
'n_trials': 200,
63+
'show_progress_bar': False,
64+
'verbosity': optuna.logging.WARNING, # Suppress Optuna logs
65+
}
66+
67+
def _process_config_parameters(self):
68+
"""Process simulation-specific parameters from config"""
69+
# Process ML models in parameter grid
70+
assert "learners" in self.dml_parameters, "No learners specified in the config file"
71+
72+
required_learners = ["ml_g", "ml_m"]
73+
for learner in self.dml_parameters["learners"]:
74+
for ml in required_learners:
75+
assert ml in learner, f"No {ml} specified in the config file"
76+
77+
def _calculate_oracle_values(self):
78+
"""Calculate oracle values for the simulation."""
79+
self.logger.info("Calculating oracle values")
80+
81+
n_levels = self.dgp_parameters["n_levels"][0]
82+
data_apo_oracle = make_irm_data_discrete_treatments(
83+
n_obs=int(1e6), n_levels=n_levels, linear=self.dgp_parameters["linear"][0]
84+
)
85+
86+
y0 = data_apo_oracle["oracle_values"]["y0"]
87+
ite = data_apo_oracle["oracle_values"]["ite"]
88+
d = data_apo_oracle["d"]
89+
90+
average_ites = np.full(n_levels + 1, np.nan)
91+
apos = np.full(n_levels + 1, np.nan)
92+
for i in range(n_levels + 1):
93+
average_ites[i] = np.mean(ite[d == i]) * (i > 0)
94+
apos[i] = np.mean(y0) + average_ites[i]
95+
96+
ates = np.full(n_levels, np.nan)
97+
for i in range(n_levels):
98+
ates[i] = apos[i + 1] - apos[0]
99+
100+
self.logger.info(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
101+
self.logger.info(f"True APOs: {apos}")
102+
self.logger.info(f"True ATEs: {ates}")
103+
104+
self.oracle_values = dict()
105+
self.oracle_values["apos"] = apos
106+
self.oracle_values["ates"] = ates
107+
108+
def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]) -> Dict[str, Any]:
109+
"""Run a single repetition with the given parameters."""
110+
# Extract parameters
111+
learner_config = dml_params["learners"]
112+
learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"])
113+
learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"])
114+
treatment_levels = dml_params["treatment_levels"]
115+
trimming_threshold = dml_params["trimming_threshold"]
116+
117+
# Model
118+
dml_model = dml.DoubleMLAPOS(
119+
obj_dml_data=dml_data,
120+
ml_g=ml_g,
121+
ml_m=ml_m,
122+
treatment_levels=treatment_levels,
123+
trimming_threshold=trimming_threshold,
124+
)
125+
# Tuning
126+
dml_model_tuned = dml.DoubleMLAPOS(
127+
obj_dml_data=dml_data,
128+
ml_g=ml_g,
129+
ml_m=ml_m,
130+
treatment_levels=treatment_levels,
131+
trimming_threshold=trimming_threshold,
132+
)
133+
dml_model_tuned.tune_ml_models(
134+
ml_param_space=self._param_space,
135+
optuna_settings=self._optuna_settings,
136+
)
137+
138+
result = {
139+
"coverage": [],
140+
"causal_contrast": [],
141+
}
142+
for model in [dml_model, dml_model_tuned]:
143+
model.fit()
144+
model.bootstrap(n_rep_boot=2000)
145+
causal_contrast_model = model.causal_contrast(reference_levels=0)
146+
causal_contrast_model.bootstrap(n_rep_boot=2000)
147+
for level in self.confidence_parameters["level"]:
148+
level_result = dict()
149+
level_result["coverage"] = self._compute_coverage(
150+
thetas=model.coef,
151+
oracle_thetas=self.oracle_values["apos"],
152+
confint=model.confint(level=level),
153+
joint_confint=model.confint(level=level, joint=True),
154+
)
155+
level_result["causal_contrast"] = self._compute_coverage(
156+
thetas=causal_contrast_model.thetas,
157+
oracle_thetas=self.oracle_values["ates"],
158+
confint=causal_contrast_model.confint(level=level),
159+
joint_confint=causal_contrast_model.confint(level=level, joint=True),
160+
)
161+
162+
# add parameters to the result
163+
for res_metric in level_result.values():
164+
res_metric.update(
165+
{
166+
"Learner g": learner_g_name,
167+
"Learner m": learner_m_name,
168+
"level": level,
169+
"Tuned": model is dml_model_tuned,
170+
}
171+
)
172+
for key, res in level_result.items():
173+
result[key].append(res)
174+
175+
return result
176+
177+
def summarize_results(self):
178+
"""Summarize the simulation results."""
179+
self.logger.info("Summarizing simulation results")
180+
181+
# Group by parameter combinations
182+
groupby_cols = ["Learner g", "Learner m", "level", "Tuned"]
183+
aggregation_dict = {
184+
"Coverage": "mean",
185+
"CI Length": "mean",
186+
"Bias": "mean",
187+
"Uniform Coverage": "mean",
188+
"Uniform CI Length": "mean",
189+
"repetition": "count",
190+
}
191+
192+
# Aggregate results (possibly multiple result dfs)
193+
result_summary = dict()
194+
for result_name, result_df in self.results.items():
195+
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
196+
self.logger.debug(f"Summarized {result_name} results")
197+
198+
return result_summary
199+
200+
def _generate_dml_data(self, dgp_params: Dict[str, Any]) -> dml.DoubleMLData:
201+
"""Generate data for the simulation."""
202+
data = make_irm_data_discrete_treatments(
203+
n_obs=dgp_params["n_obs"],
204+
n_levels=dgp_params["n_levels"],
205+
linear=dgp_params["linear"],
206+
)
207+
df_apo = pd.DataFrame(
208+
np.column_stack((data["y"], data["d"], data["x"])),
209+
columns=["y", "d"] + ["x" + str(i) for i in range(data["x"].shape[1])],
210+
)
211+
dml_data = dml.DoubleMLData(df_apo, "y", "d")
212+
return dml_data

monte-cover/src/montecover/irm/irm_ate_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any])
9898
obj_dml_data=dml_data,
9999
ml_g=ml_g,
100100
ml_m=ml_m,
101-
)
101+
)
102102
dml_model_tuned.tune_ml_models(
103103
ml_param_space=self._param_space,
104104
optuna_settings=self._optuna_settings,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition
2+
LGBM Regr.,LGBM Clas.,0.9,False,0.915,37.253979808129365,8.903634313073434,0.95,44.16923532073818,200
3+
LGBM Regr.,LGBM Clas.,0.9,True,0.835,4.834836017081165,1.3591722777708926,0.815,5.703081839207788,200
4+
LGBM Regr.,LGBM Clas.,0.95,False,0.98,44.39085491153563,8.903634313073434,0.985,50.67060073707776,200
5+
LGBM Regr.,LGBM Clas.,0.95,True,0.915,5.761062449185174,1.3591722777708926,0.895,6.546100271377481,200

results/irm/apos_tune_config.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
simulation_parameters:
2+
repetitions: 200
3+
max_runtime: 19800
4+
random_seed: 42
5+
n_jobs: -2
6+
dgp_parameters:
7+
n_obs:
8+
- 500
9+
n_levels:
10+
- 2
11+
linear:
12+
- true
13+
learner_definitions:
14+
lgbmr: &id001
15+
name: LGBM Regr.
16+
lgbmc: &id002
17+
name: LGBM Clas.
18+
dml_parameters:
19+
treatment_levels:
20+
- - 0
21+
- 1
22+
- 2
23+
trimming_threshold:
24+
- 0.01
25+
learners:
26+
- ml_g: *id001
27+
ml_m: *id002
28+
confidence_parameters:
29+
level:
30+
- 0.95
31+
- 0.9

results/irm/apos_tune_coverage.csv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Learner g,Learner m,level,Tuned,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition
2+
LGBM Regr.,LGBM Clas.,0.9,False,0.93,27.82005877126217,6.439524063849977,0.96,35.676118532942496,200
3+
LGBM Regr.,LGBM Clas.,0.9,True,0.885,6.300962875030208,1.5916287837149021,0.88,7.710244822536492,200
4+
LGBM Regr.,LGBM Clas.,0.95,False,0.9766666666666667,33.14964465289176,6.439524063849977,0.975,40.33245838311548,200
5+
LGBM Regr.,LGBM Clas.,0.95,True,0.9466666666666668,7.50806035298818,1.5916287837149021,0.94,8.829514768318946,200

results/irm/apos_tune_metadata.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File
2+
0.12.dev0,APOSTuningCoverageSimulation,2025-11-26 11:38,31.90539586544037,3.12.9,scripts/irm/apos_tune_config.yml

scripts/irm/apos_tune.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from montecover.irm import APOSTuningCoverageSimulation
2+
3+
# Create and run simulation with config file
4+
sim = APOSTuningCoverageSimulation(
5+
config_file="scripts/irm/apos_tune_config.yml",
6+
log_level="INFO",
7+
log_file="logs/irm/apos_tune_sim.log",
8+
)
9+
sim.run_simulation()
10+
sim.save_results(output_path="results/irm/", file_prefix="apos_tune")
11+
12+
# Save config file for reproducibility
13+
sim.save_config("results/irm/apos_tune_config.yml")

scripts/irm/apos_tune_config.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Simulation parameters for APOS Coverage
2+
3+
simulation_parameters:
4+
repetitions: 200
5+
max_runtime: 19800 # 5.5 hours in seconds
6+
random_seed: 42
7+
n_jobs: -2
8+
9+
dgp_parameters:
10+
n_obs: [500] # Sample size
11+
n_levels: [2]
12+
linear: [True]
13+
14+
# Define reusable learner configurations
15+
learner_definitions:
16+
lgbmr: &lgbmr
17+
name: "LGBM Regr."
18+
19+
20+
lgbmc: &lgbmc
21+
name: "LGBM Clas."
22+
23+
dml_parameters:
24+
treatment_levels: [[0, 1, 2]]
25+
trimming_threshold: [0.01]
26+
learners:
27+
- ml_g: *lgbmr
28+
ml_m: *lgbmc
29+
30+
31+
32+
confidence_parameters:
33+
level: [0.95, 0.90] # Confidence levels

0 commit comments

Comments
 (0)