|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | +from torch.nn import functional as F |
| 5 | +from tqdm import trange |
| 6 | + |
| 7 | +from inclearn import factory, utils |
| 8 | +from inclearn.models.base import IncrementalLearner |
| 9 | + |
| 10 | + |
| 11 | +class End2End(IncrementalLearner): |
| 12 | + """Implementation of End-to-End Increment Learning. |
| 13 | +
|
| 14 | + :param args: An argparse parsed arguments object. |
| 15 | + """ |
| 16 | + |
| 17 | + def __init__(self, args): |
| 18 | + super().__init__() |
| 19 | + |
| 20 | + self._device = args["device"] |
| 21 | + self._opt_name = args["optimizer"] |
| 22 | + self._lr = args["lr"] |
| 23 | + self._weight_decay = args["weight_decay"] |
| 24 | + self._n_epochs = args["epochs"] |
| 25 | + |
| 26 | + self._scheduling = args["scheduling"] |
| 27 | + self._lr_decay = args["lr_decay"] |
| 28 | + |
| 29 | + self._k = args["memory_size"] |
| 30 | + self._n_classes = args["increment"] |
| 31 | + |
| 32 | + self._temperature = args["temperature"] |
| 33 | + |
| 34 | + self._features_extractor = factory.get_resnet( |
| 35 | + args["convnet"], nf=64, zero_init_residual=True |
| 36 | + ) |
| 37 | + self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False) |
| 38 | + torch.nn.init.kaiming_normal_(self._classifier.weight) |
| 39 | + |
| 40 | + self._examplars = {} |
| 41 | + self._means = None |
| 42 | + |
| 43 | + self.to(self._device) |
| 44 | + |
| 45 | + def forward(self, x): |
| 46 | + x = self._features_extractor(x) |
| 47 | + x = self._classifier(x) |
| 48 | + return x |
| 49 | + |
| 50 | + # ---------- |
| 51 | + # Public API |
| 52 | + # ---------- |
| 53 | + |
| 54 | + def _before_task(self, train_loader, val_loader): |
| 55 | + """Set up before the task training can begin. |
| 56 | +
|
| 57 | + 1. Precomputes previous model probabilities. |
| 58 | + 2. Extend the classifier to support new classes. |
| 59 | +
|
| 60 | + :param train_loader: The training dataloader. |
| 61 | + :param val_loader: The validation dataloader. |
| 62 | + """ |
| 63 | + if self._task == 0: |
| 64 | + self._previous_preds = None |
| 65 | + else: |
| 66 | + print("Computing previous predictions...") |
| 67 | + self._previous_preds = self._compute_predictions(train_loader) |
| 68 | + if val_loader: |
| 69 | + self._previous_preds_val = self._compute_predictions(val_loader) |
| 70 | + |
| 71 | + self._add_n_classes(self._task_size) |
| 72 | + |
| 73 | + def _train_task(self, train_loader, val_loader): |
| 74 | + # Training on all new + examplars |
| 75 | + self.foo = 0 |
| 76 | + optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.1, 0.0001) |
| 77 | + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20, 30], gamma=0.1) |
| 78 | + self._train(train_loader, 1, optimizer, scheduler) |
| 79 | + |
| 80 | + if self._task == 0: |
| 81 | + return |
| 82 | + |
| 83 | + # Fine-tuning on sub-set new + examplars |
| 84 | + self._build_examplars(train_loader) |
| 85 | + train_loader.dataset.set_idxes(self.examplars) # Fine-tuning only on balanced dataset |
| 86 | + optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.01, 0.0001) |
| 87 | + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1) |
| 88 | + self.foo = 1 |
| 89 | + self._train(train_loader, 1, optimizer, scheduler) |
| 90 | + |
| 91 | + def _after_task(self, data_loader): |
| 92 | + self._reduce_examplars() |
| 93 | + self._build_examplars(data_loader) |
| 94 | + |
| 95 | + def _eval_task(self, data_loader): |
| 96 | + ypred, ytrue = self._classify(data_loader) |
| 97 | + assert ypred.shape == ytrue.shape |
| 98 | + |
| 99 | + return ypred, ytrue |
| 100 | + |
| 101 | + def get_memory_indexes(self): |
| 102 | + return self.examplars |
| 103 | + |
| 104 | + # ----------- |
| 105 | + # Private API |
| 106 | + # ----------- |
| 107 | + |
| 108 | + def _train(self, train_loader, n_epochs, optimizer, scheduler): |
| 109 | + print("nb ", len(train_loader.dataset)) |
| 110 | + |
| 111 | + prog_bar = trange(n_epochs, desc="Losses.") |
| 112 | + |
| 113 | + for epoch in prog_bar: |
| 114 | + _clf_loss, _distil_loss = 0., 0. |
| 115 | + c = 0 |
| 116 | + |
| 117 | + scheduler.step() |
| 118 | + |
| 119 | + for i, ((_, idxes), inputs, targets) in enumerate(train_loader, start=1): |
| 120 | + optimizer.zero_grad() |
| 121 | + |
| 122 | + c += len(idxes) |
| 123 | + inputs, targets = inputs.to(self._device), targets.to(self._device) |
| 124 | + logits = self.forward(inputs) |
| 125 | + |
| 126 | + clf_loss, distil_loss = self._compute_loss( |
| 127 | + logits, |
| 128 | + targets, |
| 129 | + idxes, |
| 130 | + ) |
| 131 | + |
| 132 | + if not utils._check_loss(clf_loss) or not utils._check_loss(distil_loss): |
| 133 | + import pdb |
| 134 | + pdb.set_trace() |
| 135 | + |
| 136 | + loss = clf_loss + distil_loss |
| 137 | + |
| 138 | + loss.backward() |
| 139 | + optimizer.step() |
| 140 | + |
| 141 | + _clf_loss += clf_loss.item() |
| 142 | + _distil_loss += distil_loss.item() |
| 143 | + |
| 144 | + if i % 10 == 0 or i >= len(train_loader): |
| 145 | + prog_bar.set_description( |
| 146 | + "Clf loss: {}; Distill loss: {}".format( |
| 147 | + round(clf_loss.item(), 3), round(distil_loss.item(), 3) |
| 148 | + ) |
| 149 | + ) |
| 150 | + |
| 151 | + prog_bar.set_description( |
| 152 | + "Clf loss: {}; Distill loss: {}".format( |
| 153 | + round(_clf_loss / c, 3), round(_distil_loss / c, 3) |
| 154 | + ) |
| 155 | + ) |
| 156 | + |
| 157 | + def _compute_loss(self, logits, targets, idxes): |
| 158 | + """Computes the classification loss & the distillation loss. |
| 159 | +
|
| 160 | + Distillation loss is null at the first task. |
| 161 | +
|
| 162 | + :param logits: Logits produced the model. |
| 163 | + :param targets: The targets. |
| 164 | + :param idxes: The real indexes of the just-processed images. Needed to |
| 165 | + match the previous predictions. |
| 166 | + :return: A tuple of the classification loss and the distillation loss. |
| 167 | + """ |
| 168 | + if self._task == 0: |
| 169 | + clf_loss = F.cross_entropy(logits, targets) |
| 170 | + distil_loss = torch.zeros(1, device=self._device) |
| 171 | + else: |
| 172 | + # Disable the cross_entropy loss for the old targets: |
| 173 | + for i in range(self._new_task_index): |
| 174 | + targets[targets == i] = -1 |
| 175 | + clf_loss = F.cross_entropy(logits, targets, ignore_index=-1) |
| 176 | + |
| 177 | + distil_loss = F.binary_cross_entropy( |
| 178 | + F.softmax(logits[..., :self._new_task_index] ** (1 / self._temperature), dim=1), |
| 179 | + F.softmax(self._previous_preds[idxes]**(1 / self._temperature), dim=1) |
| 180 | + ) |
| 181 | + |
| 182 | + return clf_loss, distil_loss |
| 183 | + |
| 184 | + def _compute_predictions(self, loader): |
| 185 | + """Precomputes the logits before a task. |
| 186 | +
|
| 187 | + :param data_loader: A DataLoader. |
| 188 | + :return: A tensor storing the whole current dataset logits. |
| 189 | + """ |
| 190 | + logits = torch.zeros(self._n_train_data, self._n_classes, device=self._device) |
| 191 | + |
| 192 | + for idxes, inputs, _ in loader: |
| 193 | + inputs = inputs.to(self._device) |
| 194 | + idxes = idxes[1].to(self._device) |
| 195 | + |
| 196 | + logits[idxes] = self.forward(inputs).detach() |
| 197 | + |
| 198 | + return logits |
| 199 | + |
| 200 | + def _classify(self, loader): |
| 201 | + """Classify the images given by the data loader. |
| 202 | +
|
| 203 | + :param data_loader: A DataLoader. |
| 204 | + :return: A numpy array of the predicted targets and a numpy array of the |
| 205 | + ground-truth targets. |
| 206 | + """ |
| 207 | + ypred = [] |
| 208 | + ytrue = [] |
| 209 | + |
| 210 | + for _, inputs, targets in loader: |
| 211 | + inputs = inputs.to(self._device) |
| 212 | + logits = self.forward(inputs) |
| 213 | + preds = F.softmax(logits, dim=1).argmax(dim=1) |
| 214 | + |
| 215 | + ypred.extend(preds) |
| 216 | + ytrue.extend(targets) |
| 217 | + |
| 218 | + return np.array(ypred), np.array(ytrue) |
| 219 | + |
| 220 | + @property |
| 221 | + def _m(self): |
| 222 | + """Returns the number of examplars per class.""" |
| 223 | + return self._k // self._n_classes |
| 224 | + |
| 225 | + def _add_n_classes(self, n): |
| 226 | + self._n_classes += n |
| 227 | + |
| 228 | + weights = self._classifier.weight.data |
| 229 | + self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, |
| 230 | + bias=False).to(self._device) |
| 231 | + torch.nn.init.kaiming_normal_(self._classifier.weight) |
| 232 | + |
| 233 | + self._classifier.weight.data[:self._n_classes - n] = weights |
| 234 | + |
| 235 | + print("Now {} examplars per class.".format(self._m)) |
| 236 | + |
| 237 | + def _extract_features(self, loader): |
| 238 | + features = [] |
| 239 | + idxes = [] |
| 240 | + |
| 241 | + for (real_idxes, _), inputs, _ in loader: |
| 242 | + inputs = inputs.to(self._device) |
| 243 | + features.append(self._features_extractor(inputs).detach()) |
| 244 | + idxes.extend(real_idxes.numpy().tolist()) |
| 245 | + |
| 246 | + features = torch.cat(features) |
| 247 | + mean = torch.mean(features, dim=0, keepdim=False) |
| 248 | + |
| 249 | + return features, mean, idxes |
| 250 | + |
| 251 | + @staticmethod |
| 252 | + def _get_closest(centers, features): |
| 253 | + """Returns the center index being the closest to each feature. |
| 254 | +
|
| 255 | + :param centers: Centers to compare, in this case the class means. |
| 256 | + :param features: A tensor of features extracted by the convnet. |
| 257 | + :return: A numpy array of the closest centers indexes. |
| 258 | + """ |
| 259 | + pred_labels = [] |
| 260 | + |
| 261 | + features = features |
| 262 | + for feature in features: |
| 263 | + distances = End2End._dist(centers, feature) |
| 264 | + pred_labels.append(distances.argmin().item()) |
| 265 | + |
| 266 | + return np.array(pred_labels) |
| 267 | + |
| 268 | + @staticmethod |
| 269 | + def _dist(a, b): |
| 270 | + """Computes L2 distance between two tensors. |
| 271 | +
|
| 272 | + :param a: A tensor. |
| 273 | + :param b: A tensor. |
| 274 | + :return: A tensor of distance being of the shape of the "biggest" input |
| 275 | + tensor. |
| 276 | + """ |
| 277 | + return torch.pow(a - b, 2).sum(-1) |
| 278 | + |
| 279 | + def _build_examplars(self, loader): |
| 280 | + """Builds new examplars. |
| 281 | +
|
| 282 | + :param loader: A DataLoader. |
| 283 | + """ |
| 284 | + lo, hi = self._task * self._task_size, self._n_classes |
| 285 | + print("Building examplars for classes {} -> {}.".format(lo, hi)) |
| 286 | + for class_idx in range(lo, hi): |
| 287 | + loader.dataset.set_classes_range(class_idx, class_idx) |
| 288 | + self._examplars[class_idx] = self._build_class_examplars(loader) |
| 289 | + |
| 290 | + def _build_class_examplars(self, loader): |
| 291 | + """Build examplars for a single class. |
| 292 | +
|
| 293 | + Examplars are selected as the closest to the class mean. |
| 294 | +
|
| 295 | + :param loader: DataLoader that provides images for a single class. |
| 296 | + :return: The real indexes of the chosen examplars. |
| 297 | + """ |
| 298 | + features, class_mean, idxes = self._extract_features(loader) |
| 299 | + |
| 300 | + class_mean = F.normalize(class_mean, dim=0) |
| 301 | + distances_to_mean = self._dist(class_mean, features) |
| 302 | + |
| 303 | + nb_examplars = min(self._m, len(features)) |
| 304 | + |
| 305 | + fake_idxes = distances_to_mean.argsort().cpu().numpy()[:nb_examplars] |
| 306 | + return [idxes[idx] for idx in fake_idxes] |
| 307 | + |
| 308 | + @property |
| 309 | + def examplars(self): |
| 310 | + """Returns all the real examplars indexes. |
| 311 | +
|
| 312 | + :return: A numpy array of indexes. |
| 313 | + """ |
| 314 | + return np.array( |
| 315 | + [ |
| 316 | + examplar_idx for class_examplars in self._examplars.values() |
| 317 | + for examplar_idx in class_examplars |
| 318 | + ] |
| 319 | + ) |
| 320 | + |
| 321 | + def _reduce_examplars(self): |
| 322 | + print("Reducing examplars.") |
| 323 | + for class_idx in range(len(self._examplars)): |
| 324 | + self._examplars[class_idx] = self._examplars[class_idx][:self._m] |
0 commit comments