From 9b63968cc1ed74df60b7ed0c9907994f48c7087a Mon Sep 17 00:00:00 2001 From: "warre.veys" Date: Wed, 6 May 2026 15:31:59 +0200 Subject: [PATCH 1/3] feat: graded relevance support for ranking metrics RankingDataset now accepts an optional target_relevance field aligned 1-to-1 with target_indices. ndcg@k uses a (2^rel - 1) gain when graded labels are provided; binary metrics (map, mrr, recall@k, hit@k, rp@k) ignore the field. When target_relevance is None the binary nDCG output is numerically identical to the previous implementation, so existing tasks behave unchanged. Indices are sorted on dedup and relevance is permuted in lockstep so (idx, rel) pairs stay aligned through both _postprocess_indices and the RESOLVE strategies for duplicate queries / targets. Adds custom_task_graded_relevance_example.py demonstrating how to define a graded task and how the same dataset can serve nDCG (graded) and MAP/MRR/recall (binary) in one evaluate() call. --- CHANGELOG.md | 11 ++ README.md | 2 +- examples/custom_task_example.py | 4 + .../custom_task_graded_relevance_example.py | 138 ++++++++++++++++++ src/workrb/metrics/ranking.py | 84 +++++++---- src/workrb/tasks/abstract/ranking_base.py | 135 ++++++++++++++--- tests/test_duplicate_strategy.py | 113 ++++++++++++++ tests/test_ranking_metrics.py | 100 +++++++++++++ 8 files changed, 543 insertions(+), 44 deletions(-) create mode 100644 examples/custom_task_graded_relevance_example.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ed6d0e..c3a8207 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +## Unreleased + +### Feat + +- graded relevance support for ranking metrics: ``RankingDataset`` accepts an + optional ``target_relevance`` field aligned 1-to-1 with ``target_indices``. + ``ndcg@k`` uses a (2^rel - 1) gain when graded labels are provided; binary + metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, ``rp@k``) ignore the + field. Binary nDCG behavior is preserved when ``target_relevance`` is + ``None``. See ``examples/custom_task_graded_relevance_example.py``. + ## v0.5.1 (2026-03-13) ### Feat diff --git a/README.md b/README.md index 5deeb16..ba4d548 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,7 @@ Each aggregation provides 95% confidence intervals (replace `mean` with `ci_marg | --- | --- | | `map` | Mean Average Precision | | `mrr` | Mean Reciprocal Rank | -| `ndcg@k` | Normalized Discounted Cumulative Gain with support for top-k cutoff.| +| `ndcg@k` | Normalized Discounted Cumulative Gain with support for top-k cutoff. Uses a (2^rel - 1) gain when the dataset provides graded `target_relevance`; falls back to binary positives otherwise. See `examples/custom_task_graded_relevance_example.py`. | | `recall@k` | Recall at k (e.g. `recall@5`, `recall@10`) | | `hit@k` | Hit rate at k — binary: is any relevant item in the top-k? | | `rp@k` | R-Precision at k — precision relative to total relevant items | diff --git a/examples/custom_task_example.py b/examples/custom_task_example.py index 99edb3e..910cb8e 100644 --- a/examples/custom_task_example.py +++ b/examples/custom_task_example.py @@ -4,6 +4,10 @@ This example demonstrates how to create a custom ranking task that can be used with the WorkRB framework. Custom tasks should inherit from workrb.tasks.RankingTask and implement the required abstract methods. + +Uses binary relevance: every entry in ``target_indices`` is treated as relevant=1 +by all metrics. For graded relevance (e.g. nDCG@k with a 1-2-3 scale), see +``custom_task_graded_relevance_example.py``. """ import workrb diff --git a/examples/custom_task_graded_relevance_example.py b/examples/custom_task_graded_relevance_example.py new file mode 100644 index 0000000..9afc606 --- /dev/null +++ b/examples/custom_task_graded_relevance_example.py @@ -0,0 +1,138 @@ +""" +Custom Graded-Relevance Ranking Task + +Extends ``custom_task_example.py`` with graded relevance: each (query, target) +pair carries an ordinal grade instead of a binary relevant/not-relevant flag. + +How a graded task differs from a binary one: + +1. Pass an additional ``target_relevance`` list to ``RankingDataset``, aligned + 1-to-1 with ``target_indices``. Items NOT listed in ``target_indices`` are + implicitly grade 0 (irrelevant). +2. Add ``ndcg@k`` to ``default_metrics``. The (2^rel - 1) gain in nDCG uses the + grades; binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, ``rp@k``) + ignore them and treat every entry in ``target_indices`` as relevant=1. + +The grading scale is up to the task. Common choices: {1, 2, 3} (this example), +{1, 2, 3, 4} (TREC-style), or fractional in [0, 1]. Values must be non-negative. +""" + +import workrb +from workrb.registry import register_task +from workrb.tasks.abstract.base import DatasetSplit, LabelType, Language +from workrb.tasks.abstract.ranking_base import RankingDataset, RankingTaskGroup +from workrb.types import ModelInputType + + +@register_task() +class GradedJob2SkillTask(workrb.tasks.RankingTask): + """Job-to-skill ranking with three relevance grades. + + Grade 3 — primary skill (the role is defined by it). + Grade 2 — secondary skill (clearly expected). + Grade 1 — nice-to-have (mentioned but not required). + Grade 0 — irrelevant; never appears in ``target_indices`` (implicit). + """ + + @property + def name(self) -> str: + return "GradedJob2SkillTask" + + @property + def description(self) -> str: + return "Job-to-skill ranking with graded relevance for nDCG evaluation." + + @property + def query_input_type(self) -> ModelInputType: + return ModelInputType.JOB_TITLE + + @property + def target_input_type(self) -> ModelInputType: + return ModelInputType.SKILL + + @property + def default_metrics(self) -> list[str]: + # nDCG@k is the headline metric here. MAP/MRR/recall are kept as a + # binary sanity check: they treat every graded positive as relevant=1 + # and ignore the grades, so they still produce well-defined numbers. + return ["ndcg@5", "ndcg@10", "map", "mrr", "recall@5"] + + @property + def task_group(self) -> RankingTaskGroup: + return RankingTaskGroup.JOB2SKILL + + @property + def label_type(self) -> LabelType: + return LabelType.MULTI_LABEL + + @property + def supported_query_languages(self) -> list[Language]: + return [Language.EN] + + @property + def supported_target_languages(self) -> list[Language]: + return [Language.EN] + + def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset: + """Load a tiny in-memory graded dataset. + + target_indices and target_relevance are aligned 1-to-1: position i in + ``target_relevance[q]`` is the grade for ``target_indices[q][i]``. + """ + queries = [ + "Machine learning engineer", + "Data scientist", + "Backend software developer", + ] + + targets = [ + "Python programming", # 0 + "Machine learning", # 1 + "Data analysis", # 2 + "Software engineering", # 3 + "Statistics", # 4 + "Deep learning", # 5 + "Web development", # 6 + "Database management", # 7 + ] + + # Each list pairs (skill_idx -> grade). Skills not listed for a query + # are implicitly grade 0. + target_indices = [ + [1, 5, 0, 3], # ML engineer: ML(3), DL(3), Python(2), SWE(1) + [1, 2, 4, 5], # Data scientist: ML(3), DA(3), Stats(2), DL(2) + [0, 3, 7, 6], # Backend dev: Python(3), SWE(3), DB(2), Web(1) + ] + target_relevance = [ + [3.0, 3.0, 2.0, 1.0], + [3.0, 3.0, 2.0, 2.0], + [3.0, 3.0, 2.0, 1.0], + ] + + return RankingDataset( + query_texts=queries, + target_indices=target_indices, + target_space=targets, + dataset_id=dataset_id, + target_relevance=target_relevance, + ) + + +if __name__ == "__main__": + print("🚀 Custom Graded-Relevance Task Example") + print("=" * 50) + + model = workrb.models.BiEncoderModel("all-MiniLM-L6-v2") + tasks = [GradedJob2SkillTask(languages=["en"], split="test")] + + # Run with the task's default_metrics, which mixes nDCG (graded) with + # MAP/MRR/recall (binary). The same dataset feeds both — graded metrics + # consult target_relevance, binary metrics ignore it. + results = workrb.evaluate( + model, + tasks, + output_folder="results/graded_task_demo", + description="Custom ranking task with graded relevance.", + force_restart=True, + ) + print(results) diff --git a/src/workrb/metrics/ranking.py b/src/workrb/metrics/ranking.py index 317d001..d5afbce 100644 --- a/src/workrb/metrics/ranking.py +++ b/src/workrb/metrics/ranking.py @@ -13,18 +13,29 @@ def calculate_ranking_metrics( "map", "rp@10", ), + pos_label_relevance: list[list[float]] | None = None, ) -> dict[str, float]: - """ - Calculate ranking metrics for evaluation. - - Args: - prediction_matrix: Similarity/prediction matrix of shape (n_queries, n_targets) - pos_label_idxs: List of lists containing positive label indices for each query - metrics: List of metric names to compute + """Calculate ranking metrics for evaluation. + + Parameters + ---------- + prediction_matrix : torch.Tensor or np.ndarray + Similarity/prediction matrix of shape (n_queries, n_targets). + pos_label_idxs : list[list[int]] + Positive label indices for each query. + metrics : Sequence[str] + Metric names to compute. + pos_label_relevance : list[list[float]] or None, optional + Optional graded relevance per positive, aligned 1-to-1 with + ``pos_label_idxs``. When ``None``, every positive is treated as relevance 1.0 + (binary fallback, identical to current behavior). Only consulted by graded + metrics (``ndcg``); binary metrics (``map``, ``mrr``, ``recall@k``, + ``hit@k``, ``rp@k``) ignore it. Returns ------- - Dictionary mapping metric names to values + dict[str, float] + Dictionary mapping metric names to values. """ # Convert to numpy if needed if isinstance(prediction_matrix, torch.Tensor): @@ -68,12 +79,10 @@ def _metric_k_split(metric: str) -> tuple[str, int | None]: results[metric] = _calculate_hit_at_k(sorted_indices, pos_label_idxs, k) elif base_metric == "ndcg": - if k is not None: - results[metric] = _calculate_ndcg(sorted_indices, pos_label_idxs, k) - else: - results[metric] = _calculate_ndcg( - sorted_indices, pos_label_idxs, sorted_indices.shape[1] - ) + cutoff = k if k is not None else sorted_indices.shape[1] + results[metric] = _calculate_ndcg( + sorted_indices, pos_label_idxs, pos_label_relevance, cutoff + ) else: raise ValueError(f"Unknown ranking metric '{metric}'") @@ -204,24 +213,47 @@ def _calculate_rp_at_k( return float(np.mean(rp_scores)) if rp_scores else 0.0 -def _calculate_ndcg(sorted_indices: np.ndarray, pos_label_idxs: list[list[int]], k: int) -> float: - """Calculate Normalized Discounted Cumulative Gain@K (binary relevance).""" +def _calculate_ndcg( + sorted_indices: np.ndarray, + pos_label_idxs: list[list[int]], + pos_label_relevance: list[list[float]] | None, + k: int, +) -> float: + """Calculate nDCG@K with the (2^rel - 1) gain and log2(i+2) discount. + + When ``pos_label_relevance`` is None, every positive is treated as + relevance 1.0 (binary fallback). With binary positives the gain + ``(2^1 - 1) = 1`` matches the binary implementation exactly. + """ ndcg_scores = [] for i, pos_labels in enumerate(pos_label_idxs): if len(pos_labels) == 0: continue - pos_labels_set = set(pos_labels) - - # DCG@k: sum of 1/log2(rank+1) for relevant items in top-k - dcg = 0.0 - for rank_idx in range(min(k, len(sorted_indices[i]))): - if sorted_indices[i][rank_idx] in pos_labels_set: - dcg += 1.0 / np.log2(rank_idx + 2) # 0-based -> log2(pos+1) - - # IDCG@k: ideal DCG with all relevant items ranked first - idcg = sum(1.0 / np.log2(j + 2) for j in range(min(k, len(pos_labels)))) + if pos_label_relevance is None: + relevance_by_idx = dict.fromkeys(pos_labels, 1.0) + else: + relevance_by_idx = { + idx: float(rel) + for idx, rel in zip(pos_labels, pos_label_relevance[i], strict=True) + } + + cutoff = min(k, len(sorted_indices[i])) + + gains = np.array( + [ + (2.0 ** relevance_by_idx.get(int(idx), 0.0)) - 1.0 + for idx in sorted_indices[i][:cutoff] + ] + ) + discounts = 1.0 / np.log2(np.arange(cutoff) + 2) + dcg = float(np.sum(gains * discounts)) + + ideal_relevances = sorted(relevance_by_idx.values(), reverse=True)[:cutoff] + ideal_gains = np.array([(2.0**rel) - 1.0 for rel in ideal_relevances]) + ideal_discounts = 1.0 / np.log2(np.arange(len(ideal_gains)) + 2) + idcg = float(np.sum(ideal_gains * ideal_discounts)) if idcg > 0: ndcg_scores.append(dcg / idcg) diff --git a/src/workrb/tasks/abstract/ranking_base.py b/src/workrb/tasks/abstract/ranking_base.py index ed99200..28371ba 100644 --- a/src/workrb/tasks/abstract/ranking_base.py +++ b/src/workrb/tasks/abstract/ranking_base.py @@ -61,6 +61,7 @@ def __init__( target_indices: list[list[int]], target_space: list[str], dataset_id: str, + target_relevance: list[list[float]] | None = None, duplicate_query_strategy: DuplicateStrategy = DuplicateStrategy.RESOLVE, duplicate_target_strategy: DuplicateStrategy = DuplicateStrategy.RESOLVE, ): @@ -71,11 +72,20 @@ def __init__( query_texts : list[str] List of query strings. target_indices : list[list[int]] - List of lists containing indices into the target vocabulary. + List of lists containing indices into the target vocabulary. Items not + listed for a query are treated as having relevance 0 (unjudged / + irrelevant) by graded metrics. target_space : list[str] List of target vocabulary strings. dataset_id : str Unique identifier for this dataset. + target_relevance : list[list[float]] or None, optional + Optional graded relevance per positive, aligned 1-to-1 with + ``target_indices``. When ``None``, every entry in ``target_indices`` is + treated as binary relevance 1.0. Used by graded metrics such as + ``ndcg@k``; binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, + ``rp@k``) ignore this field. Values must be non-negative; the scale is + up to the task (e.g. {1, 2, 3} or {0.0..1.0}). duplicate_query_strategy : DuplicateStrategy How to handle duplicate query texts. ALLOW silently accepts them, RAISE raises on duplicates, RESOLVE merges their target_indices via union. @@ -84,7 +94,9 @@ def __init__( RAISE raises on duplicates, RESOLVE keeps first occurrence and remaps indices. """ self.query_texts = self._postprocess_texts(query_texts) - self.target_indices = self._postprocess_indices(target_indices) + self.target_indices, self.target_relevance = self._postprocess_indices( + target_indices, target_relevance + ) self.target_space = self._postprocess_texts(target_space) self.dataset_id = dataset_id @@ -121,27 +133,62 @@ def _resolve_duplicate_targets(self) -> None: duplicates, ) self.target_space = new_target_space - self.target_indices = [ - sorted(set(old_to_new[idx] for idx in idx_list)) for idx_list in self.target_indices - ] + if self.target_relevance is None: + self.target_indices = [ + sorted(set(old_to_new[idx] for idx in idx_list)) + for idx_list in self.target_indices + ] + else: + # Remap indices, dedup, and keep relevance from the first occurrence + # of each remapped index so (idx, rel) pairs stay aligned. + new_indices: list[list[int]] = [] + new_relevance: list[list[float]] = [] + for idx_list, rel_list in zip( + self.target_indices, self.target_relevance, strict=True + ): + seen_idx: dict[int, float] = {} + for idx, rel in zip(idx_list, rel_list, strict=True): + new_idx = old_to_new[idx] + if new_idx not in seen_idx: + seen_idx[new_idx] = rel + sorted_pairs = sorted(seen_idx.items()) + new_indices.append([idx for idx, _ in sorted_pairs]) + new_relevance.append([rel for _, rel in sorted_pairs]) + self.target_indices = new_indices + self.target_relevance = new_relevance def _resolve_duplicate_queries(self) -> None: - """Deduplicate query_texts, merging target_indices via union.""" + """Deduplicate query_texts, merging target_indices via union. + + When ``target_relevance`` is set, the merge keeps the relevance from the + first query occurrence for each index; relevance values from later + duplicates of the same (query, index) pair are dropped. + """ seen: dict[str, int] = {} new_queries: list[str] = [] - new_indices: list[list[int]] = [] + new_pairs: list[dict[int, float]] = [] duplicates: list[str] = [] - for query, idx_list in zip(self.query_texts, self.target_indices): + graded = self.target_relevance is not None + relevance_iter = ( + self.target_relevance + if graded + else [[1.0] * len(idx_list) for idx_list in self.target_indices] + ) + + for query, idx_list, rel_list in zip( + self.query_texts, self.target_indices, relevance_iter, strict=True + ): if query in seen: pos = seen[query] - merged = set(new_indices[pos]) | set(idx_list) - new_indices[pos] = sorted(merged) + for idx, rel in zip(idx_list, rel_list, strict=True): + if idx not in new_pairs[pos]: + new_pairs[pos][idx] = rel duplicates.append(query) else: seen[query] = len(new_queries) new_queries.append(query) - new_indices.append(sorted(idx_list)) + new_pairs.append(dict(zip(idx_list, rel_list, strict=True))) if duplicates: logger.warning( @@ -151,7 +198,15 @@ def _resolve_duplicate_queries(self) -> None: duplicates, ) self.query_texts = new_queries + new_indices: list[list[int]] = [] + new_relevance: list[list[float]] = [] + for pairs in new_pairs: + sorted_pairs = sorted(pairs.items()) + new_indices.append([idx for idx, _ in sorted_pairs]) + new_relevance.append([rel for _, rel in sorted_pairs]) self.target_indices = new_indices + if graded: + self.target_relevance = new_relevance def _validate_dataset( self, @@ -194,11 +249,54 @@ def _validate_dataset( ) assert isinstance(idx, int), f"Target index {idx} is not an integer" - def _postprocess_indices(self, indices: list[list[int]]) -> list[list[int]]: - """Postprocess indices.""" - # Remove duplicates in target_label - indices = [sorted(set(label_list)) for label_list in indices] - return indices + # Check target_relevance alignment and non-negativity + if self.target_relevance is not None: + assert len(self.target_relevance) == len(self.target_indices), ( + f"target_relevance has {len(self.target_relevance)} queries but " + f"target_indices has {len(self.target_indices)}" + ) + for q_i, (rel_list, idx_list) in enumerate( + zip(self.target_relevance, self.target_indices, strict=True) + ): + assert len(rel_list) == len(idx_list), ( + f"target_relevance[{q_i}] has length {len(rel_list)} but " + f"target_indices[{q_i}] has length {len(idx_list)}" + ) + for rel in rel_list: + assert rel >= 0, f"Negative relevance value {rel} at query {q_i}" + + def _postprocess_indices( + self, + indices: list[list[int]], + relevance: list[list[float]] | None, + ) -> tuple[list[list[int]], list[list[float]] | None]: + """Postprocess indices and aligned relevance, dropping duplicate indices. + + Indices are sorted; relevance is permuted in lockstep so each (idx, rel) + pair stays aligned. When duplicate indices appear within a query, the + relevance from the first occurrence is kept. + """ + if relevance is None: + return [sorted(set(label_list)) for label_list in indices], None + + assert len(relevance) == len(indices), ( + f"target_relevance has {len(relevance)} queries but target_indices has {len(indices)}" + ) + deduped_indices: list[list[int]] = [] + deduped_relevance: list[list[float]] = [] + for idx_list, rel_list in zip(indices, relevance, strict=True): + assert len(idx_list) == len(rel_list), ( + f"target_indices and target_relevance must align per query " + f"(got {len(idx_list)} vs {len(rel_list)})" + ) + seen: dict[int, float] = {} + for idx, rel in zip(idx_list, rel_list, strict=True): + if idx not in seen: + seen[idx] = float(rel) + sorted_pairs = sorted(seen.items()) + deduped_indices.append([idx for idx, _ in sorted_pairs]) + deduped_relevance.append([rel for _, rel in sorted_pairs]) + return deduped_indices, deduped_relevance def _postprocess_texts(self, texts: list[str]) -> list[str]: """Postprocess texts.""" @@ -329,7 +427,10 @@ def evaluate( # Calculate metrics metric_results = calculate_ranking_metrics( - prediction_matrix=prediction_matrix, pos_label_idxs=labels, metrics=metrics + prediction_matrix=prediction_matrix, + pos_label_idxs=labels, + metrics=metrics, + pos_label_relevance=dataset.target_relevance, ) return metric_results diff --git a/tests/test_duplicate_strategy.py b/tests/test_duplicate_strategy.py index e81e8cf..af172a3 100644 --- a/tests/test_duplicate_strategy.py +++ b/tests/test_duplicate_strategy.py @@ -208,3 +208,116 @@ def test_default_resolves_duplicate_targets(self): ) assert ds.target_space == ["a"] assert ds.target_indices == [[0]] + + +class TestTargetRelevance: + """Test optional target_relevance carried through validation and dedup.""" + + def test_default_is_none(self): + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + ) + assert ds.target_relevance is None + + def test_aligned_relevance_preserved(self): + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[3.0, 1.0]], + ) + # Indices are sorted; relevance is permuted in lockstep + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 3.0, 1: 1.0} + + def test_dedup_keeps_first_relevance(self): + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1, 0]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[2.0, 1.0, 9.0]], + ) + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 2.0, 1: 1.0} + + def test_indices_sorted_with_relevance_permuted(self): + """When indices are reordered, relevance must follow.""" + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[2, 0, 1]], + target_space=["a", "b", "c"], + dataset_id="test", + target_relevance=[[1.0, 3.0, 2.0]], + ) + assert ds.target_indices == [[0, 1, 2]] + assert ds.target_relevance == [[3.0, 2.0, 1.0]] + + def test_misaligned_lengths_raise(self): + with pytest.raises(AssertionError): + RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[1.0]], + ) + + def test_negative_relevance_raises(self): + with pytest.raises(AssertionError, match="Negative relevance"): + RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[1.0, -0.5]], + ) + + def test_resolve_targets_carries_relevance(self): + """When duplicate targets collapse, relevance from the first occurrence wins.""" + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1, 2]], + target_space=["x", "y", "x"], + dataset_id="test", + target_relevance=[[3.0, 2.0, 9.0]], + duplicate_target_strategy=DuplicateStrategy.RESOLVE, + ) + # target_space dedups to ["x", "y"]; idx 2 ("x") remaps to 0 (already present) + # First-occurrence relevance for idx 0 is 3.0; the 9.0 is dropped. + assert ds.target_space == ["x", "y"] + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 3.0, 1: 2.0} + + def test_resolve_queries_carries_relevance(self): + """When duplicate queries merge, relevance is unioned with first-wins for ties.""" + ds = RankingDataset( + query_texts=["q1", "q1"], + target_indices=[[0], [1]], + target_space=["a", "b", "c"], + dataset_id="test", + target_relevance=[[3.0], [2.0]], + duplicate_query_strategy=DuplicateStrategy.RESOLVE, + ) + assert ds.query_texts == ["q1"] + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 3.0, 1: 2.0} + + def test_resolve_queries_first_relevance_wins_on_overlap(self): + """If duplicate queries share an index, the first query's relevance is kept.""" + ds = RankingDataset( + query_texts=["q1", "q1"], + target_indices=[[0, 1], [1, 2]], + target_space=["a", "b", "c"], + dataset_id="test", + target_relevance=[[3.0, 2.0], [9.0, 1.0]], + duplicate_query_strategy=DuplicateStrategy.RESOLVE, + ) + assert ds.query_texts == ["q1"] + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + # idx 1 appears in both queries; relevance from the first (2.0) wins, not 9.0 + assert pairs == {0: 3.0, 1: 2.0, 2: 1.0} diff --git a/tests/test_ranking_metrics.py b/tests/test_ranking_metrics.py index 1153812..32cf4e3 100644 --- a/tests/test_ranking_metrics.py +++ b/tests/test_ranking_metrics.py @@ -1,11 +1,21 @@ """Tests for ranking metrics (workrb.metrics.ranking).""" +import math + import numpy as np import pytest import torch from workrb.metrics.ranking import calculate_ranking_metrics + +def _prediction_matrix_from_order(order: list[int], n_targets: int) -> np.ndarray: + """Build a 1-query prediction matrix that ranks targets in the given order (best first).""" + scores = np.zeros((1, n_targets), dtype=float) + for rank, idx in enumerate(order): + scores[0, idx] = float(n_targets - rank) + return scores + # Shared test fixtures: 2 queries x 5 targets # Query 0 sorted order: [0, 2, 4, 3, 1] (scores: 0.9, 0.1, 0.8, 0.2, 0.5) # Query 1 sorted order: [3, 1, 2, 0, 4] (scores: 0.3, 0.7, 0.5, 0.9, 0.1) @@ -152,3 +162,93 @@ def test_existing_metrics_still_work(self): for metric_name in ["map", "mrr", "recall@3", "hit@3", "rp@3"]: assert isinstance(results[metric_name], float) assert 0.0 <= results[metric_name] <= 1.0 + + +class TestNDCGGradedRelevance: + """nDCG with graded relevance (pos_label_relevance not None).""" + + def test_explicit_relevance_one_matches_binary(self): + """Passing all-1.0 relevance must match the binary fallback exactly.""" + binary = calculate_ranking_metrics(PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["ndcg@3"]) + graded = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@3"], + pos_label_relevance=[[1.0, 1.0], [1.0, 1.0]], + ) + assert graded["ndcg@3"] == pytest.approx(binary["ndcg@3"]) + + def test_higher_grade_first_outscores_lower_grade_first(self): + """Ranking the higher-graded positive first must beat the reverse.""" + preds_good = _prediction_matrix_from_order([0, 1, 2, 3], n_targets=4) + preds_bad = _prediction_matrix_from_order([1, 0, 2, 3], n_targets=4) + good = calculate_ranking_metrics( + prediction_matrix=preds_good, + pos_label_idxs=[[0, 1]], + pos_label_relevance=[[3.0, 1.0]], + metrics=["ndcg@4"], + ) + bad = calculate_ranking_metrics( + prediction_matrix=preds_bad, + pos_label_idxs=[[0, 1]], + pos_label_relevance=[[3.0, 1.0]], + metrics=["ndcg@4"], + ) + assert good["ndcg@4"] == pytest.approx(1.0) + assert bad["ndcg@4"] < good["ndcg@4"] + + def test_graded_perfect_ranking_returns_one(self): + """Ideal ordering of grades [3, 2, 1] gives nDCG = 1.""" + preds = _prediction_matrix_from_order([0, 1, 2, 3], n_targets=4) + result = calculate_ranking_metrics( + prediction_matrix=preds, + pos_label_idxs=[[0, 1, 2]], + pos_label_relevance=[[3.0, 2.0, 1.0]], + metrics=["ndcg@4"], + ) + assert result["ndcg@4"] == pytest.approx(1.0) + + def test_graded_known_value(self): + """Hand-computed value with grades 2 and 1 over 3 targets. + + Positives: idx 0 grade 2, idx 2 grade 1. Ranking [0, 1, 2]. + gains @ ranks 0,1,2 = (2^2-1, 0, 2^1-1) = (3, 0, 1) + discounts = 1/log2(2), 1/log2(3), 1/log2(4) + DCG = 3 + 0 + 0.5 = 3.5 + Ideal grades [2, 1] -> gains (3, 1), discounts (1, 1/log2(3)) + IDCG = 3 + 1/log2(3) + """ + preds = _prediction_matrix_from_order([0, 1, 2], n_targets=3) + result = calculate_ranking_metrics( + prediction_matrix=preds, + pos_label_idxs=[[0, 2]], + pos_label_relevance=[[2.0, 1.0]], + metrics=["ndcg@3"], + ) + dcg = 3.0 + 1.0 / math.log2(4) + idcg = 3.0 + 1.0 / math.log2(3) + assert result["ndcg@3"] == pytest.approx(dcg / idcg) + + def test_relevance_ignored_by_binary_metrics(self): + """Binary metrics (map, mrr, ...) ignore pos_label_relevance entirely.""" + without = calculate_ranking_metrics( + PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["map", "mrr", "recall@3"] + ) + with_rel = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + pos_label_relevance=[[3.0, 1.0], [2.0, 2.0]], + ) + for metric in ["map", "mrr", "recall@3"]: + assert with_rel[metric] == pytest.approx(without[metric]) + + def test_misaligned_relevance_raises(self): + """A relevance list whose lengths don't match pos_label_idxs raises.""" + with pytest.raises(ValueError, match="zip"): + calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@3"], + pos_label_relevance=[[1.0], [1.0, 1.0]], + ) From ce26d7f82dc5d070c825a6aad81a3010d4dbcac7 Mon Sep 17 00:00:00 2001 From: "warre.veys" Date: Wed, 6 May 2026 16:04:00 +0200 Subject: [PATCH 2/3] feat: binary_relevance_threshold for graded ranking tasks Adds RankingTask.binary_relevance_threshold (default 1e-9) so a graded task can choose which grades count as positives for binary metrics (map, mrr, recall@k, hit@k, rp@k). Items with relevance >= threshold are positives; items below are dropped from the binary set but still contribute to graded metrics like ndcg@k. The threshold is plumbed through calculate_ranking_metrics, where the binary positive set is derived on the fly from the (indices, relevance) pair when graded labels are present. Default of 1e-9 keeps every listed item as a positive when the dataset provides target_relevance, so existing graded tasks behave identically. The threshold has no effect when target_relevance is None. The graded example task now sets binary_relevance_threshold=2.0 to demonstrate dropping nice-to-have skills from MAP/MRR while keeping them as gain-1 contributions to nDCG. --- CHANGELOG.md | 5 ++ .../custom_task_graded_relevance_example.py | 20 ++++- src/workrb/metrics/ranking.py | 41 +++++++--- src/workrb/tasks/abstract/ranking_base.py | 24 +++++- tests/test_duplicate_strategy.py | 11 ++- tests/test_ranking_metrics.py | 80 +++++++++++++++++++ 6 files changed, 165 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a8207..4d3cf48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, ``rp@k``) ignore the field. Binary nDCG behavior is preserved when ``target_relevance`` is ``None``. See ``examples/custom_task_graded_relevance_example.py``. +- ``RankingTask.binary_relevance_threshold`` (default ``1e-9``) lets a graded + task choose which grades count as positives for binary metrics. Items with + relevance below the threshold are dropped from the binary positive set but + still contribute to graded metrics like ``ndcg@k``. Has no effect when + ``target_relevance`` is ``None``. ## v0.5.1 (2026-03-13) diff --git a/examples/custom_task_graded_relevance_example.py b/examples/custom_task_graded_relevance_example.py index 9afc606..a126e30 100644 --- a/examples/custom_task_graded_relevance_example.py +++ b/examples/custom_task_graded_relevance_example.py @@ -10,8 +10,13 @@ 1-to-1 with ``target_indices``. Items NOT listed in ``target_indices`` are implicitly grade 0 (irrelevant). 2. Add ``ndcg@k`` to ``default_metrics``. The (2^rel - 1) gain in nDCG uses the - grades; binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, ``rp@k``) - ignore them and treat every entry in ``target_indices`` as relevant=1. + grades. +3. Optionally override ``binary_relevance_threshold`` on the task to control + which graded items count as positives for binary metrics (``map``, ``mrr``, + ``recall@k``, ``hit@k``, ``rp@k``). Items with relevance >= threshold are + positives; items below are dropped from the binary positive set but still + contribute to nDCG. Default is ``1e-9``, which keeps every listed item as + a positive (matches the binary-only behavior). The grading scale is up to the task. Common choices: {1, 2, 3} (this example), {1, 2, 3, 4} (TREC-style), or fractional in [0, 1]. Values must be non-negative. @@ -53,10 +58,17 @@ def target_input_type(self) -> ModelInputType: @property def default_metrics(self) -> list[str]: # nDCG@k is the headline metric here. MAP/MRR/recall are kept as a - # binary sanity check: they treat every graded positive as relevant=1 - # and ignore the grades, so they still produce well-defined numbers. + # binary sanity check on the high-grade subset (see the threshold + # below): they treat every passing graded positive as relevant=1. return ["ndcg@5", "ndcg@10", "map", "mrr", "recall@5"] + @property + def binary_relevance_threshold(self) -> float: + # Only "secondary" (grade 2) and "primary" (grade 3) skills count as + # positives for binary metrics; "nice-to-have" (grade 1) items drop + # out of MAP/MRR/recall but still contribute to nDCG. + return 2.0 + @property def task_group(self) -> RankingTaskGroup: return RankingTaskGroup.JOB2SKILL diff --git a/src/workrb/metrics/ranking.py b/src/workrb/metrics/ranking.py index d5afbce..414430d 100644 --- a/src/workrb/metrics/ranking.py +++ b/src/workrb/metrics/ranking.py @@ -14,6 +14,7 @@ def calculate_ranking_metrics( "rp@10", ), pos_label_relevance: list[list[float]] | None = None, + binary_relevance_threshold: float = 1e-9, ) -> dict[str, float]: """Calculate ranking metrics for evaluation. @@ -28,9 +29,15 @@ def calculate_ranking_metrics( pos_label_relevance : list[list[float]] or None, optional Optional graded relevance per positive, aligned 1-to-1 with ``pos_label_idxs``. When ``None``, every positive is treated as relevance 1.0 - (binary fallback, identical to current behavior). Only consulted by graded - metrics (``ndcg``); binary metrics (``map``, ``mrr``, ``recall@k``, - ``hit@k``, ``rp@k``) ignore it. + (binary fallback, identical to current behavior). Used by graded metrics + (``ndcg``); binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, + ``rp@k``) consult it only to apply ``binary_relevance_threshold``. + binary_relevance_threshold : float, optional + Minimum graded relevance for an item to count as a positive for binary + metrics. Items with relevance below this threshold are dropped from the + binary positive set but still contribute to graded metrics. Ignored when + ``pos_label_relevance`` is ``None``. Defaults to ``1e-9``: any non-zero + grade counts, matching the current binary-only behavior. Returns ------- @@ -44,6 +51,21 @@ def calculate_ranking_metrics( # Sort indices by prediction scores (descending) sorted_indices = np.argsort(-prediction_matrix, axis=1) + # When graded relevance is provided, derive the binary positive set by + # thresholding so binary metrics (map/mrr/recall/hit/rp) consume only items + # with relevance >= threshold. Graded nDCG continues to use the full list. + if pos_label_relevance is None: + binary_pos_label_idxs = pos_label_idxs + else: + binary_pos_label_idxs = [ + [ + idx + for idx, rel in zip(idx_list, rel_list, strict=True) + if rel >= binary_relevance_threshold + ] + for idx_list, rel_list in zip(pos_label_idxs, pos_label_relevance, strict=True) + ] + results = {} def _metric_k_split(metric: str) -> tuple[str, int | None]: @@ -61,22 +83,22 @@ def _metric_k_split(metric: str) -> tuple[str, int | None]: base_metric, k = _metric_k_split(metric) if metric == "map": - results[metric] = _calculate_map(sorted_indices, pos_label_idxs) + results[metric] = _calculate_map(sorted_indices, binary_pos_label_idxs) elif base_metric == "rp": assert k is not None, "k must be provided for rp@k metrics" - results[metric] = _calculate_rp_at_k(sorted_indices, pos_label_idxs, k) + results[metric] = _calculate_rp_at_k(sorted_indices, binary_pos_label_idxs, k) elif metric == "mrr": - results[metric] = _calculate_mrr(sorted_indices, pos_label_idxs) + results[metric] = _calculate_mrr(sorted_indices, binary_pos_label_idxs) elif base_metric == "recall": assert k is not None, "k must be provided for recall@k metrics" - results[metric] = _calculate_recall_at_k(sorted_indices, pos_label_idxs, k) + results[metric] = _calculate_recall_at_k(sorted_indices, binary_pos_label_idxs, k) elif base_metric == "hit": assert k is not None, "k must be provided for hit@k metrics" - results[metric] = _calculate_hit_at_k(sorted_indices, pos_label_idxs, k) + results[metric] = _calculate_hit_at_k(sorted_indices, binary_pos_label_idxs, k) elif base_metric == "ndcg": cutoff = k if k is not None else sorted_indices.shape[1] @@ -235,8 +257,7 @@ def _calculate_ndcg( relevance_by_idx = dict.fromkeys(pos_labels, 1.0) else: relevance_by_idx = { - idx: float(rel) - for idx, rel in zip(pos_labels, pos_label_relevance[i], strict=True) + idx: float(rel) for idx, rel in zip(pos_labels, pos_label_relevance[i], strict=True) } cutoff = min(k, len(sorted_indices[i])) diff --git a/src/workrb/tasks/abstract/ranking_base.py b/src/workrb/tasks/abstract/ranking_base.py index 28371ba..701369e 100644 --- a/src/workrb/tasks/abstract/ranking_base.py +++ b/src/workrb/tasks/abstract/ranking_base.py @@ -321,6 +321,25 @@ def task_type(self) -> TaskType: def default_metrics(self) -> list[str]: return ["map", "rp@10", "mrr"] + @property + def binary_relevance_threshold(self) -> float: + """Minimum graded relevance for an item to count as a positive. + + Used by binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, + ``rp@k``) when the dataset provides ``target_relevance``: items with + relevance ``>= threshold`` are treated as positives, items below it + are dropped from the binary positive set but still contribute to + graded metrics such as ``ndcg@k``. + + Default is ``1e-9`` so any listed item with a non-zero grade counts + as a positive — equivalent to today's behavior where every entry in + ``target_indices`` is a positive. Override on the task to express + a stricter threshold (e.g. ``2.0`` on a 0-3 scale). + + Has no effect when ``target_relevance`` is ``None``. + """ + return 1e-9 + def __init__( self, **kwargs, @@ -425,12 +444,15 @@ def evaluate( if isinstance(prediction_matrix, torch.Tensor): prediction_matrix = prediction_matrix.cpu().float().numpy() - # Calculate metrics + # Calculate metrics. When the dataset provides graded relevance, binary + # metrics consume only positives with relevance >= binary_relevance_threshold; + # nDCG still sees the full graded label list. metric_results = calculate_ranking_metrics( prediction_matrix=prediction_matrix, pos_label_idxs=labels, metrics=metrics, pos_label_relevance=dataset.target_relevance, + binary_relevance_threshold=self.binary_relevance_threshold, ) return metric_results diff --git a/tests/test_duplicate_strategy.py b/tests/test_duplicate_strategy.py index af172a3..5414826 100644 --- a/tests/test_duplicate_strategy.py +++ b/tests/test_duplicate_strategy.py @@ -2,7 +2,7 @@ import pytest -from workrb.tasks.abstract.ranking_base import DuplicateStrategy, RankingDataset +from workrb.tasks.abstract.ranking_base import DuplicateStrategy, RankingDataset, RankingTask class TestDuplicateStrategyRaise: @@ -321,3 +321,12 @@ def test_resolve_queries_first_relevance_wins_on_overlap(self): pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) # idx 1 appears in both queries; relevance from the first (2.0) wins, not 9.0 assert pairs == {0: 3.0, 1: 2.0, 2: 1.0} + + +class TestBinaryRelevanceThresholdDefault: + """RankingTask exposes binary_relevance_threshold with a sensible default.""" + + def test_default_value_is_small_positive(self): + """Default threshold treats any non-zero grade as positive (current behavior).""" + # Read directly off the abstract class attribute — no instantiation needed. + assert RankingTask.binary_relevance_threshold.fget(None) == pytest.approx(1e-9) diff --git a/tests/test_ranking_metrics.py b/tests/test_ranking_metrics.py index 32cf4e3..d4c0546 100644 --- a/tests/test_ranking_metrics.py +++ b/tests/test_ranking_metrics.py @@ -16,6 +16,7 @@ def _prediction_matrix_from_order(order: list[int], n_targets: int) -> np.ndarra scores[0, idx] = float(n_targets - rank) return scores + # Shared test fixtures: 2 queries x 5 targets # Query 0 sorted order: [0, 2, 4, 3, 1] (scores: 0.9, 0.1, 0.8, 0.2, 0.5) # Query 1 sorted order: [3, 1, 2, 0, 4] (scores: 0.3, 0.7, 0.5, 0.9, 0.1) @@ -252,3 +253,82 @@ def test_misaligned_relevance_raises(self): metrics=["ndcg@3"], pos_label_relevance=[[1.0], [1.0, 1.0]], ) + + +class TestBinaryRelevanceThreshold: + """Threshold filters the binary positive set without affecting nDCG.""" + + def test_default_threshold_keeps_all_listed_positives(self): + """Default 1e-9 threshold passes any non-zero grade — same as no relevance.""" + without_rel = calculate_ranking_metrics( + PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["map", "mrr", "recall@3"] + ) + with_rel = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + pos_label_relevance=[[3.0, 1.0], [2.0, 2.0]], + ) + for metric in ["map", "mrr", "recall@3"]: + assert with_rel[metric] == pytest.approx(without_rel[metric]) + + def test_threshold_drops_subthreshold_positives_for_binary(self): + """A threshold of 2 drops grade-1 items from the binary positive set.""" + # Query 0 sorted: [0, 2, 4, 3, 1]. Positives: {2 (grade 3), 4 (grade 1)}. + # With threshold=2, only idx 2 stays positive for binary metrics. + # MRR: rank of first surviving positive (idx 2) is 2 -> 0.5 + # Query 1 sorted: [3, 1, 2, 0, 4]. Positives: {1 (grade 3), 3 (grade 1)}. + # With threshold=2, only idx 1 stays. Rank of idx 1 = 2 -> 0.5 + result = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["mrr"], + pos_label_relevance=[[3.0, 1.0], [3.0, 1.0]], + binary_relevance_threshold=2.0, + ) + assert result["mrr"] == pytest.approx(0.5) + + def test_threshold_does_not_affect_ndcg(self): + """NDCG always sees the full graded list regardless of threshold.""" + relevance = [[3.0, 1.0], [3.0, 1.0]] + low = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@5"], + pos_label_relevance=relevance, + binary_relevance_threshold=1e-9, + ) + high = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@5"], + pos_label_relevance=relevance, + binary_relevance_threshold=2.0, + ) + assert low["ndcg@5"] == pytest.approx(high["ndcg@5"]) + + def test_threshold_ignored_when_relevance_is_none(self): + """Threshold is a no-op when no graded relevance is supplied.""" + ref = calculate_ranking_metrics( + PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["map", "mrr", "recall@3"] + ) + with_threshold = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + binary_relevance_threshold=99.0, # would drop everything if it applied + ) + for metric in ["map", "mrr", "recall@3"]: + assert with_threshold[metric] == pytest.approx(ref[metric]) + + def test_threshold_above_all_grades_empties_binary_set(self): + """A threshold above every grade leaves no binary positives -> 0.0.""" + result = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + pos_label_relevance=[[3.0, 1.0], [3.0, 1.0]], + binary_relevance_threshold=10.0, + ) + for metric in ["map", "mrr", "recall@3"]: + assert result[metric] == 0.0 From 31e00f1679908dbf68b9f6cb77a2831b0318af9a Mon Sep 17 00:00:00 2001 From: "warre.veys" Date: Wed, 6 May 2026 16:31:10 +0200 Subject: [PATCH 3/3] More clear documentation on the relevancy feature --- README.md | 14 ++- .../custom_task_graded_relevance_example.py | 102 ++++++++++++++++-- 2 files changed, 107 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ba4d548..2ed1b21 100644 --- a/README.md +++ b/README.md @@ -237,11 +237,23 @@ Each aggregation provides 95% confidence intervals (replace `mean` with `ci_marg | --- | --- | | `map` | Mean Average Precision | | `mrr` | Mean Reciprocal Rank | -| `ndcg@k` | Normalized Discounted Cumulative Gain with support for top-k cutoff. Uses a (2^rel - 1) gain when the dataset provides graded `target_relevance`; falls back to binary positives otherwise. See `examples/custom_task_graded_relevance_example.py`. | +| `ndcg@k` | Normalized Discounted Cumulative Gain with support for top-k cutoff. Uses a (2^rel - 1) gain when the dataset provides graded `target_relevance`; falls back to binary positives otherwise. See [Graded relevance](#graded-relevance-optional) for the `binary_relevance_threshold` interaction. | | `recall@k` | Recall at k (e.g. `recall@5`, `recall@10`) | | `hit@k` | Hit rate at k — binary: is any relevant item in the top-k? | | `rp@k` | R-Precision at k — precision relative to total relevant items | +##### Graded relevance (optional) + +`RankingDataset` accepts an optional `target_relevance` aligned 1-to-1 with `target_indices`. When set, `ndcg@k` consumes the grades via the standard `(2^rel - 1)` gain. The other metrics (`map`, `mrr`, `recall@k`, `hit@k`, `rp@k`) remain binary — they decide what counts as a positive by thresholding the grades. + +That threshold is `RankingTask.binary_relevance_threshold`. **It is the knob that decides what "relevant" means for the binary metrics on a graded dataset, so changing it changes their values.** + +- Default `1e-9` — every listed grade > 0 is a positive. Numbers match the `target_relevance=None` case exactly, so legacy tasks need no migration. +- Override on a graded task to express a stricter cutoff (e.g. `2.0` on a 1-3 scale: only secondary/primary count as positives; nice-to-haves still help nDCG but no longer count toward MAP). +- Always ignored when `target_relevance is None`. + +`recall@k` denominator on a graded dataset is the *thresholded* positive count, not the count of all listed items — so a graded dataset's binary numbers are not directly comparable to a fully-binary version of the same data. See [examples/custom_task_graded_relevance_example.py](examples/custom_task_graded_relevance_example.py) for a runnable side-by-side of how the threshold shifts MAP/MRR/recall while leaving nDCG unchanged. + **Classification metrics** (used in `ClassificationTask`): | Metric | Description | diff --git a/examples/custom_task_graded_relevance_example.py b/examples/custom_task_graded_relevance_example.py index a126e30..d24472e 100644 --- a/examples/custom_task_graded_relevance_example.py +++ b/examples/custom_task_graded_relevance_example.py @@ -11,18 +11,35 @@ implicitly grade 0 (irrelevant). 2. Add ``ndcg@k`` to ``default_metrics``. The (2^rel - 1) gain in nDCG uses the grades. -3. Optionally override ``binary_relevance_threshold`` on the task to control - which graded items count as positives for binary metrics (``map``, ``mrr``, - ``recall@k``, ``hit@k``, ``rp@k``). Items with relevance >= threshold are - positives; items below are dropped from the binary positive set but still - contribute to nDCG. Default is ``1e-9``, which keeps every listed item as - a positive (matches the binary-only behavior). +3. Choose ``binary_relevance_threshold`` on the task to control which graded + items count as positives for binary metrics (``map``, ``mrr``, ``recall@k``, + ``hit@k``, ``rp@k``). Items with relevance >= threshold are positives; + items below are dropped from the binary positive set but still contribute + to nDCG. + +Why the threshold matters +------------------------- +The threshold *defines* what "relevant" means for the binary metrics on a +graded dataset. Changing it changes the values of MAP/MRR/recall/hit/rp. +nDCG is unaffected — it always sees the full graded list. + +- Default ``1e-9``: every listed grade > 0 counts. Numbers match the + ``target_relevance=None`` (binary-only) case exactly. Safe default if you + just want graded nDCG without disturbing legacy binary metrics. +- Stricter (e.g. ``2.0`` on a 1-3 scale): only secondary/primary count. + Nice-to-haves stop counting toward MAP but still help nDCG. + +The bottom of this file runs the same dataset under two thresholds so you +can see nDCG stay constant while MAP/MRR/recall shift. The grading scale is up to the task. Common choices: {1, 2, 3} (this example), {1, 2, 3, 4} (TREC-style), or fractional in [0, 1]. Values must be non-negative. """ +import numpy as np + import workrb +from workrb.metrics.ranking import calculate_ranking_metrics from workrb.registry import register_task from workrb.tasks.abstract.base import DatasetSplit, LabelType, Language from workrb.tasks.abstract.ranking_base import RankingDataset, RankingTaskGroup @@ -53,7 +70,7 @@ def query_input_type(self) -> ModelInputType: @property def target_input_type(self) -> ModelInputType: - return ModelInputType.SKILL + return ModelInputType.SKILL_NAME @property def default_metrics(self) -> list[str]: @@ -130,6 +147,72 @@ def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset: ) +def _demo_threshold_effect() -> None: + """Show how ``binary_relevance_threshold`` changes binary metrics on the same data. + + Construction is deliberate: a single query with positives at grades {3, 3, 2, 1}, + and a hand-picked ranking that places the grade-1 "nice-to-have" at rank 1. + That makes the trade-off legible: + + - At threshold ``1e-9`` (default), the nice-to-have counts as a positive, so + MRR is a perfect 1.0 — even though the model put a low-value item at the top. + - At threshold ``2.0``, the nice-to-have is dropped from the positive set, so + MRR collapses to the rank of the first surviving positive. + - nDCG is invariant: it always sees the full graded list, and the (2^rel - 1) + gain already discounts the nice-to-have appropriately. + + Same ``prediction_matrix``, same ``target_relevance`` — only the threshold differs. + """ + target_indices = [[0, 1, 2, 3]] + target_relevance = [[3.0, 3.0, 2.0, 1.0]] # primary, primary, secondary, nice-to-have + n_targets = 8 + + # Forced ranking (best -> worst): + # rank 1: idx 3 (grade 1, nice-to-have ranked first) + # rank 2: idx 0 (grade 3) + # rank 5: idx 1 (grade 3, just inside top-5) + # rank 7: idx 2 (grade 2, outside top-5) + order = [3, 0, 5, 6, 1, 7, 2, 4] + prediction_matrix = np.zeros((1, n_targets), dtype=float) + for rank, idx in enumerate(order): + prediction_matrix[0, idx] = float(n_targets - rank) + + metrics = ["ndcg@5", "map", "mrr", "recall@5"] + thresholds = { + "1e-9 (every listed grade counts)": 1e-9, + "2.0 (only grade >= 2 counts)": 2.0, + } + + print("\n--- Effect of binary_relevance_threshold on the same dataset ---") + print("Same prediction_matrix, same target_relevance — only the threshold differs.\n") + header = f"{'threshold':<36}" + "".join(f"{m:>10}" for m in metrics) + print(header) + print("-" * len(header)) + for label, threshold in thresholds.items(): + results = calculate_ranking_metrics( + prediction_matrix=prediction_matrix, + pos_label_idxs=target_indices, + metrics=metrics, + pos_label_relevance=target_relevance, + binary_relevance_threshold=threshold, + ) + row = f"{label:<36}" + "".join(f"{results[m]:>10.4f}" for m in metrics) + print(row) + print( + "\nnDCG@5 is identical across thresholds — it always consumes the full graded list." + "\nMAP / MRR / recall@5 shift because the threshold redefines the binary positive set:" + "\n - 1e-9: the nice-to-have (grade 1) at rank 1 counts as a positive, so MRR=1.0" + "\n and the model 'looks great' on binary metrics despite ranking a" + "\n low-value item first." + "\n - 2.0: the nice-to-have is dropped from the positive set; MRR drops to the" + "\n rank of the first surviving positive, MAP drops, and recall@5 is" + "\n computed against a smaller denominator (3 instead of 4)." + "\nLesson: on a graded dataset, binary metrics answer 'how does the model rank" + "\nthe items I declared relevant', and *what counts as relevant* is exactly the" + "\nthreshold. Pick it deliberately." + ) + + if __name__ == "__main__": print("🚀 Custom Graded-Relevance Task Example") print("=" * 50) @@ -139,7 +222,7 @@ def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset: # Run with the task's default_metrics, which mixes nDCG (graded) with # MAP/MRR/recall (binary). The same dataset feeds both — graded metrics - # consult target_relevance, binary metrics ignore it. + # consult target_relevance, binary metrics see only the thresholded set. results = workrb.evaluate( model, tasks, @@ -148,3 +231,6 @@ def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset: force_restart=True, ) print(results) + + # Show why binary_relevance_threshold matters: same data, two thresholds. + _demo_threshold_effect()