-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmeta_learner.py
More file actions
152 lines (109 loc) · 4.77 KB
/
meta_learner.py
File metadata and controls
152 lines (109 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author: Treamy
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from copy import deepcopy
from learner import Learner
class MAML(nn.Module):
"""
The Meta-Learner
"""
def __init__(self, args, config):
super(MAML, self).__init__()
# learning rate
self.update_lr = args.update_lr
self.meta_lr = args.meta_lr
self.n_way = args.n_way
self.k_spt = args.k_spt
self.k_qry = args.k_qry
self.tasks_num = args.tasks_num
self.num_update = args.num_update
self.test_num_update = args.test_num_update
self.learner = Learner(config)
self.meta_optim = optim.Adam(self.learner.parameters(), lr=self.meta_lr)
self.order = args.update_order # 二阶导或一阶近似
self.num_learnable_param = len(self.learner.parameters())
self.update_fn = lambda p, g: p - self.update_lr * g
def forward(self, tsk_xs, tsk_ys, tsk_xq, tsk_yq):
"""
:param tsk_xs: [tsk, n_way*k_spt, c,h,w]
:param tsk_ys: [tsk, n_way*k_spt,]
:param tsk_xq: [tsk, n_way*k_qry, c,h,w]
:param tsk_yq: [tsk, n_way*k_qry,]
:return:
"""
tasks_num, xs_sz, ch, h, w = tsk_xs.shape
xq_sz = tsk_xq.size(1)
create_graph = (True if self.order == 2 else False) and self.train
update_fn = self.update_fn
corr_q_over_update = [0 for _ in range(self.num_update)] # 用来判断在支持集上更新次数对模型准确率的影响
cum_loss = 0. # 在任务上的累计损失
cum_grads = [0. for _ in range(self.num_learnable_param)] # 每个参数在任务上的累计梯度
for i in range(tasks_num): # 对于每一个任务
fast_weights = self.learner.parameters()
for k in range(self.num_update):
logits_s = self.learner(tsk_xs[i], fast_weights)
loss_s = F.cross_entropy(logits_s, tsk_ys[i])
grads = torch.autograd.grad(loss_s, fast_weights, create_graph=create_graph)
fast_weights = list( map(update_fn, fast_weights, grads) )
with torch.no_grad():
logits_q = self.learner(tsk_xq[i], fast_weights)
pred_q = logits_q.argmax(dim=1)
corr_q = (pred_q == tsk_yq[i]).sum().item()
corr_q_over_update[k] += corr_q
# 子模型在查询集上的损失
logits_q = self.learner(tsk_xq[i], fast_weights)
loss_q = F.cross_entropy(logits_q, tsk_yq[i])
if self.order == 2:
cum_loss += loss_q # 所有任务累计损失
elif self.order == 1: #
grads = torch.autograd.grad(loss_q, fast_weights) # 当前任务损失的梯度,关于参数的列表
for j in range(self.num_learnable_param):
cum_grads[j] += grads[j]
else:
raise ValueError('Order must be either 1 or 2.')
# 所有任务结束,进行元更新
if self.order == 2:
loss_ = cum_loss / tasks_num # 所有任务上查询集的平均损失
loss_.backward()
elif self.order == 1:
grads = [param_grads / tasks_num for param_grads in cum_grads] # 所有任务上查询集的平均梯度
for p,g in zip(self.learner.parameters(), grads):
p.grad = g.clone()
else:
raise ValueError('Order must be either 1 or 2.')
self.meta_optim.step()
self.meta_optim.zero_grad()
accs = np.array(corr_q_over_update) / (tasks_num*xq_sz)
return accs
def fine_tuning(self, xs,ys, xq,yq):
"""
在一个任务上进行
:param xs: [n_way, k_spt, c,h,w]
:param ys: [n_way, k_spt]
:param xq: [n_way, k_qry, c,h,w]
:param yq: [n_way, k_qry]
:return:
"""
assert len(xs.shape) == 4
xq_sz = xq.size(0)
corr_q_over_update = [0 for _ in range(self.test_num_update)]
model = deepcopy(self.learner)
fast_weights = None
for k in range(self.test_num_update):
logits_s = model(xs, fast_weights)
loss_s = F.cross_entropy(logits_s, ys)
grads = torch.autograd.grad(loss_s, model.parameters())
fast_weights = list(map(self.update_fn, model.parameters(), grads))
with torch.no_grad():
logits_q = model(xq, fast_weights)
pred_q = logits_q.argmax(dim=1)
corr_q = (pred_q==yq).sum().item()
corr_q_over_update[k] += corr_q
del model
accs = np.array(corr_q_over_update) / xq_sz
return accs