🐛 Bug
The retrieval metrics (RetrievalMAP, RetrievalRecall, etc.) crash or allocate excessive memory when the indexes tensor contains sparse or high-valued integers, even if the number of unique queries is small.
This is because torchmetrics.utilities.data._bincount() relies on index.max() to determine the size of internal tensors. When deterministic mode is enabled (or on XLA/MPS), the fallback implementation can allocate massive [len(indexes), index.max()] tensors, leading to out-of-memory (OOM) errors.
✅ Expected behavior
The metrics should group predictions by query regardless of the numerical values in indexes. The actual values shouldn't impact performance or memory.
To Reproduce
Steps:
- Simulate a retrieval task with a few queries using high-value IDs
- Use
RetrievalMAP() or similar
- Call
.update() and .compute()
Code sample
import torch
from torchmetrics.retrieval import RetrievalMAP
# Simulate predictions and labels for 3 queries with sparse/high index values
preds = torch.tensor([0.2, 0.8, 0.4, 0.9, 0.1, 0.3])
target = torch.tensor([0, 1, 0, 1, 1, 0])
indexes = torch.tensor([1000, 1000, 50000, 50000, 90000000, 90000000]) # only 3 unique queries
# Enable deterministic mode (triggers fallback path)
torch.use_deterministic_algorithms(True)
metric = RetrievalMAP()
metric.update(preds, target, indexes)
# This line will likely cause a crash or massive memory use due to high index values
result = metric.compute()
print(result)
🔥 What happens
With torch.use_deterministic_algorithms(True) enabled:
- The
_bincount() fallback tries to allocate a tensor of shape [len(indexes), index.max()] = [6, 90000001]
- This results in >67 GB memory allocation and often crashes
Without deterministic mode, torch.bincount() is used directly, which also scales poorly if index.max() is large.
Environment
- TorchMetrics version: 1.8.1 (reproduced on latest)
- Python version: 3.12.10
- PyTorch version: 2.8.0
- OS: Ubuntu 22.04 / Windows 11
- Device: CPU and CUDA (but also applies to MPS/XLA)
Additional context
This is especially common in real-world retrieval problems where indexes come from:
- Row numbers or IDs in large datasets
- Sparse query IDs (e.g., from database keys)
Since the metric only uses indexes to group elements, their actual values are irrelevant — only equality matters.
🐛 Bug
The retrieval metrics (
RetrievalMAP,RetrievalRecall, etc.) crash or allocate excessive memory when theindexestensor contains sparse or high-valued integers, even if the number of unique queries is small.This is because
torchmetrics.utilities.data._bincount()relies onindex.max()to determine the size of internal tensors. When deterministic mode is enabled (or on XLA/MPS), the fallback implementation can allocate massive[len(indexes), index.max()]tensors, leading to out-of-memory (OOM) errors.✅ Expected behavior
The metrics should group predictions by query regardless of the numerical values in
indexes. The actual values shouldn't impact performance or memory.To Reproduce
Steps:
RetrievalMAP()or similar.update()and.compute()Code sample
🔥 What happens
With
torch.use_deterministic_algorithms(True)enabled:_bincount()fallback tries to allocate a tensor of shape[len(indexes), index.max()] = [6, 90000001]Without deterministic mode,
torch.bincount()is used directly, which also scales poorly ifindex.max()is large.Environment
Additional context
This is especially common in real-world retrieval problems where
indexescome from:Since the metric only uses
indexesto group elements, their actual values are irrelevant — only equality matters.