Skip to content

Commit 454361c

Browse files
committed
formatting
1 parent 46522ac commit 454361c

35 files changed

+580
-245
lines changed

.github/workflows/apo_sim.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
uses: astral-sh/setup-uv@v5
5353
with:
5454
version: "0.7.8"
55-
55+
5656
- name: Set up Python
5757
uses: actions/setup-python@v5
5858
with:

.github/workflows/pliv_sim.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
cd monte-cover
6363
uv venv
6464
uv sync
65-
65+
6666
- name: Install DoubleML from correct branch
6767
run: |
6868
source monte-cover/.venv/bin/activate

doc/irm/irm.qmd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ print(metadata_df.T.to_string(header=False))
3737

3838
:::
3939

40-
### ATE
40+
### ATE
4141

4242
```{python}
4343
#| echo: false
@@ -264,7 +264,7 @@ print(metadata_df.T.to_string(header=False))
264264

265265
:::
266266

267-
### ATE
267+
### ATE
268268

269269
```{python}
270270
#| echo: false
@@ -304,4 +304,4 @@ generate_and_show_styled_table(
304304
level_col="level",
305305
coverage_highlight_cols=["Coverage"]
306306
)
307-
```
307+
```

doc/plm/plr.qmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,4 @@ generate_and_show_styled_table(
254254
rename_map={"Learner g": "Learner l"},
255255
coverage_highlight_cols=["Coverage"]
256256
)
257-
```
257+
```

monte-cover/src/montecover/base.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def run_simulation(self, n_jobs=None):
107107

108108
rep_end_time = time.time()
109109
rep_duration = rep_end_time - rep_start_time
110-
self.logger.info(f"Repetition {i_rep+1} completed in {rep_duration:.2f}s")
110+
self.logger.info(
111+
f"Repetition {i_rep+1} completed in {rep_duration:.2f}s"
112+
)
111113

112114
else:
113115
self.logger.info(f"Starting parallel execution with n_jobs={n_jobs}")
@@ -138,7 +140,9 @@ def save_results(self, output_path: str = "results", file_prefix: str = ""):
138140
"Script": [self.__class__.__name__],
139141
"Date": [datetime.now().strftime("%Y-%m-%d %H:%M")],
140142
"Total Runtime (minutes)": [self.total_runtime / 60],
141-
"Python Version": [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"],
143+
"Python Version": [
144+
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
145+
],
142146
"Config File": [self.config_file],
143147
}
144148
)
@@ -161,7 +165,14 @@ def save_config(self, output_path: str):
161165
self.logger.warning(f"Adding .yaml extension to output path: {output_path}")
162166

163167
with open(output_path, "w") as file:
164-
yaml.dump(self.config, file, sort_keys=False, default_flow_style=False, indent=2, allow_unicode=True)
168+
yaml.dump(
169+
self.config,
170+
file,
171+
sort_keys=False,
172+
default_flow_style=False,
173+
indent=2,
174+
allow_unicode=True,
175+
)
165176

166177
self.logger.info(f"Configuration saved to {output_path}")
167178

@@ -174,7 +185,9 @@ def _load_config(self, config_path: str) -> Dict[str, Any]:
174185
with open(config_path, "r") as file:
175186
config = yaml.safe_load(file)
176187
else:
177-
raise ValueError(f"Unsupported config file format: {config_path}. Use .yaml or .yml")
188+
raise ValueError(
189+
f"Unsupported config file format: {config_path}. Use .yaml or .yml"
190+
)
178191

179192
return config
180193

@@ -198,7 +211,9 @@ def _setup_logging(self, log_level: str, log_file: Optional[str]):
198211
# Console handler
199212
ch = logging.StreamHandler()
200213
ch.setLevel(level)
201-
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
214+
formatter = logging.Formatter(
215+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
216+
)
202217
ch.setFormatter(formatter)
203218
self.logger.addHandler(ch)
204219

@@ -256,7 +271,9 @@ def _process_repetition(self, i_rep):
256271
dml_params = dict(zip(self.dml_parameters.keys(), dml_param_values))
257272
i_param_comb += 1
258273

