diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ed6d0e..4d3cf48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,19 @@ +## Unreleased + +### Feat + +- graded relevance support for ranking metrics: ``RankingDataset`` accepts an + optional ``target_relevance`` field aligned 1-to-1 with ``target_indices``. + ``ndcg@k`` uses a (2^rel - 1) gain when graded labels are provided; binary + metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, ``rp@k``) ignore the + field. Binary nDCG behavior is preserved when ``target_relevance`` is + ``None``. See ``examples/custom_task_graded_relevance_example.py``. +- ``RankingTask.binary_relevance_threshold`` (default ``1e-9``) lets a graded + task choose which grades count as positives for binary metrics. Items with + relevance below the threshold are dropped from the binary positive set but + still contribute to graded metrics like ``ndcg@k``. Has no effect when + ``target_relevance`` is ``None``. + ## v0.5.1 (2026-03-13) ### Feat diff --git a/README.md b/README.md index 5deeb16..2ed1b21 100644 --- a/README.md +++ b/README.md @@ -237,11 +237,23 @@ Each aggregation provides 95% confidence intervals (replace `mean` with `ci_marg | --- | --- | | `map` | Mean Average Precision | | `mrr` | Mean Reciprocal Rank | -| `ndcg@k` | Normalized Discounted Cumulative Gain with support for top-k cutoff.| +| `ndcg@k` | Normalized Discounted Cumulative Gain with support for top-k cutoff. Uses a (2^rel - 1) gain when the dataset provides graded `target_relevance`; falls back to binary positives otherwise. See [Graded relevance](#graded-relevance-optional) for the `binary_relevance_threshold` interaction. | | `recall@k` | Recall at k (e.g. `recall@5`, `recall@10`) | | `hit@k` | Hit rate at k — binary: is any relevant item in the top-k? | | `rp@k` | R-Precision at k — precision relative to total relevant items | +##### Graded relevance (optional) + +`RankingDataset` accepts an optional `target_relevance` aligned 1-to-1 with `target_indices`. When set, `ndcg@k` consumes the grades via the standard `(2^rel - 1)` gain. The other metrics (`map`, `mrr`, `recall@k`, `hit@k`, `rp@k`) remain binary — they decide what counts as a positive by thresholding the grades. + +That threshold is `RankingTask.binary_relevance_threshold`. **It is the knob that decides what "relevant" means for the binary metrics on a graded dataset, so changing it changes their values.** + +- Default `1e-9` — every listed grade > 0 is a positive. Numbers match the `target_relevance=None` case exactly, so legacy tasks need no migration. +- Override on a graded task to express a stricter cutoff (e.g. `2.0` on a 1-3 scale: only secondary/primary count as positives; nice-to-haves still help nDCG but no longer count toward MAP). +- Always ignored when `target_relevance is None`. + +`recall@k` denominator on a graded dataset is the *thresholded* positive count, not the count of all listed items — so a graded dataset's binary numbers are not directly comparable to a fully-binary version of the same data. See [examples/custom_task_graded_relevance_example.py](examples/custom_task_graded_relevance_example.py) for a runnable side-by-side of how the threshold shifts MAP/MRR/recall while leaving nDCG unchanged. + **Classification metrics** (used in `ClassificationTask`): | Metric | Description | diff --git a/examples/custom_task_example.py b/examples/custom_task_example.py index 99edb3e..910cb8e 100644 --- a/examples/custom_task_example.py +++ b/examples/custom_task_example.py @@ -4,6 +4,10 @@ This example demonstrates how to create a custom ranking task that can be used with the WorkRB framework. Custom tasks should inherit from workrb.tasks.RankingTask and implement the required abstract methods. + +Uses binary relevance: every entry in ``target_indices`` is treated as relevant=1 +by all metrics. For graded relevance (e.g. nDCG@k with a 1-2-3 scale), see +``custom_task_graded_relevance_example.py``. """ import workrb diff --git a/examples/custom_task_graded_relevance_example.py b/examples/custom_task_graded_relevance_example.py new file mode 100644 index 0000000..d24472e --- /dev/null +++ b/examples/custom_task_graded_relevance_example.py @@ -0,0 +1,236 @@ +""" +Custom Graded-Relevance Ranking Task + +Extends ``custom_task_example.py`` with graded relevance: each (query, target) +pair carries an ordinal grade instead of a binary relevant/not-relevant flag. + +How a graded task differs from a binary one: + +1. Pass an additional ``target_relevance`` list to ``RankingDataset``, aligned + 1-to-1 with ``target_indices``. Items NOT listed in ``target_indices`` are + implicitly grade 0 (irrelevant). +2. Add ``ndcg@k`` to ``default_metrics``. The (2^rel - 1) gain in nDCG uses the + grades. +3. Choose ``binary_relevance_threshold`` on the task to control which graded + items count as positives for binary metrics (``map``, ``mrr``, ``recall@k``, + ``hit@k``, ``rp@k``). Items with relevance >= threshold are positives; + items below are dropped from the binary positive set but still contribute + to nDCG. + +Why the threshold matters +------------------------- +The threshold *defines* what "relevant" means for the binary metrics on a +graded dataset. Changing it changes the values of MAP/MRR/recall/hit/rp. +nDCG is unaffected — it always sees the full graded list. + +- Default ``1e-9``: every listed grade > 0 counts. Numbers match the + ``target_relevance=None`` (binary-only) case exactly. Safe default if you + just want graded nDCG without disturbing legacy binary metrics. +- Stricter (e.g. ``2.0`` on a 1-3 scale): only secondary/primary count. + Nice-to-haves stop counting toward MAP but still help nDCG. + +The bottom of this file runs the same dataset under two thresholds so you +can see nDCG stay constant while MAP/MRR/recall shift. + +The grading scale is up to the task. Common choices: {1, 2, 3} (this example), +{1, 2, 3, 4} (TREC-style), or fractional in [0, 1]. Values must be non-negative. +""" + +import numpy as np + +import workrb +from workrb.metrics.ranking import calculate_ranking_metrics +from workrb.registry import register_task +from workrb.tasks.abstract.base import DatasetSplit, LabelType, Language +from workrb.tasks.abstract.ranking_base import RankingDataset, RankingTaskGroup +from workrb.types import ModelInputType + + +@register_task() +class GradedJob2SkillTask(workrb.tasks.RankingTask): + """Job-to-skill ranking with three relevance grades. + + Grade 3 — primary skill (the role is defined by it). + Grade 2 — secondary skill (clearly expected). + Grade 1 — nice-to-have (mentioned but not required). + Grade 0 — irrelevant; never appears in ``target_indices`` (implicit). + """ + + @property + def name(self) -> str: + return "GradedJob2SkillTask" + + @property + def description(self) -> str: + return "Job-to-skill ranking with graded relevance for nDCG evaluation." + + @property + def query_input_type(self) -> ModelInputType: + return ModelInputType.JOB_TITLE + + @property + def target_input_type(self) -> ModelInputType: + return ModelInputType.SKILL_NAME + + @property + def default_metrics(self) -> list[str]: + # nDCG@k is the headline metric here. MAP/MRR/recall are kept as a + # binary sanity check on the high-grade subset (see the threshold + # below): they treat every passing graded positive as relevant=1. + return ["ndcg@5", "ndcg@10", "map", "mrr", "recall@5"] + + @property + def binary_relevance_threshold(self) -> float: + # Only "secondary" (grade 2) and "primary" (grade 3) skills count as + # positives for binary metrics; "nice-to-have" (grade 1) items drop + # out of MAP/MRR/recall but still contribute to nDCG. + return 2.0 + + @property + def task_group(self) -> RankingTaskGroup: + return RankingTaskGroup.JOB2SKILL + + @property + def label_type(self) -> LabelType: + return LabelType.MULTI_LABEL + + @property + def supported_query_languages(self) -> list[Language]: + return [Language.EN] + + @property + def supported_target_languages(self) -> list[Language]: + return [Language.EN] + + def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset: + """Load a tiny in-memory graded dataset. + + target_indices and target_relevance are aligned 1-to-1: position i in + ``target_relevance[q]`` is the grade for ``target_indices[q][i]``. + """ + queries = [ + "Machine learning engineer", + "Data scientist", + "Backend software developer", + ] + + targets = [ + "Python programming", # 0 + "Machine learning", # 1 + "Data analysis", # 2 + "Software engineering", # 3 + "Statistics", # 4 + "Deep learning", # 5 + "Web development", # 6 + "Database management", # 7 + ] + + # Each list pairs (skill_idx -> grade). Skills not listed for a query + # are implicitly grade 0. + target_indices = [ + [1, 5, 0, 3], # ML engineer: ML(3), DL(3), Python(2), SWE(1) + [1, 2, 4, 5], # Data scientist: ML(3), DA(3), Stats(2), DL(2) + [0, 3, 7, 6], # Backend dev: Python(3), SWE(3), DB(2), Web(1) + ] + target_relevance = [ + [3.0, 3.0, 2.0, 1.0], + [3.0, 3.0, 2.0, 2.0], + [3.0, 3.0, 2.0, 1.0], + ] + + return RankingDataset( + query_texts=queries, + target_indices=target_indices, + target_space=targets, + dataset_id=dataset_id, + target_relevance=target_relevance, + ) + + +def _demo_threshold_effect() -> None: + """Show how ``binary_relevance_threshold`` changes binary metrics on the same data. + + Construction is deliberate: a single query with positives at grades {3, 3, 2, 1}, + and a hand-picked ranking that places the grade-1 "nice-to-have" at rank 1. + That makes the trade-off legible: + + - At threshold ``1e-9`` (default), the nice-to-have counts as a positive, so + MRR is a perfect 1.0 — even though the model put a low-value item at the top. + - At threshold ``2.0``, the nice-to-have is dropped from the positive set, so + MRR collapses to the rank of the first surviving positive. + - nDCG is invariant: it always sees the full graded list, and the (2^rel - 1) + gain already discounts the nice-to-have appropriately. + + Same ``prediction_matrix``, same ``target_relevance`` — only the threshold differs. + """ + target_indices = [[0, 1, 2, 3]] + target_relevance = [[3.0, 3.0, 2.0, 1.0]] # primary, primary, secondary, nice-to-have + n_targets = 8 + + # Forced ranking (best -> worst): + # rank 1: idx 3 (grade 1, nice-to-have ranked first) + # rank 2: idx 0 (grade 3) + # rank 5: idx 1 (grade 3, just inside top-5) + # rank 7: idx 2 (grade 2, outside top-5) + order = [3, 0, 5, 6, 1, 7, 2, 4] + prediction_matrix = np.zeros((1, n_targets), dtype=float) + for rank, idx in enumerate(order): + prediction_matrix[0, idx] = float(n_targets - rank) + + metrics = ["ndcg@5", "map", "mrr", "recall@5"] + thresholds = { + "1e-9 (every listed grade counts)": 1e-9, + "2.0 (only grade >= 2 counts)": 2.0, + } + + print("\n--- Effect of binary_relevance_threshold on the same dataset ---") + print("Same prediction_matrix, same target_relevance — only the threshold differs.\n") + header = f"{'threshold':<36}" + "".join(f"{m:>10}" for m in metrics) + print(header) + print("-" * len(header)) + for label, threshold in thresholds.items(): + results = calculate_ranking_metrics( + prediction_matrix=prediction_matrix, + pos_label_idxs=target_indices, + metrics=metrics, + pos_label_relevance=target_relevance, + binary_relevance_threshold=threshold, + ) + row = f"{label:<36}" + "".join(f"{results[m]:>10.4f}" for m in metrics) + print(row) + print( + "\nnDCG@5 is identical across thresholds — it always consumes the full graded list." + "\nMAP / MRR / recall@5 shift because the threshold redefines the binary positive set:" + "\n - 1e-9: the nice-to-have (grade 1) at rank 1 counts as a positive, so MRR=1.0" + "\n and the model 'looks great' on binary metrics despite ranking a" + "\n low-value item first." + "\n - 2.0: the nice-to-have is dropped from the positive set; MRR drops to the" + "\n rank of the first surviving positive, MAP drops, and recall@5 is" + "\n computed against a smaller denominator (3 instead of 4)." + "\nLesson: on a graded dataset, binary metrics answer 'how does the model rank" + "\nthe items I declared relevant', and *what counts as relevant* is exactly the" + "\nthreshold. Pick it deliberately." + ) + + +if __name__ == "__main__": + print("🚀 Custom Graded-Relevance Task Example") + print("=" * 50) + + model = workrb.models.BiEncoderModel("all-MiniLM-L6-v2") + tasks = [GradedJob2SkillTask(languages=["en"], split="test")] + + # Run with the task's default_metrics, which mixes nDCG (graded) with + # MAP/MRR/recall (binary). The same dataset feeds both — graded metrics + # consult target_relevance, binary metrics see only the thresholded set. + results = workrb.evaluate( + model, + tasks, + output_folder="results/graded_task_demo", + description="Custom ranking task with graded relevance.", + force_restart=True, + ) + print(results) + + # Show why binary_relevance_threshold matters: same data, two thresholds. + _demo_threshold_effect() diff --git a/src/workrb/metrics/ranking.py b/src/workrb/metrics/ranking.py index 317d001..414430d 100644 --- a/src/workrb/metrics/ranking.py +++ b/src/workrb/metrics/ranking.py @@ -13,18 +13,36 @@ def calculate_ranking_metrics( "map", "rp@10", ), + pos_label_relevance: list[list[float]] | None = None, + binary_relevance_threshold: float = 1e-9, ) -> dict[str, float]: - """ - Calculate ranking metrics for evaluation. - - Args: - prediction_matrix: Similarity/prediction matrix of shape (n_queries, n_targets) - pos_label_idxs: List of lists containing positive label indices for each query - metrics: List of metric names to compute + """Calculate ranking metrics for evaluation. + + Parameters + ---------- + prediction_matrix : torch.Tensor or np.ndarray + Similarity/prediction matrix of shape (n_queries, n_targets). + pos_label_idxs : list[list[int]] + Positive label indices for each query. + metrics : Sequence[str] + Metric names to compute. + pos_label_relevance : list[list[float]] or None, optional + Optional graded relevance per positive, aligned 1-to-1 with + ``pos_label_idxs``. When ``None``, every positive is treated as relevance 1.0 + (binary fallback, identical to current behavior). Used by graded metrics + (``ndcg``); binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, + ``rp@k``) consult it only to apply ``binary_relevance_threshold``. + binary_relevance_threshold : float, optional + Minimum graded relevance for an item to count as a positive for binary + metrics. Items with relevance below this threshold are dropped from the + binary positive set but still contribute to graded metrics. Ignored when + ``pos_label_relevance`` is ``None``. Defaults to ``1e-9``: any non-zero + grade counts, matching the current binary-only behavior. Returns ------- - Dictionary mapping metric names to values + dict[str, float] + Dictionary mapping metric names to values. """ # Convert to numpy if needed if isinstance(prediction_matrix, torch.Tensor): @@ -33,6 +51,21 @@ def calculate_ranking_metrics( # Sort indices by prediction scores (descending) sorted_indices = np.argsort(-prediction_matrix, axis=1) + # When graded relevance is provided, derive the binary positive set by + # thresholding so binary metrics (map/mrr/recall/hit/rp) consume only items + # with relevance >= threshold. Graded nDCG continues to use the full list. + if pos_label_relevance is None: + binary_pos_label_idxs = pos_label_idxs + else: + binary_pos_label_idxs = [ + [ + idx + for idx, rel in zip(idx_list, rel_list, strict=True) + if rel >= binary_relevance_threshold + ] + for idx_list, rel_list in zip(pos_label_idxs, pos_label_relevance, strict=True) + ] + results = {} def _metric_k_split(metric: str) -> tuple[str, int | None]: @@ -50,30 +83,28 @@ def _metric_k_split(metric: str) -> tuple[str, int | None]: base_metric, k = _metric_k_split(metric) if metric == "map": - results[metric] = _calculate_map(sorted_indices, pos_label_idxs) + results[metric] = _calculate_map(sorted_indices, binary_pos_label_idxs) elif base_metric == "rp": assert k is not None, "k must be provided for rp@k metrics" - results[metric] = _calculate_rp_at_k(sorted_indices, pos_label_idxs, k) + results[metric] = _calculate_rp_at_k(sorted_indices, binary_pos_label_idxs, k) elif metric == "mrr": - results[metric] = _calculate_mrr(sorted_indices, pos_label_idxs) + results[metric] = _calculate_mrr(sorted_indices, binary_pos_label_idxs) elif base_metric == "recall": assert k is not None, "k must be provided for recall@k metrics" - results[metric] = _calculate_recall_at_k(sorted_indices, pos_label_idxs, k) + results[metric] = _calculate_recall_at_k(sorted_indices, binary_pos_label_idxs, k) elif base_metric == "hit": assert k is not None, "k must be provided for hit@k metrics" - results[metric] = _calculate_hit_at_k(sorted_indices, pos_label_idxs, k) + results[metric] = _calculate_hit_at_k(sorted_indices, binary_pos_label_idxs, k) elif base_metric == "ndcg": - if k is not None: - results[metric] = _calculate_ndcg(sorted_indices, pos_label_idxs, k) - else: - results[metric] = _calculate_ndcg( - sorted_indices, pos_label_idxs, sorted_indices.shape[1] - ) + cutoff = k if k is not None else sorted_indices.shape[1] + results[metric] = _calculate_ndcg( + sorted_indices, pos_label_idxs, pos_label_relevance, cutoff + ) else: raise ValueError(f"Unknown ranking metric '{metric}'") @@ -204,24 +235,46 @@ def _calculate_rp_at_k( return float(np.mean(rp_scores)) if rp_scores else 0.0 -def _calculate_ndcg(sorted_indices: np.ndarray, pos_label_idxs: list[list[int]], k: int) -> float: - """Calculate Normalized Discounted Cumulative Gain@K (binary relevance).""" +def _calculate_ndcg( + sorted_indices: np.ndarray, + pos_label_idxs: list[list[int]], + pos_label_relevance: list[list[float]] | None, + k: int, +) -> float: + """Calculate nDCG@K with the (2^rel - 1) gain and log2(i+2) discount. + + When ``pos_label_relevance`` is None, every positive is treated as + relevance 1.0 (binary fallback). With binary positives the gain + ``(2^1 - 1) = 1`` matches the binary implementation exactly. + """ ndcg_scores = [] for i, pos_labels in enumerate(pos_label_idxs): if len(pos_labels) == 0: continue - pos_labels_set = set(pos_labels) - - # DCG@k: sum of 1/log2(rank+1) for relevant items in top-k - dcg = 0.0 - for rank_idx in range(min(k, len(sorted_indices[i]))): - if sorted_indices[i][rank_idx] in pos_labels_set: - dcg += 1.0 / np.log2(rank_idx + 2) # 0-based -> log2(pos+1) - - # IDCG@k: ideal DCG with all relevant items ranked first - idcg = sum(1.0 / np.log2(j + 2) for j in range(min(k, len(pos_labels)))) + if pos_label_relevance is None: + relevance_by_idx = dict.fromkeys(pos_labels, 1.0) + else: + relevance_by_idx = { + idx: float(rel) for idx, rel in zip(pos_labels, pos_label_relevance[i], strict=True) + } + + cutoff = min(k, len(sorted_indices[i])) + + gains = np.array( + [ + (2.0 ** relevance_by_idx.get(int(idx), 0.0)) - 1.0 + for idx in sorted_indices[i][:cutoff] + ] + ) + discounts = 1.0 / np.log2(np.arange(cutoff) + 2) + dcg = float(np.sum(gains * discounts)) + + ideal_relevances = sorted(relevance_by_idx.values(), reverse=True)[:cutoff] + ideal_gains = np.array([(2.0**rel) - 1.0 for rel in ideal_relevances]) + ideal_discounts = 1.0 / np.log2(np.arange(len(ideal_gains)) + 2) + idcg = float(np.sum(ideal_gains * ideal_discounts)) if idcg > 0: ndcg_scores.append(dcg / idcg) diff --git a/src/workrb/tasks/abstract/ranking_base.py b/src/workrb/tasks/abstract/ranking_base.py index ed99200..701369e 100644 --- a/src/workrb/tasks/abstract/ranking_base.py +++ b/src/workrb/tasks/abstract/ranking_base.py @@ -61,6 +61,7 @@ def __init__( target_indices: list[list[int]], target_space: list[str], dataset_id: str, + target_relevance: list[list[float]] | None = None, duplicate_query_strategy: DuplicateStrategy = DuplicateStrategy.RESOLVE, duplicate_target_strategy: DuplicateStrategy = DuplicateStrategy.RESOLVE, ): @@ -71,11 +72,20 @@ def __init__( query_texts : list[str] List of query strings. target_indices : list[list[int]] - List of lists containing indices into the target vocabulary. + List of lists containing indices into the target vocabulary. Items not + listed for a query are treated as having relevance 0 (unjudged / + irrelevant) by graded metrics. target_space : list[str] List of target vocabulary strings. dataset_id : str Unique identifier for this dataset. + target_relevance : list[list[float]] or None, optional + Optional graded relevance per positive, aligned 1-to-1 with + ``target_indices``. When ``None``, every entry in ``target_indices`` is + treated as binary relevance 1.0. Used by graded metrics such as + ``ndcg@k``; binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, + ``rp@k``) ignore this field. Values must be non-negative; the scale is + up to the task (e.g. {1, 2, 3} or {0.0..1.0}). duplicate_query_strategy : DuplicateStrategy How to handle duplicate query texts. ALLOW silently accepts them, RAISE raises on duplicates, RESOLVE merges their target_indices via union. @@ -84,7 +94,9 @@ def __init__( RAISE raises on duplicates, RESOLVE keeps first occurrence and remaps indices. """ self.query_texts = self._postprocess_texts(query_texts) - self.target_indices = self._postprocess_indices(target_indices) + self.target_indices, self.target_relevance = self._postprocess_indices( + target_indices, target_relevance + ) self.target_space = self._postprocess_texts(target_space) self.dataset_id = dataset_id @@ -121,27 +133,62 @@ def _resolve_duplicate_targets(self) -> None: duplicates, ) self.target_space = new_target_space - self.target_indices = [ - sorted(set(old_to_new[idx] for idx in idx_list)) for idx_list in self.target_indices - ] + if self.target_relevance is None: + self.target_indices = [ + sorted(set(old_to_new[idx] for idx in idx_list)) + for idx_list in self.target_indices + ] + else: + # Remap indices, dedup, and keep relevance from the first occurrence + # of each remapped index so (idx, rel) pairs stay aligned. + new_indices: list[list[int]] = [] + new_relevance: list[list[float]] = [] + for idx_list, rel_list in zip( + self.target_indices, self.target_relevance, strict=True + ): + seen_idx: dict[int, float] = {} + for idx, rel in zip(idx_list, rel_list, strict=True): + new_idx = old_to_new[idx] + if new_idx not in seen_idx: + seen_idx[new_idx] = rel + sorted_pairs = sorted(seen_idx.items()) + new_indices.append([idx for idx, _ in sorted_pairs]) + new_relevance.append([rel for _, rel in sorted_pairs]) + self.target_indices = new_indices + self.target_relevance = new_relevance def _resolve_duplicate_queries(self) -> None: - """Deduplicate query_texts, merging target_indices via union.""" + """Deduplicate query_texts, merging target_indices via union. + + When ``target_relevance`` is set, the merge keeps the relevance from the + first query occurrence for each index; relevance values from later + duplicates of the same (query, index) pair are dropped. + """ seen: dict[str, int] = {} new_queries: list[str] = [] - new_indices: list[list[int]] = [] + new_pairs: list[dict[int, float]] = [] duplicates: list[str] = [] - for query, idx_list in zip(self.query_texts, self.target_indices): + graded = self.target_relevance is not None + relevance_iter = ( + self.target_relevance + if graded + else [[1.0] * len(idx_list) for idx_list in self.target_indices] + ) + + for query, idx_list, rel_list in zip( + self.query_texts, self.target_indices, relevance_iter, strict=True + ): if query in seen: pos = seen[query] - merged = set(new_indices[pos]) | set(idx_list) - new_indices[pos] = sorted(merged) + for idx, rel in zip(idx_list, rel_list, strict=True): + if idx not in new_pairs[pos]: + new_pairs[pos][idx] = rel duplicates.append(query) else: seen[query] = len(new_queries) new_queries.append(query) - new_indices.append(sorted(idx_list)) + new_pairs.append(dict(zip(idx_list, rel_list, strict=True))) if duplicates: logger.warning( @@ -151,7 +198,15 @@ def _resolve_duplicate_queries(self) -> None: duplicates, ) self.query_texts = new_queries + new_indices: list[list[int]] = [] + new_relevance: list[list[float]] = [] + for pairs in new_pairs: + sorted_pairs = sorted(pairs.items()) + new_indices.append([idx for idx, _ in sorted_pairs]) + new_relevance.append([rel for _, rel in sorted_pairs]) self.target_indices = new_indices + if graded: + self.target_relevance = new_relevance def _validate_dataset( self, @@ -194,11 +249,54 @@ def _validate_dataset( ) assert isinstance(idx, int), f"Target index {idx} is not an integer" - def _postprocess_indices(self, indices: list[list[int]]) -> list[list[int]]: - """Postprocess indices.""" - # Remove duplicates in target_label - indices = [sorted(set(label_list)) for label_list in indices] - return indices + # Check target_relevance alignment and non-negativity + if self.target_relevance is not None: + assert len(self.target_relevance) == len(self.target_indices), ( + f"target_relevance has {len(self.target_relevance)} queries but " + f"target_indices has {len(self.target_indices)}" + ) + for q_i, (rel_list, idx_list) in enumerate( + zip(self.target_relevance, self.target_indices, strict=True) + ): + assert len(rel_list) == len(idx_list), ( + f"target_relevance[{q_i}] has length {len(rel_list)} but " + f"target_indices[{q_i}] has length {len(idx_list)}" + ) + for rel in rel_list: + assert rel >= 0, f"Negative relevance value {rel} at query {q_i}" + + def _postprocess_indices( + self, + indices: list[list[int]], + relevance: list[list[float]] | None, + ) -> tuple[list[list[int]], list[list[float]] | None]: + """Postprocess indices and aligned relevance, dropping duplicate indices. + + Indices are sorted; relevance is permuted in lockstep so each (idx, rel) + pair stays aligned. When duplicate indices appear within a query, the + relevance from the first occurrence is kept. + """ + if relevance is None: + return [sorted(set(label_list)) for label_list in indices], None + + assert len(relevance) == len(indices), ( + f"target_relevance has {len(relevance)} queries but target_indices has {len(indices)}" + ) + deduped_indices: list[list[int]] = [] + deduped_relevance: list[list[float]] = [] + for idx_list, rel_list in zip(indices, relevance, strict=True): + assert len(idx_list) == len(rel_list), ( + f"target_indices and target_relevance must align per query " + f"(got {len(idx_list)} vs {len(rel_list)})" + ) + seen: dict[int, float] = {} + for idx, rel in zip(idx_list, rel_list, strict=True): + if idx not in seen: + seen[idx] = float(rel) + sorted_pairs = sorted(seen.items()) + deduped_indices.append([idx for idx, _ in sorted_pairs]) + deduped_relevance.append([rel for _, rel in sorted_pairs]) + return deduped_indices, deduped_relevance def _postprocess_texts(self, texts: list[str]) -> list[str]: """Postprocess texts.""" @@ -223,6 +321,25 @@ def task_type(self) -> TaskType: def default_metrics(self) -> list[str]: return ["map", "rp@10", "mrr"] + @property + def binary_relevance_threshold(self) -> float: + """Minimum graded relevance for an item to count as a positive. + + Used by binary metrics (``map``, ``mrr``, ``recall@k``, ``hit@k``, + ``rp@k``) when the dataset provides ``target_relevance``: items with + relevance ``>= threshold`` are treated as positives, items below it + are dropped from the binary positive set but still contribute to + graded metrics such as ``ndcg@k``. + + Default is ``1e-9`` so any listed item with a non-zero grade counts + as a positive — equivalent to today's behavior where every entry in + ``target_indices`` is a positive. Override on the task to express + a stricter threshold (e.g. ``2.0`` on a 0-3 scale). + + Has no effect when ``target_relevance`` is ``None``. + """ + return 1e-9 + def __init__( self, **kwargs, @@ -327,9 +444,15 @@ def evaluate( if isinstance(prediction_matrix, torch.Tensor): prediction_matrix = prediction_matrix.cpu().float().numpy() - # Calculate metrics + # Calculate metrics. When the dataset provides graded relevance, binary + # metrics consume only positives with relevance >= binary_relevance_threshold; + # nDCG still sees the full graded label list. metric_results = calculate_ranking_metrics( - prediction_matrix=prediction_matrix, pos_label_idxs=labels, metrics=metrics + prediction_matrix=prediction_matrix, + pos_label_idxs=labels, + metrics=metrics, + pos_label_relevance=dataset.target_relevance, + binary_relevance_threshold=self.binary_relevance_threshold, ) return metric_results diff --git a/tests/test_duplicate_strategy.py b/tests/test_duplicate_strategy.py index e81e8cf..5414826 100644 --- a/tests/test_duplicate_strategy.py +++ b/tests/test_duplicate_strategy.py @@ -2,7 +2,7 @@ import pytest -from workrb.tasks.abstract.ranking_base import DuplicateStrategy, RankingDataset +from workrb.tasks.abstract.ranking_base import DuplicateStrategy, RankingDataset, RankingTask class TestDuplicateStrategyRaise: @@ -208,3 +208,125 @@ def test_default_resolves_duplicate_targets(self): ) assert ds.target_space == ["a"] assert ds.target_indices == [[0]] + + +class TestTargetRelevance: + """Test optional target_relevance carried through validation and dedup.""" + + def test_default_is_none(self): + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + ) + assert ds.target_relevance is None + + def test_aligned_relevance_preserved(self): + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[3.0, 1.0]], + ) + # Indices are sorted; relevance is permuted in lockstep + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 3.0, 1: 1.0} + + def test_dedup_keeps_first_relevance(self): + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1, 0]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[2.0, 1.0, 9.0]], + ) + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 2.0, 1: 1.0} + + def test_indices_sorted_with_relevance_permuted(self): + """When indices are reordered, relevance must follow.""" + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[2, 0, 1]], + target_space=["a", "b", "c"], + dataset_id="test", + target_relevance=[[1.0, 3.0, 2.0]], + ) + assert ds.target_indices == [[0, 1, 2]] + assert ds.target_relevance == [[3.0, 2.0, 1.0]] + + def test_misaligned_lengths_raise(self): + with pytest.raises(AssertionError): + RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[1.0]], + ) + + def test_negative_relevance_raises(self): + with pytest.raises(AssertionError, match="Negative relevance"): + RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1]], + target_space=["a", "b"], + dataset_id="test", + target_relevance=[[1.0, -0.5]], + ) + + def test_resolve_targets_carries_relevance(self): + """When duplicate targets collapse, relevance from the first occurrence wins.""" + ds = RankingDataset( + query_texts=["q1"], + target_indices=[[0, 1, 2]], + target_space=["x", "y", "x"], + dataset_id="test", + target_relevance=[[3.0, 2.0, 9.0]], + duplicate_target_strategy=DuplicateStrategy.RESOLVE, + ) + # target_space dedups to ["x", "y"]; idx 2 ("x") remaps to 0 (already present) + # First-occurrence relevance for idx 0 is 3.0; the 9.0 is dropped. + assert ds.target_space == ["x", "y"] + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 3.0, 1: 2.0} + + def test_resolve_queries_carries_relevance(self): + """When duplicate queries merge, relevance is unioned with first-wins for ties.""" + ds = RankingDataset( + query_texts=["q1", "q1"], + target_indices=[[0], [1]], + target_space=["a", "b", "c"], + dataset_id="test", + target_relevance=[[3.0], [2.0]], + duplicate_query_strategy=DuplicateStrategy.RESOLVE, + ) + assert ds.query_texts == ["q1"] + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + assert pairs == {0: 3.0, 1: 2.0} + + def test_resolve_queries_first_relevance_wins_on_overlap(self): + """If duplicate queries share an index, the first query's relevance is kept.""" + ds = RankingDataset( + query_texts=["q1", "q1"], + target_indices=[[0, 1], [1, 2]], + target_space=["a", "b", "c"], + dataset_id="test", + target_relevance=[[3.0, 2.0], [9.0, 1.0]], + duplicate_query_strategy=DuplicateStrategy.RESOLVE, + ) + assert ds.query_texts == ["q1"] + pairs = dict(zip(ds.target_indices[0], ds.target_relevance[0], strict=True)) + # idx 1 appears in both queries; relevance from the first (2.0) wins, not 9.0 + assert pairs == {0: 3.0, 1: 2.0, 2: 1.0} + + +class TestBinaryRelevanceThresholdDefault: + """RankingTask exposes binary_relevance_threshold with a sensible default.""" + + def test_default_value_is_small_positive(self): + """Default threshold treats any non-zero grade as positive (current behavior).""" + # Read directly off the abstract class attribute — no instantiation needed. + assert RankingTask.binary_relevance_threshold.fget(None) == pytest.approx(1e-9) diff --git a/tests/test_ranking_metrics.py b/tests/test_ranking_metrics.py index 1153812..d4c0546 100644 --- a/tests/test_ranking_metrics.py +++ b/tests/test_ranking_metrics.py @@ -1,11 +1,22 @@ """Tests for ranking metrics (workrb.metrics.ranking).""" +import math + import numpy as np import pytest import torch from workrb.metrics.ranking import calculate_ranking_metrics + +def _prediction_matrix_from_order(order: list[int], n_targets: int) -> np.ndarray: + """Build a 1-query prediction matrix that ranks targets in the given order (best first).""" + scores = np.zeros((1, n_targets), dtype=float) + for rank, idx in enumerate(order): + scores[0, idx] = float(n_targets - rank) + return scores + + # Shared test fixtures: 2 queries x 5 targets # Query 0 sorted order: [0, 2, 4, 3, 1] (scores: 0.9, 0.1, 0.8, 0.2, 0.5) # Query 1 sorted order: [3, 1, 2, 0, 4] (scores: 0.3, 0.7, 0.5, 0.9, 0.1) @@ -152,3 +163,172 @@ def test_existing_metrics_still_work(self): for metric_name in ["map", "mrr", "recall@3", "hit@3", "rp@3"]: assert isinstance(results[metric_name], float) assert 0.0 <= results[metric_name] <= 1.0 + + +class TestNDCGGradedRelevance: + """nDCG with graded relevance (pos_label_relevance not None).""" + + def test_explicit_relevance_one_matches_binary(self): + """Passing all-1.0 relevance must match the binary fallback exactly.""" + binary = calculate_ranking_metrics(PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["ndcg@3"]) + graded = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@3"], + pos_label_relevance=[[1.0, 1.0], [1.0, 1.0]], + ) + assert graded["ndcg@3"] == pytest.approx(binary["ndcg@3"]) + + def test_higher_grade_first_outscores_lower_grade_first(self): + """Ranking the higher-graded positive first must beat the reverse.""" + preds_good = _prediction_matrix_from_order([0, 1, 2, 3], n_targets=4) + preds_bad = _prediction_matrix_from_order([1, 0, 2, 3], n_targets=4) + good = calculate_ranking_metrics( + prediction_matrix=preds_good, + pos_label_idxs=[[0, 1]], + pos_label_relevance=[[3.0, 1.0]], + metrics=["ndcg@4"], + ) + bad = calculate_ranking_metrics( + prediction_matrix=preds_bad, + pos_label_idxs=[[0, 1]], + pos_label_relevance=[[3.0, 1.0]], + metrics=["ndcg@4"], + ) + assert good["ndcg@4"] == pytest.approx(1.0) + assert bad["ndcg@4"] < good["ndcg@4"] + + def test_graded_perfect_ranking_returns_one(self): + """Ideal ordering of grades [3, 2, 1] gives nDCG = 1.""" + preds = _prediction_matrix_from_order([0, 1, 2, 3], n_targets=4) + result = calculate_ranking_metrics( + prediction_matrix=preds, + pos_label_idxs=[[0, 1, 2]], + pos_label_relevance=[[3.0, 2.0, 1.0]], + metrics=["ndcg@4"], + ) + assert result["ndcg@4"] == pytest.approx(1.0) + + def test_graded_known_value(self): + """Hand-computed value with grades 2 and 1 over 3 targets. + + Positives: idx 0 grade 2, idx 2 grade 1. Ranking [0, 1, 2]. + gains @ ranks 0,1,2 = (2^2-1, 0, 2^1-1) = (3, 0, 1) + discounts = 1/log2(2), 1/log2(3), 1/log2(4) + DCG = 3 + 0 + 0.5 = 3.5 + Ideal grades [2, 1] -> gains (3, 1), discounts (1, 1/log2(3)) + IDCG = 3 + 1/log2(3) + """ + preds = _prediction_matrix_from_order([0, 1, 2], n_targets=3) + result = calculate_ranking_metrics( + prediction_matrix=preds, + pos_label_idxs=[[0, 2]], + pos_label_relevance=[[2.0, 1.0]], + metrics=["ndcg@3"], + ) + dcg = 3.0 + 1.0 / math.log2(4) + idcg = 3.0 + 1.0 / math.log2(3) + assert result["ndcg@3"] == pytest.approx(dcg / idcg) + + def test_relevance_ignored_by_binary_metrics(self): + """Binary metrics (map, mrr, ...) ignore pos_label_relevance entirely.""" + without = calculate_ranking_metrics( + PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["map", "mrr", "recall@3"] + ) + with_rel = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + pos_label_relevance=[[3.0, 1.0], [2.0, 2.0]], + ) + for metric in ["map", "mrr", "recall@3"]: + assert with_rel[metric] == pytest.approx(without[metric]) + + def test_misaligned_relevance_raises(self): + """A relevance list whose lengths don't match pos_label_idxs raises.""" + with pytest.raises(ValueError, match="zip"): + calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@3"], + pos_label_relevance=[[1.0], [1.0, 1.0]], + ) + + +class TestBinaryRelevanceThreshold: + """Threshold filters the binary positive set without affecting nDCG.""" + + def test_default_threshold_keeps_all_listed_positives(self): + """Default 1e-9 threshold passes any non-zero grade — same as no relevance.""" + without_rel = calculate_ranking_metrics( + PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["map", "mrr", "recall@3"] + ) + with_rel = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + pos_label_relevance=[[3.0, 1.0], [2.0, 2.0]], + ) + for metric in ["map", "mrr", "recall@3"]: + assert with_rel[metric] == pytest.approx(without_rel[metric]) + + def test_threshold_drops_subthreshold_positives_for_binary(self): + """A threshold of 2 drops grade-1 items from the binary positive set.""" + # Query 0 sorted: [0, 2, 4, 3, 1]. Positives: {2 (grade 3), 4 (grade 1)}. + # With threshold=2, only idx 2 stays positive for binary metrics. + # MRR: rank of first surviving positive (idx 2) is 2 -> 0.5 + # Query 1 sorted: [3, 1, 2, 0, 4]. Positives: {1 (grade 3), 3 (grade 1)}. + # With threshold=2, only idx 1 stays. Rank of idx 1 = 2 -> 0.5 + result = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["mrr"], + pos_label_relevance=[[3.0, 1.0], [3.0, 1.0]], + binary_relevance_threshold=2.0, + ) + assert result["mrr"] == pytest.approx(0.5) + + def test_threshold_does_not_affect_ndcg(self): + """NDCG always sees the full graded list regardless of threshold.""" + relevance = [[3.0, 1.0], [3.0, 1.0]] + low = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@5"], + pos_label_relevance=relevance, + binary_relevance_threshold=1e-9, + ) + high = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["ndcg@5"], + pos_label_relevance=relevance, + binary_relevance_threshold=2.0, + ) + assert low["ndcg@5"] == pytest.approx(high["ndcg@5"]) + + def test_threshold_ignored_when_relevance_is_none(self): + """Threshold is a no-op when no graded relevance is supplied.""" + ref = calculate_ranking_metrics( + PREDICTION_MATRIX, POS_LABEL_IDXS, metrics=["map", "mrr", "recall@3"] + ) + with_threshold = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + binary_relevance_threshold=99.0, # would drop everything if it applied + ) + for metric in ["map", "mrr", "recall@3"]: + assert with_threshold[metric] == pytest.approx(ref[metric]) + + def test_threshold_above_all_grades_empties_binary_set(self): + """A threshold above every grade leaves no binary positives -> 0.0.""" + result = calculate_ranking_metrics( + PREDICTION_MATRIX, + POS_LABEL_IDXS, + metrics=["map", "mrr", "recall@3"], + pos_label_relevance=[[3.0, 1.0], [3.0, 1.0]], + binary_relevance_threshold=10.0, + ) + for metric in ["map", "mrr", "recall@3"]: + assert result[metric] == 0.0