diff --git a/solvers/mantis.py b/solvers/mantis.py index 5a114d8..2bee70b 100644 --- a/solvers/mantis.py +++ b/solvers/mantis.py @@ -84,7 +84,8 @@ def _extract_embeddings(self, X): else: embedding_dim = 128 all_embeddings.append( - np.zeros((batch_end - batch_idx, embedding_dim), dtype=np.float32) + np.zeros((batch_end - batch_idx, embedding_dim), + dtype=np.float32) ) # Concatenate all embeddings @@ -157,7 +158,8 @@ def set_objective(self, task, X_train, y_train, **meta): network = network.from_pretrained(self.checkpoint) self._network = network - self._trainer = MantisTrainer(device=device, network=self._network) + self._trainer = MantisTrainer( + device=device, network=self._network) self._loaded_checkpoint = self.checkpoint print( f"✓ Mantis checkpoint loaded: {self.checkpoint} on device: {device}" @@ -176,8 +178,11 @@ def run(self, _): self._adapter = LinearProbeAdapter( encoder=self, task=self.task, - n_estimators=self.n_estimators, classifier=self.classifier, + penalty=self.penalty, + C=self.C, + alpha=self.alpha, + n_estimators=self.n_estimators, ) self._adapter.fit(self.X_train, self.y_train)