259-
comb_results = self._process_parameter_combination(i_rep, i_param_comb, dgp_params, dml_params, dml_data)
274+
comb_results = self._process_parameter_combination(
275+
i_rep, i_param_comb, dgp_params, dml_params, dml_data
276+
)
260277

261278
# Merge results
262279
for result_name, result_list in comb_results.items():
@@ -266,11 +283,14 @@ def _process_repetition(self, i_rep):
266283

267284
return rep_results
268285

269-
def _process_parameter_combination(self, i_rep, i_param_comb, dgp_params, dml_params, dml_data):
286+
def _process_parameter_combination(
287+
self, i_rep, i_param_comb, dgp_params, dml_params, dml_data
288+
):
270289
"""Process a single parameter combination."""
271290
# Log parameter combination
272291
self.logger.debug(
273-
f"Rep {i_rep+1}, Combo {i_param_comb}/{self.total_combinations}: " f"DGPs {dgp_params}, DML {dml_params}"
292+
f"Rep {i_rep+1}, Combo {i_param_comb}/{self.total_combinations}: "
293+
f"DGPs {dgp_params}, DML {dml_params}"
274294
)
275295
param_start_time = time.time()
276296

@@ -279,7 +299,9 @@ def _process_parameter_combination(self, i_rep, i_param_comb, dgp_params, dml_pa
279299

280300
# Log timing
281301
param_duration = time.time() - param_start_time
282-
self.logger.debug(f"Parameter combination completed in {param_duration:.2f}s")
302+
self.logger.debug(
303+
f"Parameter combination completed in {param_duration:.2f}s"
304+
)
283305

284306
# Process results
285307
if repetition_results is None:
@@ -298,7 +320,8 @@ def _process_parameter_combination(self, i_rep, i_param_comb, dgp_params, dml_pa
298320

299321
except Exception as e:
300322
self.logger.error(
301-
f"Error: repetition {i_rep+1}, DGP parameters {dgp_params}, " f"DML parameters {dml_params}: {str(e)}"
323+
f"Error: repetition {i_rep+1}, DGP parameters {dgp_params}, "
324+
f"DML parameters {dml_params}: {str(e)}"
302325
)
303326
self.logger.exception("Exception details:")
304327
return {}
@@ -333,9 +356,13 @@ def _compute_coverage(thetas, oracle_thetas, confint, joint_confint=None):
333356
if joint_confint is not None:
334357
joint_lower_bound = joint_confint.iloc[:, 0]
335358
joint_upper_bound = joint_confint.iloc[:, 1]
336-
joint_coverage_mask = (joint_lower_bound < oracle_thetas) & (oracle_thetas < joint_upper_bound)
359+
joint_coverage_mask = (joint_lower_bound < oracle_thetas) & (
360+
oracle_thetas < joint_upper_bound
361+
)
337362

338363
result_dict["Uniform Coverage"] = np.all(joint_coverage_mask)
339-
result_dict["Uniform CI Length"] = np.mean(joint_upper_bound - joint_lower_bound)
364+
result_dict["Uniform CI Length"] = np.mean(
365+
joint_upper_bound - joint_lower_bound
366+
)
340367

341368
return result_dict

monte-cover/src/montecover/did/did_cs_multi.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def __init__(
3535
def _process_config_parameters(self):
3636
"""Process simulation-specific parameters from config"""
3737
# Process ML models in parameter grid
38-
assert "learners" in self.dml_parameters, "No learners specified in the config file"
38+
assert (
39+
"learners" in self.dml_parameters
40+
), "No learners specified in the config file"
3941

4042
required_learners = ["ml_g", "ml_m"]
4143
for learner in self.dml_parameters["learners"]:
@@ -54,19 +56,23 @@ def _calculate_oracle_values(self):
5456
lambda_t=self.dgp_parameters["lambda_t"][0],
5557
) # does not depend on the DGP type or lambda_t
5658
df_oracle["ite"] = df_oracle["y1"] - df_oracle["y0"]
57-
self.oracle_values["detailed"] = df_oracle.groupby(["d", "t"])["ite"].mean().reset_index()
59+
self.oracle_values["detailed"] = (
60+
df_oracle.groupby(["d", "t"])["ite"].mean().reset_index()
61+
)
5862

5963
# Oracle group aggregation
6064
df_oracle_post_treatment = df_oracle[df_oracle["t"] >= df_oracle["d"]]
61-
self.oracle_values["group"] = df_oracle_post_treatment.groupby("d")["ite"].mean()
65+
self.oracle_values["group"] = df_oracle_post_treatment.groupby("d")[
66+
"ite"
67+
].mean()
6268

6369
# Oracle time aggregation
6470
self.oracle_values["time"] = df_oracle_post_treatment.groupby("t")["ite"].mean()
6571

6672
# Oracle eventstudy aggregation
67-
df_oracle["e"] = pd.to_datetime(df_oracle["t"]).values.astype("datetime64[M]") - pd.to_datetime(
68-
df_oracle["d"]
69-
).values.astype("datetime64[M]")
73+
df_oracle["e"] = pd.to_datetime(df_oracle["t"]).values.astype(
74+
"datetime64[M]"
75+
) - pd.to_datetime(df_oracle["d"]).values.astype("datetime64[M]")
7076
self.oracle_values["eventstudy"] = df_oracle.groupby("e")["ite"].mean()[1:]
7177

7278
def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
@@ -96,7 +102,9 @@ def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
96102
for i, (g, _, t) in enumerate(dml_model.gt_combinations):
97103
group_index = self.oracle_values["detailed"]["d"] == g
98104
time_index = self.oracle_values["detailed"]["t"] == t
99-
oracle_thetas[i] = self.oracle_values["detailed"][group_index & time_index]["ite"].iloc[0]
105+
oracle_thetas[i] = self.oracle_values["detailed"][group_index & time_index][
106+
"ite"
107+
].iloc[0]
100108

101109
result = {
102110
"detailed": [],
@@ -121,7 +129,9 @@ def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]:
121129
thetas=agg_obj.aggregated_frameworks.thetas,
122130
oracle_thetas=self.oracle_values[aggregation_method].values,
123131
confint=agg_obj.aggregated_frameworks.confint(level=level),
124-
joint_confint=agg_obj.aggregated_frameworks.confint(level=level, joint=True),
132+
joint_confint=agg_obj.aggregated_frameworks.confint(
133+
level=level, joint=True
134+
),
125135
)
126136

127137
# add parameters to the result
@@ -163,14 +173,20 @@ def summarize_results(self):
163173

164174
result_summary = dict()
165175
for result_name, result_df in self.results.items():
166-
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
176+
result_summary[result_name] = (
177+
result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
178+
)
167179
self.logger.debug(f"Summarized {result_name} results")
168180

169181
return result_summary
170182

171183
def _generate_dml_data(self, dgp_params) -> dml.data.DoubleMLPanelData:
172184
"""Generate data for the simulation."""
173-
data = make_did_cs_CS2021(n_obs=dgp_params["n_obs"], dgp_type=dgp_params["DGP"], lambda_t=dgp_params["lambda_t"])
185+
data = make_did_cs_CS2021(
186+
n_obs=dgp_params["n_obs"],
187+
dgp_type=dgp_params["DGP"],
188+
lambda_t=dgp_params["lambda_t"],
189+
)
174190
dml_data = dml.data.DoubleMLPanelData(
175191
data,
176192
y_col="y",

monte-cover/src/montecover/did/did_pa_multi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def _process_config_parameters(self):
3636
"""Process simulation-specific parameters from config"""
3737
# Process ML models in parameter grid
3838
# Process ML models in parameter grid
39-
assert "learners" in self.dml_parameters, "No learners specified in the config file"
39+
assert (
40+
"learners" in self.dml_parameters
41+
), "No learners specified in the config file"
4042

4143
required_learners = ["ml_g", "ml_m"]
4244
for learner in self.dml_parameters["learners"]:

monte-cover/src/montecover/irm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from montecover.irm.cvar import CVARCoverageSimulation
77
from montecover.irm.iivm_late import IIVMLATECoverageSimulation
88
from montecover.irm.irm_ate import IRMATECoverageSimulation
9-
from montecover.irm.irm_ate_tune import IRMATETuningCoverageSimulation
109
from montecover.irm.irm_ate_sensitivity import IRMATESensitivityCoverageSimulation
10+
from montecover.irm.irm_ate_tune import IRMATETuningCoverageSimulation
1111
from montecover.irm.irm_atte import IRMATTECoverageSimulation
1212
from montecover.irm.irm_atte_sensitivity import IRMATTESensitivityCoverageSimulation
1313
from montecover.irm.irm_cate import IRMCATECoverageSimulation

monte-cover/src/montecover/irm/apo.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __init__(
3232
def _process_config_parameters(self):
3333
"""Process simulation-specific parameters from config"""
3434
# Process ML models in parameter grid
35-
assert "learners" in self.dml_parameters, "No learners specified in the config file"
35+
assert (
36+
"learners" in self.dml_parameters
37+
), "No learners specified in the config file"
3638

3739
required_learners = ["ml_g", "ml_m"]
3840
for learner in self.dml_parameters["learners"]:
@@ -62,15 +64,19 @@ def _calculate_oracle_values(self):
6264
for i in range(n_levels):
6365
ates[i] = apos[i + 1] - apos[0]
6466

65-
self.logger.info(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
67+
self.logger.info(
68+
f"Levels and their counts:\n{np.unique(d, return_counts=True)}"
69+
)
6670
self.logger.info(f"True APOs: {apos}")
6771
self.logger.info(f"True ATEs: {ates}")
6872

6973
self.oracle_values = dict()
7074
self.oracle_values["apos"] = apos
7175
self.oracle_values["ates"] = ates
7276

73-
def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]) -> Dict[str, Any]:
77+
def run_single_rep(
78+
self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]
79+
) -> Dict[str, Any]:
7480
"""Run a single repetition with the given parameters."""
7581
# Extract parameters
7682
learner_config = dml_params["learners"]
@@ -132,7 +138,9 @@ def summarize_results(self):
132138
# Aggregate results (possibly multiple result dfs)
133139
result_summary = dict()
134140
for result_name, result_df in self.results.items():
135-
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
141+
result_summary[result_name] = (
142+
result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
143+
)
136144
self.logger.debug(f"Summarized {result_name} results")
137145

138146
return result_summary

monte-cover/src/montecover/irm/apos.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __init__(
3232
def _process_config_parameters(self):
3333
"""Process simulation-specific parameters from config"""
3434
# Process ML models in parameter grid
35-
assert "learners" in self.dml_parameters, "No learners specified in the config file"
35+
assert (
36+
"learners" in self.dml_parameters
37+
), "No learners specified in the config file"
3638

3739
required_learners = ["ml_g", "ml_m"]
3840
for learner in self.dml_parameters["learners"]:
@@ -62,15 +64,19 @@ def _calculate_oracle_values(self):
6264
for i in range(n_levels):
6365
ates[i] = apos[i + 1] - apos[0]
6466

65-
self.logger.info(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
67+
self.logger.info(
68+
f"Levels and their counts:\n{np.unique(d, return_counts=True)}"
69+
)
6670
self.logger.info(f"True APOs: {apos}")
6771
self.logger.info(f"True ATEs: {ates}")
6872

6973
self.oracle_values = dict()
7074
self.oracle_values["apos"] = apos
7175
self.oracle_values["ates"] = ates
7276

73-
def run_single_rep(self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]) -> Dict[str, Any]:
77+
def run_single_rep(
78+
self, dml_data: dml.DoubleMLData, dml_params: Dict[str, Any]
79+
) -> Dict[str, Any]:
7480
"""Run a single repetition with the given parameters."""
7581
# Extract parameters
7682
learner_config = dml_params["learners"]
@@ -144,7 +150,9 @@ def summarize_results(self):
144150
# Aggregate results (possibly multiple result dfs)
145151
result_summary = dict()
146152
for result_name, result_df in self.results.items():
147-
result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
153+
result_summary[result_name] = (
154+
result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index()
155+
)
148156
self.logger.debug(f"Summarized {result_name} results")
149157

150158
return result_summary

0 commit comments

Comments
 (0)