-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
50 lines (42 loc) · 1.07 KB
/
config.py
File metadata and controls
50 lines (42 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from torch import nn
BASE_SAVE_TO = "bins/baseViT.bin"
BASE_LOAD_FROM = "bins/baseViT.bin"
PRETRAINED_SAVE_TO = "bins/pretrainedViT.bin"
PRETRAINED_FROM = "bins/pretrainedViT.bin"
class BaseConfig:
def __init__(self):
self.iters = 50
self.batch_size = 16
self.dataset_len, self.testset_len = 1000, 500
self.dummy = None
self.n_heads = 3
self.n_stacks = 6
self.n_hidden = 3
self.dim = 900
self.output_dim = 10
self.bias = True
self.dropout = 0.1
self.attention_dropout = 0.1
self.eps = 1e-3
self.betas = (0.9, 0.98)
self.epochs = 5
self.batch_size = 16
self.lr = 1e-4
self.clip_grad = False
self.mask_prob = 0.3
self.init_weights = init_weights
self.mask_val = -1e-9
self.mask_ratio = 768
# __init__
# Config
class AdapterConfig(BaseConfig):
def __init__(self):
super().__init__()
self.output_dim = 10
# __init__()
# AdapterConfig
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
# init_weights