From 5c67a554296b1cb3d7f5bf51bf062a6bf054ebf6 Mon Sep 17 00:00:00 2001 From: ekansh-arora0 Date: Thu, 9 Apr 2026 18:12:02 -0400 Subject: [PATCH] Add EmbeddingCollapseMetric: detect representational collapse in medical imaging embeddings Closes #8808 Signed-off-by: ekansh-arora0 --- monai/metrics/__init__.py | 1 + monai/metrics/embedding_collapse.py | 451 ++++++++++++++++++++++++++++ tests/test_embedding_collapse.py | 445 +++++++++++++++++++++++++++ 3 files changed, 897 insertions(+) create mode 100644 monai/metrics/embedding_collapse.py create mode 100644 tests/test_embedding_collapse.py diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 702e3c48e2..2265dd3a3f 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -16,6 +16,7 @@ from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage +from .embedding_collapse import EmbeddingCollapseMetric, compute_embedding_collapse from .f_beta_score import FBetaScore from .fid import FIDMetric, compute_frechet_distance from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score diff --git a/monai/metrics/embedding_collapse.py b/monai/metrics/embedding_collapse.py new file mode 100644 index 0000000000..b427660b90 --- /dev/null +++ b/monai/metrics/embedding_collapse.py @@ -0,0 +1,451 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from collections.abc import Sequence + +import torch + +from monai.metrics.metric import Metric +from monai.utils import optional_import + +sklearn_silhouette, has_sklearn = optional_import("sklearn.metrics", name="silhouette_score") +sklearn_logistic, has_sklearn_linear = optional_import("sklearn.linear_model", name="LogisticRegression") + +__all__ = ["EmbeddingCollapseMetric", "compute_embedding_collapse"] + +_VALID_REDUCTIONS = ("max", "mean", "none") +_VALID_INDICATORS = frozenset({"centroid_similarity", "effective_rank", "per_class_rank", "domain_shift", "separation"}) + + +class EmbeddingCollapseMetric(Metric): + """Measures representational collapse in neural network embedding spaces. + + Representational collapse occurs when a model's internal embeddings lose + discriminative power - class centroids converge, the effective dimensionality + of the embedding space collapses, or source and target domains become + indistinguishable in representation space. This can happen silently: task + metrics like AUROC or Dice remain unchanged while the embedding space + becomes degenerate. + + This metric computes a suite of collapse indicators from a batch of + embeddings and optional class labels. All scores are in **[0, 1]** where + **higher = more collapsed**. + + Follows the ``FIDMetric`` / ``MMDMetric`` architectural pattern: + tensor-in, tensor-out, no I/O side effects. Core dependencies are + ``torch`` only. ``scikit-learn`` is imported lazily via + ``optional_import`` and is only required for ``separation`` and + ``linear_probe_accuracy``. + + Args: + reduction: how to aggregate individual scores into ``aggregate``. + ``"max"`` returns the worst-case score (recommended for + safety-critical use). ``"mean"`` returns the average of + available scores. ``"none"`` omits the ``aggregate`` key. + include_indicators: optional list of indicator names to compute. + If ``None``, all applicable indicators are computed. + Valid names: ``"centroid_similarity"``, ``"effective_rank"``, + ``"per_class_rank"``, ``"domain_shift"``, ``"separation"``. + + References: + - Roy, O. & Vetterli, M. (2007). The effective rank: A measure of + effective dimensionality. *EUSIPCO*. + - Kornblith, S. et al. (2019). Similarity of neural network + representations revisited. *ICML*. + - Hua, T. et al. (2021). On feature decorrelation in self-supervised + learning. *ICCV*. + + Examples: + >>> metric = EmbeddingCollapseMetric() + >>> emb = torch.randn(100, 768) + >>> labels = torch.randint(0, 2, (100,)) + >>> scores = metric(embeddings=emb, labels=labels) + >>> scores["centroid_similarity"] # tensor scalar in [0, 1] + >>> scores["effective_rank_score"] # tensor scalar in [0, 1] + >>> scores["aggregate"] # worst-case score + """ + + def __init__(self, reduction: str = "max", include_indicators: Sequence[str] | None = None) -> None: + super().__init__() + if reduction not in _VALID_REDUCTIONS: + raise ValueError(f"reduction must be one of {_VALID_REDUCTIONS}, got '{reduction}'") + if include_indicators is not None: + unknown = set(include_indicators) - _VALID_INDICATORS + if unknown: + raise ValueError(f"Unknown include_indicators: {unknown}. Valid: {_VALID_INDICATORS}") + self.reduction = reduction + self.include_indicators = list(include_indicators) if include_indicators is not None else None + + def __call__( + self, + embeddings: torch.Tensor, + labels: torch.Tensor | None = None, + target_embeddings: torch.Tensor | None = None, + target_labels: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor | None]: + """Compute collapse scores. + + Args: + embeddings: float tensor of shape ``[N, D]``. Required. + labels: integer class labels of shape ``[N]``. Required for + ``centroid_similarity``, ``per_class_rank``, and + ``separation``. If ``None``, only label-free indicators + are computed. + target_embeddings: optional second embedding matrix ``[M, D]`` + for cross-domain ``domain_shift`` (CKA) computation. + target_labels: unused; reserved for future class-conditional + domain shift computation. + + Returns: + Dictionary mapping indicator name to a scalar ``torch.Tensor`` + in ``[0, 1]``, or ``None`` when the indicator is not applicable. + Always includes ``"aggregate"`` unless ``reduction="none"``. + + Raises: + ValueError: if ``embeddings`` has fewer than 2 samples, is not + 2-D, or ``labels`` shape does not match ``embeddings``. + """ + return compute_embedding_collapse( + embeddings=embeddings, + labels=labels, + target_embeddings=target_embeddings, + target_labels=target_labels, + reduction=self.reduction, + include_indicators=self.include_indicators, + ) + + +def compute_embedding_collapse( + embeddings: torch.Tensor, + labels: torch.Tensor | None = None, + target_embeddings: torch.Tensor | None = None, + target_labels: torch.Tensor | None = None, + reduction: str = "max", + include_indicators: Sequence[str] | None = None, +) -> dict[str, torch.Tensor | None]: + """Functional form of :class:`EmbeddingCollapseMetric`. + + Computes a suite of representational collapse indicators from embeddings + and optional class labels. All scores are in **[0, 1]** where higher + values indicate more severe collapse. + + Args: + embeddings: float tensor of shape ``[N, D]``. + labels: integer class labels of shape ``[N]``, or ``None``. + target_embeddings: optional second embedding matrix ``[M, D]`` + for cross-domain CKA computation. + target_labels: unused; reserved for future use. + reduction: ``"max"``, ``"mean"``, or ``"none"``. + include_indicators: subset of indicators to compute, or ``None`` + for all applicable indicators. + + Returns: + Dictionary with scalar tensor values or ``None``: + + - ``centroid_similarity``: cosine similarity between L2-normalised + class centroids. ``None`` if fewer than 2 classes. + - ``effective_rank_score``: dimensional collapse score via SVD + effective rank. Always present. + - ``per_class_rank_``: per-class effective rank score. + - ``domain_shift``: linear CKA between source and target domains. + ``None`` if ``target_embeddings`` not provided. + - ``separation``: silhouette-based inter-class separation score. + ``None`` if sklearn unavailable or fewer than 2 classes. + - ``aggregate``: reduced score. Omitted when ``reduction="none"``. + + Raises: + ValueError: if inputs are invalid (shape, reduction, indicators). + """ + _validate_embeddings(embeddings) + emb = embeddings.float() + + if reduction not in _VALID_REDUCTIONS: + raise ValueError(f"reduction must be one of {_VALID_REDUCTIONS}, got '{reduction}'") + + if include_indicators is not None: + unknown = set(include_indicators) - _VALID_INDICATORS + if unknown: + raise ValueError(f"Unknown include_indicators: {unknown}. Valid: {_VALID_INDICATORS}") + + inc = set(include_indicators) if include_indicators is not None else None + scores: dict[str, torch.Tensor | None] = {} + + # -- Label-dependent indicators + if labels is not None: + lbl = labels.long() + if lbl.ndim != 1 or lbl.shape[0] != emb.shape[0]: + raise ValueError(f"labels must be 1-D with shape [{emb.shape[0]}], got {tuple(lbl.shape)}") + + if inc is None or "centroid_similarity" in inc: + scores["centroid_similarity"] = _centroid_similarity(emb, lbl) + + if inc is None or "per_class_rank" in inc: + scores.update(_per_class_rank(emb, lbl)) + + if inc is None or "separation" in inc: + scores["separation"] = _separation(emb, lbl) + else: + if inc is None or "centroid_similarity" in inc: + scores["centroid_similarity"] = None + if inc is None or "separation" in inc: + scores["separation"] = None + + # -- Label-free indicators + if inc is None or "effective_rank" in inc: + scores["effective_rank_score"] = _effective_rank_score(emb) + + # -- Cross-domain indicator + if inc is None or "domain_shift" in inc: + if target_embeddings is not None: + if target_embeddings.ndim != 2: + raise ValueError(f"target_embeddings must be 2-D [M, D], got shape {tuple(target_embeddings.shape)}") + scores["domain_shift"] = _domain_shift(emb, target_embeddings.float()) + else: + scores["domain_shift"] = None + + # -- Aggregate + if reduction != "none": + primary = {"centroid_similarity", "effective_rank_score", "domain_shift", "separation"} + available = [v.to(device=emb.device) for k, v in scores.items() if k in primary and v is not None] + if not available: + scores["aggregate"] = None + elif reduction == "max": + scores["aggregate"] = torch.stack(available).max() + else: + scores["aggregate"] = torch.stack(available).mean() + + return scores + + +# --------------------------------------------------------------------------- +# Individual indicator functions +# --------------------------------------------------------------------------- + + +def _centroid_similarity(emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor | None: + """Cosine similarity between L2-normalised class centroids. + + Args: + emb: ``[N, D]`` float tensor. + labels: ``[N]`` integer class labels. + + Returns: + Scalar tensor in ``[0, 1]``. ``None`` if fewer than 2 classes. + 1.0 = centroids identical (full collapse). + 0.5 = centroids orthogonal (cosine = 0). + 0.0 = centroids anti-parallel (cosine = -1, maximum separation). + """ + unique = torch.unique(labels) + if unique.numel() < 2: + return None + + centroids = [] + for cls in unique: + c = emb[labels == cls].mean(dim=0) + norm = c.norm(p=2) + centroids.append(c / norm if norm > 0 else c) + + ct = torch.stack(centroids) + n = ct.shape[0] + + if n == 2: + raw = torch.dot(ct[0], ct[1]) + else: + sim = ct @ ct.T + pairs = torch.stack([sim[i, j] for i in range(n) for j in range(i + 1, n)]) + raw = pairs.mean() + + return ((raw + 1.0) / 2.0).clamp(0.0, 1.0) + + +def _effective_rank_score(emb: torch.Tensor) -> torch.Tensor: + """Dimensional collapse score via SVD effective rank. + + Uses the entropy-based effective rank from Roy & Vetterli (2007): + ``eff_rank = sum(sv) / max(sv)`` where ``sv`` are the singular values + of the mean-centred embedding matrix. + + If all singular values are zero (constant embeddings after centering), + the matrix is rank-0 — the most extreme form of dimensional collapse — + and the function returns ``1.0``. + + Args: + emb: ``[N, D]`` float tensor. + + Returns: + Scalar tensor in ``[0, 1]``. 1.0 = full dimensional collapse. + 0.0 = full-rank uniform spectrum (no collapse). + """ + centered = emb - emb.mean(dim=0, keepdim=True) + _, sv, _ = torch.linalg.svd(centered, full_matrices=False) + sv_sum = sv.sum() + if sv_sum == 0: + return emb.new_tensor(1.0) + probs = sv / sv_sum + safe_probs = probs.clamp_min(torch.finfo(probs.dtype).tiny) + eff_rank: torch.Tensor = torch.exp(-(probs * safe_probs.log()).sum()) + max_rank: torch.Tensor = emb.new_tensor(float(min(emb.shape[0], emb.shape[1]))) + return (emb.new_tensor(1.0) - eff_rank / max_rank).clamp(0.0, 1.0) + + +def _per_class_rank(emb: torch.Tensor, labels: torch.Tensor) -> dict[str, torch.Tensor | None]: + """Effective rank score computed separately per class. + + Detects asymmetric collapse: one class may use 400 dimensions while + another collapses to 6, which global SVD would average away. + + Args: + emb: ``[N, D]`` float tensor. + labels: ``[N]`` integer class labels. + + Returns: + Dict mapping ``"per_class_rank_"`` to scalar tensor or ``None``. + """ + result: dict[str, torch.Tensor | None] = {} + for cls in torch.unique(labels): + cls_emb = emb[labels == cls] + key = f"per_class_rank_{cls.item()}" + if cls_emb.shape[0] < 2: + result[key] = None + else: + result[key] = _effective_rank_score(cls_emb) + return result + + +def _domain_shift(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor | None: + """Linear CKA between source and target embedding matrices. + + Reference: Kornblith et al. (2019). + + Args: + source: ``[N, D]`` float tensor. + target: ``[M, D]`` float tensor. + + Returns: + Scalar tensor in ``[0, 1]``, or ``None`` if either set has < 2 samples. + 1.0 = representations identical. 0.0 = representations orthogonal. + """ + if source.shape[0] < 2 or target.shape[0] < 2: + return None + + if source.shape[0] != target.shape[0]: + n = min(source.shape[0], target.shape[0]) + g = torch.Generator() + g.manual_seed(42) + if source.shape[0] > n: + source = source[torch.randperm(source.shape[0], generator=g)[:n]] + else: + target = target[torch.randperm(target.shape[0], generator=g)[:n]] + + hsic_xy = _hsic(source, target) + hsic_xx = _hsic(source, source) + hsic_yy = _hsic(target, target) + denom = (hsic_xx * hsic_yy).sqrt() + if denom == 0.0: + return source.new_tensor(0.0) + return (hsic_xy / denom).clamp(0.0, 1.0) + + +def _separation(emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor | None: + """Silhouette-based inter-class separation score. + + Requires scikit-learn. Returns ``None`` if unavailable. + + Args: + emb: ``[N, D]`` float tensor. + labels: ``[N]`` integer class labels. + + Returns: + Scalar tensor in ``[0, 1]``. 1.0 = no separation (collapsed). + 0.0 = perfect separation. ``None`` if sklearn unavailable or + fewer than 2 classes. + """ + if not has_sklearn: + warnings.warn( + "scikit-learn is not installed; 'separation' score skipped. " "Install with: pip install scikit-learn", + stacklevel=3, + ) + return None + + unique = torch.unique(labels) + if unique.numel() < 2: + return None + + try: + sil = sklearn_silhouette( # type: ignore[operator] + emb.detach().cpu().numpy(), labels.detach().cpu().numpy(), metric="cosine" + ) + return torch.tensor((1.0 - float(sil)) / 2.0, dtype=emb.dtype, device=emb.device).clamp(0.0, 1.0) + except Exception as exc: + warnings.warn(f"separation: silhouette_score failed: {exc}", stacklevel=3) + return None + + +def _hsic(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Unbiased linear HSIC estimator.""" + n = x.shape[0] + gram_x = x @ x.T + gram_y = y @ y.T + centering = torch.eye(n, dtype=x.dtype, device=x.device) - torch.ones(n, n, dtype=x.dtype, device=x.device) / n + return torch.sum((centering @ gram_x @ centering) * (centering @ gram_y @ centering)) / ((n - 1) ** 2) + + +def _validate_embeddings(embeddings: torch.Tensor) -> None: + if embeddings.ndim != 2: + raise ValueError(f"embeddings must be 2-D [N, D], got shape {tuple(embeddings.shape)}") + if embeddings.shape[0] < 2: + raise ValueError(f"Need at least 2 samples, got {embeddings.shape[0]}.") + + +# --------------------------------------------------------------------------- +# Utility: linear probe accuracy +# --------------------------------------------------------------------------- + + +def linear_probe_accuracy( + train_embeddings: torch.Tensor, + train_labels: torch.Tensor, + test_embeddings: torch.Tensor, + test_labels: torch.Tensor, + max_iter: int = 1000, +) -> torch.Tensor: + """Fit a linear classifier on embeddings and return test accuracy. + + Used to validate that collapse scores predict downstream task performance. + + Requires scikit-learn. + + Args: + train_embeddings: ``[N_train, D]`` float tensor. + train_labels: ``[N_train]`` integer labels. + test_embeddings: ``[N_test, D]`` float tensor. + test_labels: ``[N_test]`` integer labels. + max_iter: max iterations for ``LogisticRegression``. + + Returns: + Scalar tensor: test accuracy in ``[0, 1]``. + + Raises: + ImportError: if scikit-learn is not installed. + """ + if not has_sklearn_linear: + raise ImportError( + "scikit-learn is required for linear_probe_accuracy. " "Install with: pip install scikit-learn" + ) + + clf = sklearn_logistic(max_iter=max_iter, random_state=42) # type: ignore[operator] + clf.fit(train_embeddings.detach().float().cpu().numpy(), train_labels.detach().cpu().numpy()) + preds = clf.predict(test_embeddings.detach().float().cpu().numpy()) + acc = (preds == test_labels.detach().cpu().numpy()).mean() + return torch.tensor(float(acc)) diff --git a/tests/test_embedding_collapse.py b/tests/test_embedding_collapse.py new file mode 100644 index 0000000000..b406f92215 --- /dev/null +++ b/tests/test_embedding_collapse.py @@ -0,0 +1,445 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.metrics.embedding_collapse import ( + EmbeddingCollapseMetric, + _centroid_similarity, + _domain_shift, + _effective_rank_score, + _per_class_rank, + compute_embedding_collapse, + linear_probe_accuracy, +) + + +class TestEmbeddingCollapseMetricInit(unittest.TestCase): + def test_valid_reductions(self): + for r in ("max", "mean", "none"): + m = EmbeddingCollapseMetric(reduction=r) + self.assertEqual(m.reduction, r) + + def test_invalid_reduction_raises(self): + with self.assertRaises(ValueError): + EmbeddingCollapseMetric(reduction="sum") + + def test_include_indicators_stored(self): + m = EmbeddingCollapseMetric(include_indicators=["centroid_similarity"]) + self.assertEqual(m.include_indicators, ["centroid_similarity"]) + + def test_include_indicators_none_by_default(self): + m = EmbeddingCollapseMetric() + self.assertIsNone(m.include_indicators) + + +class TestComputeEmbeddingCollapseReturnTypes(unittest.TestCase): + """All returned values must be torch.Tensor scalars or None.""" + + def setUp(self): + torch.manual_seed(0) + self.emb = torch.randn(20, 64) + self.lbl = torch.randint(0, 2, (20,)) + + def test_all_values_are_tensor_or_none(self): + scores = compute_embedding_collapse(self.emb, self.lbl) + for k, v in scores.items(): + self.assertTrue( + v is None or isinstance(v, torch.Tensor), + f"Key '{k}' has type {type(v).__name__}, expected Tensor or None", + ) + + def test_tensor_values_are_scalar(self): + scores = compute_embedding_collapse(self.emb, self.lbl) + for k, v in scores.items(): + if isinstance(v, torch.Tensor): + self.assertEqual(v.ndim, 0, f"Key '{k}' should be a scalar tensor, got shape {v.shape}") + + def test_tensor_values_in_unit_interval(self): + scores = compute_embedding_collapse(self.emb, self.lbl) + for k, v in scores.items(): + if isinstance(v, torch.Tensor): + self.assertGreaterEqual(float(v), 0.0, f"Key '{k}' = {float(v):.4f} < 0") + self.assertLessEqual(float(v), 1.0, f"Key '{k}' = {float(v):.4f} > 1") + + +class TestAggregateReduction(unittest.TestCase): + def setUp(self): + torch.manual_seed(1) + self.emb = torch.randn(20, 32) + self.lbl = torch.randint(0, 2, (20,)) + + def test_reduction_max_present(self): + scores = compute_embedding_collapse(self.emb, self.lbl, reduction="max") + self.assertIn("aggregate", scores) + self.assertIsInstance(scores["aggregate"], torch.Tensor) + + def test_reduction_mean_present(self): + scores = compute_embedding_collapse(self.emb, self.lbl, reduction="mean") + self.assertIn("aggregate", scores) + + def test_reduction_none_absent(self): + scores = compute_embedding_collapse(self.emb, self.lbl, reduction="none") + self.assertNotIn("aggregate", scores) + + def test_reduction_max_geq_mean(self): + scores_max = compute_embedding_collapse(self.emb, self.lbl, reduction="max") + scores_mean = compute_embedding_collapse(self.emb, self.lbl, reduction="mean") + self.assertGreaterEqual(float(scores_max["aggregate"]), float(scores_mean["aggregate"])) + + +class TestNoLabels(unittest.TestCase): + def setUp(self): + torch.manual_seed(2) + self.emb = torch.randn(20, 32) + + def test_centroid_similarity_none_without_labels(self): + scores = compute_embedding_collapse(self.emb) + self.assertIsNone(scores["centroid_similarity"]) + + def test_separation_none_without_labels(self): + scores = compute_embedding_collapse(self.emb) + self.assertIsNone(scores["separation"]) + + def test_effective_rank_present_without_labels(self): + scores = compute_embedding_collapse(self.emb) + self.assertIsNotNone(scores["effective_rank_score"]) + + def test_aggregate_uses_effective_rank_only(self): + scores = compute_embedding_collapse(self.emb, reduction="max") + self.assertAlmostEqual(float(scores["aggregate"]), float(scores["effective_rank_score"]), places=5) + + +class TestCentroidSimilarity(unittest.TestCase): + def test_identical_centroids_give_one(self): + # All embeddings identical -> both centroids identical -> score = 1.0 + emb = torch.ones(20, 16) + lbl = torch.tensor([0] * 10 + [1] * 10, dtype=torch.long) + score = _centroid_similarity(emb, lbl) + self.assertIsNotNone(score) + self.assertAlmostEqual(float(score), 1.0, places=4) + + def test_orthogonal_centroids_give_zero(self): + # Class 0 centroid = e1, class 1 centroid = e2 -> cosine = 0 -> score = 0.5 + emb = torch.zeros(4, 4) + emb[0, 0] = 1.0 + emb[1, 0] = 1.0 + emb[2, 1] = 1.0 + emb[3, 1] = 1.0 + lbl = torch.tensor([0, 0, 1, 1], dtype=torch.long) + score = _centroid_similarity(emb, lbl) + self.assertIsNotNone(score) + self.assertAlmostEqual(float(score), 0.5, places=4) + + def test_single_class_returns_none(self): + emb = torch.randn(10, 8) + lbl = torch.zeros(10, dtype=torch.long) + score = _centroid_similarity(emb, lbl) + self.assertIsNone(score) + + def test_score_in_unit_interval(self): + torch.manual_seed(3) + emb = torch.randn(30, 64) + lbl = torch.randint(0, 3, (30,)) + score = _centroid_similarity(emb, lbl) + if score is not None: + self.assertGreaterEqual(float(score), 0.0) + self.assertLessEqual(float(score), 1.0) + + def test_formula_matches_manual(self): + torch.manual_seed(4) + n = 10 + emb_a = torch.randn(n, 16) + emb_b = torch.randn(n, 16) + emb = torch.cat([emb_a, emb_b]) + lbl = torch.tensor([0] * n + [1] * n, dtype=torch.long) + + c0 = emb_a.mean(0) + c1 = emb_b.mean(0) + c0 = c0 / c0.norm() + c1 = c1 / c1.norm() + expected = ((torch.dot(c0, c1) + 1.0) / 2.0).clamp(0.0, 1.0) + + score = _centroid_similarity(emb, lbl) + self.assertAlmostEqual(float(score), float(expected), places=5) + + +class TestEffectiveRankScore(unittest.TestCase): + def test_zero_variance_embeddings_give_one(self): + """All identical embeddings after centering -> sv_sum == 0 -> score = 1.0 (maximal collapse).""" + emb = torch.ones(10, 8) # all identical, centered = all zeros + score = _effective_rank_score(emb) + self.assertIsInstance(score, torch.Tensor) + self.assertAlmostEqual(float(score), 1.0, places=5) + + def test_rank_one_matrix_gives_high_score(self): + # Near-rank-1: one dominant direction + tiny noise. + # After mean-centering, the dominant direction survives. + torch.manual_seed(99) + n, d = 20, 16 + # Random unit vector as the dominant direction + direction = torch.randn(1, d) + direction = direction / direction.norm() + # All rows point in the same direction with small perturbation + scales = torch.randn(n, 1) * 5.0 # varying magnitudes + emb = scales * direction + torch.randn(n, d) * 0.01 + score = _effective_rank_score(emb) + self.assertGreater(float(score), 0.8) + + def test_random_matrix_lower_score(self): + torch.manual_seed(5) + emb = torch.randn(50, 64) + score = _effective_rank_score(emb) + # A full-rank random matrix should have low collapse score + self.assertLess(float(score), 0.8) + + def test_score_in_unit_interval(self): + torch.manual_seed(6) + for _ in range(10): + n = torch.randint(5, 50, (1,)).item() + d = torch.randint(4, 32, (1,)).item() + emb = torch.randn(int(n), int(d)) + score = _effective_rank_score(emb) + self.assertGreaterEqual(float(score), 0.0) + self.assertLessEqual(float(score), 1.0) + + def test_formula_matches_manual(self): + torch.manual_seed(7) + emb = torch.randn(20, 16) + centered = emb - emb.mean(0) + _, sv, _ = torch.linalg.svd(centered, full_matrices=False) + sv_sum = sv.sum() + probs = sv / sv_sum + safe_probs = probs.clamp_min(torch.finfo(probs.dtype).tiny) + eff_rank = (-(probs * safe_probs.log()).sum()).exp() + expected = (1.0 - eff_rank / min(20, 16)).clamp(0.0, 1.0) + score = _effective_rank_score(emb) + self.assertAlmostEqual(float(score), float(expected), places=5) + + +class TestPerClassRank(unittest.TestCase): + def test_keys_present_for_each_class(self): + torch.manual_seed(8) + emb = torch.randn(20, 16) + lbl = torch.tensor([0] * 10 + [1] * 10, dtype=torch.long) + result = _per_class_rank(emb, lbl) + self.assertIn("per_class_rank_0", result) + self.assertIn("per_class_rank_1", result) + + def test_single_sample_class_returns_none(self): + emb = torch.randn(5, 8) + lbl = torch.tensor([0, 0, 0, 0, 1], dtype=torch.long) + result = _per_class_rank(emb, lbl) + self.assertIsNone(result["per_class_rank_1"]) + + def test_asymmetric_collapse_detectable(self): + # Class 0: random full-rank embeddings + # Class 1: near-rank-1 (one dominant direction + tiny noise) + torch.manual_seed(9) + emb_0 = torch.randn(10, 16) + + direction = torch.randn(1, 16) + direction = direction / direction.norm() + scales = torch.randn(10, 1) * 5.0 + emb_1 = scales * direction + torch.randn(10, 16) * 0.01 + + emb = torch.cat([emb_0, emb_1]) + lbl = torch.tensor([0] * 10 + [1] * 10, dtype=torch.long) + result = _per_class_rank(emb, lbl) + score_0 = float(result["per_class_rank_0"]) + score_1 = float(result["per_class_rank_1"]) + # Class 1 (near-rank-1) should be more collapsed than class 0 (random) + self.assertGreater(score_1, score_0) + + +class TestDomainShift(unittest.TestCase): + def test_identical_matrices_give_one(self): + torch.manual_seed(10) + src = torch.randn(10, 8) + score = _domain_shift(src, src.clone()) + self.assertIsNotNone(score) + self.assertAlmostEqual(float(score), 1.0, places=4) + + def test_score_in_unit_interval(self): + torch.manual_seed(11) + src = torch.randn(10, 8) + tgt = torch.randn(10, 8) + score = _domain_shift(src, tgt) + self.assertIsNotNone(score) + self.assertGreaterEqual(float(score), 0.0) + self.assertLessEqual(float(score), 1.0) + + def test_single_sample_returns_none(self): + src = torch.randn(1, 8) + tgt = torch.randn(10, 8) + self.assertIsNone(_domain_shift(src, tgt)) + self.assertIsNone(_domain_shift(tgt, src)) + + def test_single_sample_target_returns_none_not_raises(self): + """target_embeddings with 1 sample should return None, not raise ValueError.""" + torch.manual_seed(13) + emb = torch.randn(10, 8) + lbl = torch.randint(0, 2, (10,)) + single_target = torch.randn(1, 8) + # Should not raise — _domain_shift handles n<2 gracefully + scores = compute_embedding_collapse(emb, lbl, target_embeddings=single_target) + self.assertIsNone(scores["domain_shift"]) + + def test_3d_target_embeddings_raises(self): + """target_embeddings with wrong ndim should still raise ValueError.""" + emb = torch.randn(10, 8) + bad_target = torch.randn(5, 8, 2) + with self.assertRaises(ValueError): + compute_embedding_collapse(emb, target_embeddings=bad_target) + torch.manual_seed(12) + src = torch.randn(20, 8) + tgt = torch.randn(10, 8) + score = _domain_shift(src, tgt) + self.assertIsNotNone(score) + self.assertGreaterEqual(float(score), 0.0) + self.assertLessEqual(float(score), 1.0) + + +class TestValidation(unittest.TestCase): + def test_1d_input_raises(self): + with self.assertRaises(ValueError): + compute_embedding_collapse(torch.randn(10)) + + def test_single_sample_raises(self): + with self.assertRaises(ValueError): + compute_embedding_collapse(torch.randn(1, 8)) + + def test_3d_input_raises(self): + with self.assertRaises(ValueError): + compute_embedding_collapse(torch.randn(4, 8, 8)) + + def test_invalid_reduction_raises(self): + with self.assertRaises(ValueError): + compute_embedding_collapse(torch.randn(10, 8), reduction="sum") + + def test_unknown_include_indicator_raises(self): + with self.assertRaises(ValueError): + compute_embedding_collapse(torch.randn(10, 8), include_indicators=["typo_metric"]) + + def test_mismatched_labels_shape_raises(self): + emb = torch.randn(10, 8) + bad_labels = torch.zeros(5, dtype=torch.long) # wrong length + with self.assertRaises(ValueError): + compute_embedding_collapse(emb, labels=bad_labels) + + def test_2d_labels_raises(self): + emb = torch.randn(10, 8) + bad_labels = torch.zeros(10, 2, dtype=torch.long) # 2D + with self.assertRaises(ValueError): + compute_embedding_collapse(emb, labels=bad_labels) + + +class TestIncludeIndicators(unittest.TestCase): + def test_only_requested_indicators_computed(self): + torch.manual_seed(13) + emb = torch.randn(20, 16) + lbl = torch.randint(0, 2, (20,)) + scores = compute_embedding_collapse(emb, lbl, include_indicators=["centroid_similarity"]) + self.assertIn("centroid_similarity", scores) + self.assertNotIn("effective_rank_score", scores) + self.assertNotIn("separation", scores) + + def test_effective_rank_only(self): + torch.manual_seed(14) + emb = torch.randn(20, 16) + scores = compute_embedding_collapse(emb, include_indicators=["effective_rank"]) + self.assertIn("effective_rank_score", scores) + self.assertNotIn("centroid_similarity", scores) + + +class TestCollapsedEmbeddingsDetected(unittest.TestCase): + """End-to-end: a deliberately collapsed embedding space scores high.""" + + def test_collapsed_mlp_scores_above_threshold(self): + # Simulate a model whose final layer always outputs the same vector + # regardless of input — perfect collapse. + n = 40 + emb = torch.ones(n, 32) # all identical + lbl = torch.tensor([0] * (n // 2) + [1] * (n // 2), dtype=torch.long) + + scores = compute_embedding_collapse(emb, lbl, reduction="max") + + # centroid_similarity should be 1.0 (identical centroids) + self.assertAlmostEqual(float(scores["centroid_similarity"]), 1.0, places=4) + # aggregate should be critical + self.assertGreater(float(scores["aggregate"]), 0.8) + + def test_healthy_embeddings_score_low(self): + # Well-separated classes: class 0 = +e1, class 1 = -e1 + torch.manual_seed(15) + n = 20 + emb = torch.zeros(n, 16) + emb[: n // 2, 0] = 1.0 # class 0 centroid = +e1 + emb[n // 2 :, 0] = -1.0 # class 1 centroid = -e1 + lbl = torch.tensor([0] * (n // 2) + [1] * (n // 2), dtype=torch.long) + + scores = compute_embedding_collapse(emb, lbl, reduction="max") + + # centroid_similarity should be 0.0 (opposite centroids -> raw cosine = -1 -> score = 0) + self.assertAlmostEqual(float(scores["centroid_similarity"]), 0.0, places=4) + + +class TestLinearProbeAccuracy(unittest.TestCase): + def test_linearly_separable_data_high_accuracy(self): + # Class 0: positive first dim, Class 1: negative first dim + n = 50 + train_emb = torch.zeros(n, 8) + train_emb[: n // 2, 0] = 1.0 + train_emb[n // 2 :, 0] = -1.0 + train_lbl = torch.tensor([0] * (n // 2) + [1] * (n // 2), dtype=torch.long) + + test_emb = train_emb.clone() + test_lbl = train_lbl.clone() + + try: + acc = linear_probe_accuracy(train_emb, train_lbl, test_emb, test_lbl) + self.assertIsInstance(acc, torch.Tensor) + self.assertGreater(float(acc), 0.9) + except ImportError: + self.skipTest("scikit-learn not installed") + + def test_returns_tensor(self): + torch.manual_seed(16) + emb = torch.randn(20, 8) + lbl = torch.randint(0, 2, (20,)) + try: + acc = linear_probe_accuracy(emb, lbl, emb, lbl) + self.assertIsInstance(acc, torch.Tensor) + self.assertEqual(acc.ndim, 0) + self.assertGreaterEqual(float(acc), 0.0) + self.assertLessEqual(float(acc), 1.0) + except ImportError: + self.skipTest("scikit-learn not installed") + + def test_raises_without_sklearn(self): + from unittest.mock import patch + + with patch("monai.metrics.embedding_collapse.has_sklearn_linear", False): + with self.assertRaises(ImportError): + linear_probe_accuracy( + torch.randn(10, 4), + torch.zeros(10, dtype=torch.long), + torch.randn(5, 4), + torch.zeros(5, dtype=torch.long), + ) + + +if __name__ == "__main__": + unittest.main()