-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparams_SE.py
More file actions
50 lines (39 loc) · 2.24 KB
/
params_SE.py
File metadata and controls
50 lines (39 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
from torch import nn
def parse() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="SE")
# Training parameters
parser.add_argument("--learning_rate", "-lr", type=float, default=0.01)
parser.add_argument("--learning_rate_min", "-lr_min", type=float, default=0.01)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--arch_learning_rate", type=float, default=1e-3)
parser.add_argument("--arch_weight_decay", type=float, default=0)
parser.add_argument("--arch_init_type", type=str, default="normal")
parser.add_argument("--arch_init_ratio", type=float, default=1e-3)
parser.add_argument("--gpu", type=int, default=0, help="activated GPU number")
parser.add_argument("--root_dir", type=str, default="./")
parser.add_argument("--savepath", type=str, default="SE", help="path for save folder")
parser.add_argument("--savepath_num", type=int, default=0, help="number for save folder")
parser.add_argument("--epochs", type=int, default=1000, help="number of training epochs")
parser.add_argument("--report_freq", type=int, default=10)
# Dataset
parser.add_argument("--train_size", type=int, default=100000)
parser.add_argument("--valid_size", type=int, default=100000)
parser.add_argument("--test_size", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=1000)
# Search space
parser.add_argument("--steps", type=int, default=5)
parser.add_argument("--num_random_RF", type=int, default=5)
parser.add_argument("--num_random_NO", type=int, default=3)
parser.add_argument("--dephasing_spin_num", type=int, default=256)
parser.add_argument("--grad_spin_num", type=int, default=1)
parser.add_argument("--Echo", type=int, default=1)
parser.add_argument("--noise_SD", type=float, default=0.0)
# Loss parameters
parser.add_argument("--mag_loss", default=nn.MSELoss())
parser.add_argument("--mag_weights", type=float, default=1.0)
parser.add_argument("--SAR_weights", type=float, default=0.0001)
parser.add_argument("--RFpenalty_weights", type=float, default=0.001)
args = parser.parse_args()
return args