-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
43 lines (36 loc) · 1.73 KB
/
main.py
File metadata and controls
43 lines (36 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from MultiFidelityBNN.active_learning.sequential_sampling import SequentialSampling
from MultiFidelityBNN.cases_config import Case_1, Case_2, Case_3, Case_4, Case_5, Case_6
def main(case: int):
if case == 1:
CaseConfig = Case_1
elif case == 2:
CaseConfig = Case_2
elif case == 3:
CaseConfig = Case_3
elif case == 4:
CaseConfig = Case_4
elif case == 5:
CaseConfig = Case_5
elif case == 6:
CaseConfig = Case_6
else:
raise ValueError(f"Invalid case number: {case}. Must be between 1 and 6.")
# methods_to_compare = {"baseline": "baseline", "DDP": "DDP"}
methods_to_compare = {"DDP": "DDP"}
comparison_results = {}
for name, method_str in methods_to_compare.items():
print(f"\n--- Running Case {case} with method: {name} ({method_str}) ---")
results = SequentialSampling(
funcs=CaseConfig.Funcs, D=CaseConfig.Dim, fidelity_num=len(CaseConfig.Funcs), cost=CaseConfig.COST,
sample_sizes=CaseConfig.SampleSize, pool_size=CaseConfig.PoolSize,
configs=CaseConfig.NetConfigs, method=method_str,
low_bounds=CaseConfig.LOW, high_bounds=CaseConfig.HIGH,
warm_up_num=CaseConfig.WarmUpNum, sample_num=CaseConfig.SampleNum, seed=CaseConfig.Seed,
batch_size=CaseConfig.BatchSize,
iteration_num=CaseConfig.IterationNum, test_set=CaseConfig.TestSet,
error_threshold=CaseConfig.ErrorThreshold)
model, train_data, error_history, cost_history, sample_history = results
comparison_results[name] = (error_history, cost_history, train_data)
return comparison_results
if __name__ == "__main__":
main(case=2)