forked from WeiChengTseng/Pytorch-PCGrad
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpcgrad.py
More file actions
90 lines (66 loc) · 2.66 KB
/
pcgrad.py
File metadata and controls
90 lines (66 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import random
import torch
from torch import autograd
class PCGrad:
def __init__(self, optimizer, reduction="sum"):
self.optimizer = optimizer
self.reduction = reduction
def zero_grad(self):
return self.optimizer.zero_grad(set_to_none=True)
def step(self):
return self.optimizer.step()
def pc_backward(self, model, losses):
pc_gradients = self.get_gradients(model, losses)
pc_gradients = self.project_conflicting(pc_gradients)
self.set_gradients(model, pc_gradients)
@torch.no_grad()
def project_conflicting(self, grads):
for name, grad in grads.items():
b, *_ = grad.shape
if b < 2: # require more than one loss per parameter
grads[name] = grads[name][0]
continue
for i in range(b):
g_i = grad[i].view(-1).clone()
for j in range(b):
if i == j:
continue
g_j = grad[j].view(-1)
sim = torch.dot(g_i, g_j)
if sim < 0:
g_i = g_i - (sim) * g_j / (g_j.norm() ** 2)
grad[i] = g_i.view_as(grad[i])
if self.reduction == "mean":
grads[name] = grads[name].mean(dim=0)
elif self.reduction == "sum":
grads[name] = grads[name].sum(dim=0)
else:
raise NotImplementedError(f"Reduction ({self.reduction}) not implemented")
return grads
def set_gradients(self, model, grads):
for name, param in model.named_parameters():
if name in grads.keys():
param.grad = grads[name]
else:
param.grad = None
def get_gradients(self, model, losses):
task_gradients = [
self.get_task_gradients(model, loss, retain_graph=(i < len(losses) - 1))
for i, loss in enumerate(losses)
]
random.shuffle(task_gradients)
gradients = {}
for param, _ in model.named_parameters():
grads = [x[param] for x in task_gradients if param in x.keys()]
if len(grads) > 0:
gradients[param] = torch.stack(grads)
return gradients
def get_task_gradients(self, model, loss, retain_graph=False):
params = []
for name, p in model.named_parameters():
if p.requires_grad:
params.append((name, p))
grads = autograd.grad(
loss, [p for _, p in params], retain_graph=retain_graph, allow_unused=True
)
return {name: g.clone() for (name, _), g in zip(params, grads) if g is not None}