From e72d8b269b673b581a40a87154ebfb7c20884bc5 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Tue, 2 Jun 2026 21:49:28 +0200 Subject: [PATCH 1/2] Fix mantis: removed max_iter arg / remove penalty, C, alpha parameters / rename n_iterators to n_estimators --- solvers/mantis.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/solvers/mantis.py b/solvers/mantis.py index 1ad023c..a3e819e 100644 --- a/solvers/mantis.py +++ b/solvers/mantis.py @@ -40,10 +40,7 @@ class Solver(BaseSolver): "batch_size": [32], "interpolate_to": [512], "classifier": ["random_forest"], - "penalty": ["l2"], - "C": [1.0], - "alpha": [1.0], - "n_iterators": [100], + "n_estimators": [100], } def _extract_embeddings(self, X): @@ -84,7 +81,8 @@ def _extract_embeddings(self, X): else: embedding_dim = 128 all_embeddings.append( - np.zeros((batch_end - batch_idx, embedding_dim), dtype=np.float32) + np.zeros((batch_end - batch_idx, embedding_dim), + dtype=np.float32) ) # Concatenate all embeddings @@ -157,7 +155,8 @@ def set_objective(self, task, X_train, y_train, **meta): network = network.from_pretrained(self.checkpoint) self._network = network - self._trainer = MantisTrainer(device=device, network=self._network) + self._trainer = MantisTrainer( + device=device, network=self._network) self._loaded_checkpoint = self.checkpoint print( f"✓ Mantis checkpoint loaded: {self.checkpoint} on device: {device}" @@ -176,7 +175,6 @@ def run(self, _): self._adapter = LinearProbeAdapter( encoder=self, task=self.task, - max_iter=self.max_iter, n_estimators=self.n_estimators, classifier=self.classifier, ) From 7af482d1d9463f712c3c7459a96309c2d93b6180 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Tue, 2 Jun 2026 22:12:10 +0200 Subject: [PATCH 2/2] FIX mantis: align solver with updated LinearProbeAdapter API --- solvers/mantis.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/solvers/mantis.py b/solvers/mantis.py index a3e819e..2bee70b 100644 --- a/solvers/mantis.py +++ b/solvers/mantis.py @@ -40,6 +40,9 @@ class Solver(BaseSolver): "batch_size": [32], "interpolate_to": [512], "classifier": ["random_forest"], + "penalty": ["l2"], + "C": [1.0], + "alpha": [1.0], "n_estimators": [100], } @@ -175,8 +178,11 @@ def run(self, _): self._adapter = LinearProbeAdapter( encoder=self, task=self.task, - n_estimators=self.n_estimators, classifier=self.classifier, + penalty=self.penalty, + C=self.C, + alpha=self.alpha, + n_estimators=self.n_estimators, ) self._adapter.fit(self.X_train, self.y_train)