1111class ICarl (IncrementalLearner ):
1212 """Implementation of iCarl.
1313
14+ # References:
15+ - iCaRL: Incremental Classifier and Representation Learning
16+ Sylvestre-Alvise Rebuffi, Alexander Kolesnikov, Georg Sperl, Christoph H. Lampert
17+ https://arxiv.org/abs/1611.07725
18+
1419 :param args: An argparse parsed arguments object.
1520 """
1621 def __init__ (self , args ):
1722 super ().__init__ ()
1823
1924 self ._device = args ["device" ]
20- self ._memory_size = args ["memory_size" ]
2125 self ._opt_name = args ["optimizer" ]
2226 self ._lr = args ["lr" ]
2327 self ._weight_decay = args ["weight_decay" ]
@@ -76,6 +80,9 @@ def _before_task(self, train_loader, val_loader):
7680 )
7781
7882 def _train_task (self , train_loader , val_loader ):
83+ for p in self .parameters ():
84+ p .register_hook (lambda grad : torch .clamp (grad , - 5 , 5 ))
85+
7986 print ("nb " , len (train_loader .dataset ))
8087
8188 prog_bar = trange (self ._n_epochs , desc = "Losses." )
@@ -187,7 +194,7 @@ def _compute_loss(self, logits, targets, idxes, train=True):
187194 previous_preds = self ._previous_preds if train else self ._previous_preds_val
188195 distil_loss = self ._distil_loss (
189196 logits [..., :self ._new_task_index ],
190- previous_preds [idxes ]
197+ previous_preds [idxes , : self . _new_task_index ]
191198 )
192199
193200 return clf_loss , distil_loss
@@ -232,20 +239,16 @@ def _m(self):
232239 return self ._k // self ._n_classes
233240
234241 def _add_n_classes (self , n ):
235- print ("add n classes" )
236242 self ._n_classes += n
237243
238- weight = self ._classifier .weight .data
239- # bias = self._classifier.bias.data
240-
244+ weights = self ._classifier .weight .data
241245 self ._classifier = nn .Linear (
242246 self ._features_extractor .out_dim , self ._n_classes ,
243247 bias = False
244248 ).to (self ._device )
245249 torch .nn .init .kaiming_normal_ (self ._classifier .weight )
246250
247- self ._classifier .weight .data [: self ._n_classes - n ] = weight
248- # self._classifier.bias.data[: self._n_classes - n] = bias
251+ self ._classifier .weight .data [: self ._n_classes - n ] = weights
249252
250253 print ("Now {} examplars per class." .format (self ._m ))
251254
@@ -258,7 +261,7 @@ def _extract_features(self, loader):
258261 features .append (self ._features_extractor (inputs ).detach ())
259262 idxes .extend (real_idxes .numpy ().tolist ())
260263
261- features = torch .cat (features )
264+ features = F . normalize ( torch .cat (features ), dim = 1 )
262265 mean = torch .mean (features , dim = 0 , keepdim = False )
263266
264267 return features , mean , idxes
@@ -314,9 +317,9 @@ def _build_examplars(self, loader):
314317 for i in range (min (self ._m , features .shape [0 ])):
315318 tmp = F .normalize (
316319 (features + examplars_mean ) / (i + 1 ),
317- dim = 0
320+ dim = 1
318321 )
319- distances = ( class_mean - tmp ). norm ( 2 , 1 )
322+ distances = self . _dist ( class_mean , tmp )
320323 idxes_winner = distances .argsort ().cpu ().numpy ()
321324
322325 for idx in idxes_winner :
0 commit comments