-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
205 lines (171 loc) · 7.53 KB
/
test.py
File metadata and controls
205 lines (171 loc) · 7.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Test the models on the test dataset.
Author: H. Kaan Kale
Email: hkaankale1@gmail.com
"""
from pathlib import Path
from typing import Tuple, Union
import logging
from omegaconf import DictConfig, OmegaConf
import hydra
from src.utils.general import (
set_seed,
configure_torch_backend,
)
from src.utils.config import (
ExperimentParams,
ExperimentType,
DatasetName,
MNLIParams,
MNLICombinationType,
load_experiment_params,
)
from src.test.utility_privacy import UtilityPrivacyTester
from src.test.compression import CompressionTester
from src.test.utils import TestParams
def find_experiment_config(
*,
experiments_dir: Union[Path, str],
experiment_date: str = None,
test_config: DictConfig = None,
) -> Tuple[ExperimentParams, Path]:
"""
Find the experiment config file in the experiments directory.
If experiment_date is provided, load the experiment config file from the experiments directory.
If experiment_date is not provided,
find the experiment with the same params as the test_experiment_config.
:params test_config: DictConfig, the test config file.
:params experiments_dir: Path, the experiments directory path.
:params experiment_date: str, the experiment date to load the experiment config file.
:returns: ExperimentParams, the experiment params of the experiment config file.
Path, Experiment directory path
:raises: ValueError, if the experiment with the same params
as the test_experiment_config is not found.
"""
# If experiment_date is provided, load the experiment config file from the experiments directory
if experiment_date is not None:
# Find the experiment config file in the experiments directory
experiment_params, experiment_dir = find_config_in_experiments_by_date(
experiment_date=experiment_date, experiments_dir=experiments_dir
)
return experiment_params, experiment_dir
# If experiment_date is not provided,
# find the experiment with the same params as the test_experiment_config
if test_config is not None:
experiment_params, experiment_dir = find_config_in_experiments_by_config(
config=test_config, experiments_dir=experiments_dir
)
return experiment_params, experiment_dir
else:
raise ValueError(
f"Experiment with params {test_config} not found in "
f"{experiments_dir} and experiment_date {experiment_date}"
)
def find_config_in_experiments_by_date(
experiment_date: str, experiments_dir: Union[Path, str]
) -> Tuple[ExperimentParams, Path]:
"""
Find the experiment config file in the experiments directory by date.
:params experiment_date: str, the experiment date.
:params experiments_dir: Path, the experiments directory path.
:returns: ExperimentParams, the experiment params of the experiment config file.
Path, Experiment directory path
"""
if isinstance(experiments_dir, str):
experiments_dir = Path(experiments_dir)
experiment_dir = experiments_dir / experiment_date
experiment_config_file = experiment_dir / ".hydra" / "config.yaml"
if not experiment_config_file.exists():
raise FileNotFoundError(
f"Experiment config file not found in {experiment_config_file}"
)
experiment_config = OmegaConf.load(experiment_config_file)
experiment_params = load_experiment_params(experiment_config)
return experiment_params, experiment_dir
def find_config_in_experiments_by_config(
config: Union[DictConfig, ExperimentParams, Path], experiments_dir: Union[Path, str]
) -> Tuple[ExperimentParams, Path]:
"""
Find the experiment config file in the experiments directory.
:params config: DictConfig, ExperimentParams, Path,
the config file to find in the experiments directory,
DictConfig or ExperimentParams to find in the experiments directory
:returns: ExperimentParams, the experiment params of the experiment config file.
Path, Experiment directory path
"""
if isinstance(config, Path):
config = OmegaConf.load(config)
if isinstance(config, DictConfig):
config: ExperimentParams = load_experiment_params(config)
if isinstance(experiments_dir, str):
experiments_dir = Path(experiments_dir)
# Sort experiments dir according to the date and hour
experiments_dir = sorted(
experiments_dir.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True
)
for experiment_dir in experiments_dir:
experiment_config_file = experiment_dir / ".hydra" / "config.yaml"
# Load the experiment config as DictConfig
experiment_config = OmegaConf.load(experiment_config_file)
experiment_params = load_experiment_params(experiment_config)
if experiment_params == config:
return experiment_params, experiment_dir
raise ValueError(f"Experiment with params {config} not found in {experiments_dir}")
@hydra.main(config_path="configs", config_name="test_config", version_base="1.2")
def main(config: DictConfig) -> None:
"""Main function to test the models on the test dataset."""
# Configure the logger
logging.basicConfig(level=logging.DEBUG)
# Number of epochs for training the test classifiers
load_from_epoch: int = 8
# Set the seed and configure the torch backend
seed: int = 42
set_seed(seed)
configure_torch_backend()
# Test config parameters
batch_size: int = config.batch_size
device_idx: int = config.device_idx
experiments_dir: str = config.experiments_dir
test_experiment_config = config.experiment
test_params = TestParams(max_epoch=-1, batch_size=batch_size, device_idx=device_idx)
# Load experiment params and directory according to the test config
experiment_params: ExperimentParams
experiment_dir: Path
experiment_date = config.experiment_date
experiment_params, experiment_dir = find_experiment_config(
experiments_dir=experiments_dir,
experiment_date=experiment_date,
test_config=test_experiment_config,
)
logging.debug("Experiment config found at %s", experiment_dir)
if experiment_params.dataset_params.dataset_name == DatasetName.MNLI:
dataset_params: MNLIParams = experiment_params.dataset_params
dataset_params.combination_type = MNLICombinationType.CONCAT
# Load the experiment type
experiment_type: ExperimentType = ExperimentType(experiment_params.experiment_type)
# Get the encoder weights path
weights_dir = experiment_dir / "encoder_weights"
if weights_dir is not None:
if load_from_epoch is None:
encoder_weights_path = sorted(
weights_dir.glob("model_*.pt"), key=lambda x: int(x.stem.split("_")[-1])
)[-1]
else:
encoder_weights_path = weights_dir / f"model_{load_from_epoch}.pt"
# Create the test types based on the experiment type
if experiment_type == ExperimentType.UTILITY:
utility_privacy_tester = UtilityPrivacyTester(
experiment_params=experiment_params,
test_params=test_params,
encoder_weights_path=encoder_weights_path,
)
utility_privacy_tester.run_all_tests()
elif experiment_type == ExperimentType.COMPRESSION:
compression_tester = CompressionTester(
experiment_params=experiment_params,
test_params=test_params,
encoder_weights_path=encoder_weights_path,
)
compression_tester.run_all_tests()
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter