-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathoptimization.py
More file actions
38 lines (26 loc) · 1.06 KB
/
optimization.py
File metadata and controls
38 lines (26 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
class Adagrad:
def __init__(self, lr):
self._lr = lr
# variable name => sum of gradient square (also a vector)
self._sum_grad2 = {}
def update(self, variables, gradients):
for gradname, gradient in gradients.items():
# ------ update cache
g2 = gradient * gradient
if gradname in self._sum_grad2:
self._sum_grad2[gradname] += g2
else:
self._sum_grad2[gradname] = g2
# ------ calculate delta
delta = self._lr * gradient / (np.sqrt(self._sum_grad2[gradname]) + 1e-6)
# ------ update
if '@' in gradname:
# 对应着稀疏输入的权重与梯度,gradients中的key遵循着'vocab_name@feat_id'的格式
varname, row = gradname.split('@')
row = int(row)
variable = variables[varname]
variable[row, :] -= delta
else:
variable = variables[gradname]
variable -= delta