-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathChildTuningOptimizer.py
More file actions
138 lines (120 loc) · 5.48 KB
/
ChildTuningOptimizer.py
File metadata and controls
138 lines (120 loc) · 5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# original code credit - https://github.com/RunxinXu/ChildTuning
# this is from the paper Child Tuning: https://arxiv.org/pdf/2109.05687.pdf
import torch
from torch.optim import Optimizer
from typing import Callable, Iterable, Tuple
from torch.distributions.bernoulli import Bernoulli
import math
class ChildTuningAdamW(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
correct_bias: bool = True,
reserve_p=0.5,
mode=None,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=correct_bias,
)
super().__init__(params, defaults)
self.gradient_mask = None
self.reserve_p = reserve_p
self.mode = mode
if mode == "taskfree":
print(f"--> optimizer running with taskfree Child Tuning")
elif mode == "task":
print(f"--> optimizer running with fim masking active")
else:
print(
f"--> WARNING: No mode Set! This is the same as running regular AdamW"
)
def set_gradient_mask(self, gradient_mask):
self.gradient_mask = gradient_mask
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
# =================== Child Tuning BEGIN =======================
if self.mode is not None:
if self.mode == "task":
if p in self.gradient_mask:
grad *= self.gradient_mask[p]
elif self.mode == "taskfree":
# ChildTuning-F
grad_mask = Bernoulli(
grad.new_full(size=grad.size(), fill_value=self.reserve_p)
)
grad *= grad_mask.sample() / self.reserve_p
else:
raise ValueError(
"running Child Tuning optimizer but no mode set...aborting"
)
# =================== Child Tuning END =======================
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = (
step_size * math.sqrt(bias_correction2) / bias_correction1
)
p.data.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
return loss