Skip to content

Commit 9a51061

Browse files
author
Arthur Douillard
committed
[icarl] Improvement on iCaRL.
1 parent ba2359c commit 9a51061

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

inclearn/models/icarl.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
class ICarl(IncrementalLearner):
1212
"""Implementation of iCarl.
1313
14+
# References:
15+
- iCaRL: Incremental Classifier and Representation Learning
16+
Sylvestre-Alvise Rebuffi, Alexander Kolesnikov, Georg Sperl, Christoph H. Lampert
17+
https://arxiv.org/abs/1611.07725
18+
1419
:param args: An argparse parsed arguments object.
1520
"""
1621
def __init__(self, args):
1722
super().__init__()
1823

1924
self._device = args["device"]
20-
self._memory_size = args["memory_size"]
2125
self._opt_name = args["optimizer"]
2226
self._lr = args["lr"]
2327
self._weight_decay = args["weight_decay"]
@@ -76,6 +80,9 @@ def _before_task(self, train_loader, val_loader):
7680
)
7781

7882
def _train_task(self, train_loader, val_loader):
83+
for p in self.parameters():
84+
p.register_hook(lambda grad: torch.clamp(grad, -5, 5))
85+
7986
print("nb ", len(train_loader.dataset))
8087

8188
prog_bar = trange(self._n_epochs, desc="Losses.")
@@ -187,7 +194,7 @@ def _compute_loss(self, logits, targets, idxes, train=True):
187194
previous_preds = self._previous_preds if train else self._previous_preds_val
188195
distil_loss = self._distil_loss(
189196
logits[..., :self._new_task_index],
190-
previous_preds[idxes]
197+
previous_preds[idxes, :self._new_task_index]
191198
)
192199

193200
return clf_loss, distil_loss
@@ -232,20 +239,16 @@ def _m(self):
232239
return self._k // self._n_classes
233240

234241
def _add_n_classes(self, n):
235-
print("add n classes")
236242
self._n_classes += n
237243

238-
weight = self._classifier.weight.data
239-
# bias = self._classifier.bias.data
240-
244+
weights = self._classifier.weight.data
241245
self._classifier = nn.Linear(
242246
self._features_extractor.out_dim, self._n_classes,
243247
bias=False
244248
).to(self._device)
245249
torch.nn.init.kaiming_normal_(self._classifier.weight)
246250

247-
self._classifier.weight.data[: self._n_classes - n] = weight
248-
# self._classifier.bias.data[: self._n_classes - n] = bias
251+
self._classifier.weight.data[: self._n_classes - n] = weights
249252

250253
print("Now {} examplars per class.".format(self._m))
251254

@@ -258,7 +261,7 @@ def _extract_features(self, loader):
258261
features.append(self._features_extractor(inputs).detach())
259262
idxes.extend(real_idxes.numpy().tolist())
260263

261-
features = torch.cat(features)
264+
features = F.normalize(torch.cat(features), dim=1)
262265
mean = torch.mean(features, dim=0, keepdim=False)
263266

264267
return features, mean, idxes
@@ -314,9 +317,9 @@ def _build_examplars(self, loader):
314317
for i in range(min(self._m, features.shape[0])):
315318
tmp = F.normalize(
316319
(features + examplars_mean) / (i + 1),
317-
dim=0
320+
dim=1
318321
)
319-
distances = (class_mean - tmp).norm(2, 1)
322+
distances = self._dist(class_mean, tmp)
320323
idxes_winner = distances.argsort().cpu().numpy()
321324

322325
for idx in idxes_winner:

0 commit comments

Comments
 (0)