diff --git a/tests/e2e/mnist.py b/tests/e2e/mnist.py index bc9101a8..176d0a2f 100644 --- a/tests/e2e/mnist.py +++ b/tests/e2e/mnist.py @@ -99,8 +99,8 @@ def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4): nn.Linear(hidden_size, self.num_classes), ) - self.val_accuracy = Accuracy() - self.test_accuracy = Accuracy() + self.val_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes) + self.test_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes) def forward(self, x): x = self.model(x) diff --git a/tests/e2e/mnist_pip_requirements.txt b/tests/e2e/mnist_pip_requirements.txt index 6afcba78..ec87256b 100644 --- a/tests/e2e/mnist_pip_requirements.txt +++ b/tests/e2e/mnist_pip_requirements.txt @@ -2,5 +2,5 @@ torch==2.7.1+cu118 torchvision==0.22.1+cu118 pytorch_lightning==1.9.5 -torchmetrics==0.9.1 +torchmetrics==1.8.2 minio