Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions docs/plans/2025-02-26-benchmark-unification.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Benchmark PR Unification Plan

> **Goal:** Align all benchmark datasets with PartiPrompts. Minimize per-dataset code by extracting shared logic into a single helper. Same underlying flow everywhere.

**Principle:** PartiPrompts is the reference. Every dataset uses the same signature, same sampling logic, and the same `_prepare_test_only_prompt_dataset` helper.

---

## PartiPrompts reference pattern

```python
def setup_parti_prompts_dataset(
seed: int,
fraction: float = 1.0,
train_sample_size: int | None = None,
test_sample_size: int | None = None,
category: PartiCategory | list[PartiCategory] | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
ds = load_dataset("nateraw/parti-prompts")["train"]
if category is not None:
categories = [category] if not isinstance(category, list) else category
ds = ds.filter(lambda x: x["Category"] in categories or x["Challenge"] in categories)
test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size)
ds = ds.select(range(min(test_sample_size, len(ds))))
ds = ds.rename_column("Prompt", "text")
return _prepare_test_only_prompt_dataset(ds, seed, "PartiPrompts")
```

---

## 1. Shared helper in `pruna/data/utils.py`

**Add:**

```python
def _prepare_test_only_prompt_dataset(
ds: Dataset,
seed: int,
dataset_name: str,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Shared tail for test-only prompt datasets: shuffle, return dummy train/val + test.
All benchmark datasets use this.
"""
ds = ds.shuffle(seed=seed)
pruna_logger.info(f"{dataset_name} is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds
```

**Effect:** Removes repeated `ds.shuffle(seed=seed); pruna_logger.info(...); return ds.select([0]), ds.select([0]), ds` from every setup function.

---

## 2. Unified signature (match PartiPrompts)

**All prompt/benchmark datasets** use the same param names, but **each dataset has its own Literal type** for `category`:

```python
# PartiPrompts (existing)
PartiCategory = Literal["Abstract", "Animals", ...]

# HPS
HPSCategory = Literal["anime", "concept-art", "paintings", "photo"]

# GenEval
GenEvalCategory = Literal["single_object", "two_object", "counting", "colors", "position", "color_attr"]

# ImgEdit
ImgEditCategory = Literal["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"]

# GEditBench
GEditBenchCategory = Literal["background_change", ...] # (full list per dataset)

# DPG
DPGCategory = Literal["entity", "attribute", "relation", "global", "other"]

# OneIG
OneIGCategory = Literal["text_rendering", "portrait_alignment"]
```

**Signature per dataset:**

```python
def setup_*_dataset(
seed: int,
fraction: float = 1.0,
train_sample_size: int | None = None,
test_sample_size: int | None = None,
category: XCategory | list[XCategory] | None = None, # dataset-specific Literal
) -> Tuple[Dataset, Dataset, Dataset]:
```

**Replace `num_samples` with `test_sample_size`** and use `define_sample_size_for_dataset(ds, fraction, test_sample_size)` everywhere.

**OneIG:** Use `category` with `OneIGCategory` instead of `subset` (align naming).

---

## 3. Per-dataset: minimal divergence

Each dataset keeps only:

1. **Load** – dataset-specific (HF, URL, JSON)
2. **Filter** – if category: filter by dataset-specific column(s)
3. **Sample** – `define_sample_size_for_dataset` + `ds.select` (same as PartiPrompts)
4. **Rename** – ensure `text` column (if needed)
5. **Return** – `_prepare_test_only_prompt_dataset(ds, seed, "DatasetName")`

**Example – HPS aligned with PartiPrompts (Literal for categories):**

```python
from typing import Literal, get_args

HPSCategory = Literal["anime", "concept-art", "paintings", "photo"]

def setup_hps_dataset(
seed: int,
fraction: float = 1.0,
train_sample_size: int | None = None,
test_sample_size: int | None = None,
category: HPSCategory | list[HPSCategory] | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
categories_to_load = list(get_args(HPSCategory)) if category is None else (
[category] if not isinstance(category, list) else category
)

all_prompts = []
for cat in categories_to_load:
file_path = hf_hub_download("zhwang/HPDv2", f"{cat}.json", ...)
with open(file_path, "r", encoding="utf-8") as f:
for prompt in json.load(f):
all_prompts.append({"text": prompt, "category": cat})

ds = Dataset.from_list(all_prompts)
test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size)
ds = ds.select(range(min(test_sample_size, len(ds))))
return _prepare_test_only_prompt_dataset(ds, seed, "HPS")
```

**Example – GenEval / ImgEdit / DPG / GEditBench:** Same pattern: load → filter by category → build Dataset → `define_sample_size_for_dataset` + select → `_prepare_test_only_prompt_dataset`.

---

## 4. Datamodule passthrough

**Extend `from_string`** to pass `fraction`, `train_sample_size`, `test_sample_size` when the setup fn accepts them:

```python
for param in ("fraction", "train_sample_size", "test_sample_size"):
if param in inspect.signature(setup_fn).parameters:
setup_fn = partial(setup_fn, **{param: locals()[param]})
```

**OneIG:** Use `category` for subsets; no separate `subset` param. If OneIG must keep `subset` for backward compat, add passthrough for it.

---

## 5. Benchmark category registry + single test

**Option A – derive from Literal:** Use `get_literal_values_from_param(setup_fn, "category")` to get categories from each setup function's Literal type. No separate registry for categories.

**Option B – explicit registry** (for aux_keys mapping):

```python
BENCHMARK_CATEGORY_CONFIG: dict[str, tuple[str, list[str]]] = {
# (first_category_for_test, aux_keys_in_batch)
"PartiPrompts": ("Animals", ["Category", "Challenge"]),
"HPS": ("anime", ["category"]),
"GenEval": ("counting", ["tag"]),
"DPG": ("entity", ["category_broad"]),
"ImgEdit": ("replace", ["category"]),
"GEditBench": ("background_change", ["category"]),
"OneIG": ("text_rendering", ["subset"]), # or ["category"] if aligned
}
```

**Test params:** Use `get_literal_values_from_param` to get `(dataset_name, categories[0])` for each setup that has a Literal `category` param. Categories come from the function signature, not the registry.

**Single parametrized test** (replaces all per-dataset category tests):

```python
@pytest.mark.parametrize("dataset_name, category", [
(name, cat) for name, (cat, _) in BENCHMARK_CATEGORY_CONFIG.items()
if name in base_datasets
])
def test_benchmark_category_filter(dataset_name: str, category: str) -> None:
dm = PrunaDataModule.from_string(dataset_name, category=category, dataloader_args={"batch_size": 4})
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch
assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
_, aux_keys = BENCHMARK_CATEGORY_CONFIG[dataset_name]
assert all(any(aux.get(k) == category for k in aux_keys) for aux in auxiliaries)
```

**Remove:** `test_geneval_with_category_filter`, `test_hps_with_category_filter`, `test_dpg_with_category_filter`, `test_imgedit_with_category_filter`, `test_geditbench_with_category_filter`, `test_oneig_*` (category variants). **Keep:** `test_long_text_bench_auxiliaries` (no category).

---

## 6. Execution order

1. **PartiPrompts branch:** Add `_prepare_test_only_prompt_dataset` in utils, refactor PartiPrompts to use it, add `BENCHMARK_CATEGORY_CONFIG`, add `test_benchmark_category_filter`, extend datamodule passthrough.
2. **Merge PartiPrompts into each branch.**
3. **Per branch:** Refactor each dataset to use the shared helper + unified signature; add its entries to `BENCHMARK_CATEGORY_CONFIG`; remove per-dataset tests.

---

## Summary

| Change | Where | Effect |
|--------|-------|--------|
| `_prepare_test_only_prompt_dataset` | utils.py | Single return path for all benchmarks |
| Unified signature | prompt.py | Same params everywhere |
| `define_sample_size_for_dataset` | All setups | Same sampling logic as PartiPrompts |
| `BENCHMARK_CATEGORY_CONFIG` | __init__.py | One registry, one test |
| Datamodule passthrough | pruna_datamodule.py | fraction, test_sample_size forwarded |

**Result:** PartiPrompts-style flow everywhere, minimal per-dataset code, shared helper and tests. Each dataset keeps its specific categories as a Literal in the function signature (type-safe, discoverable via `get_literal_values_from_param`).
15 changes: 15 additions & 0 deletions src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
setup_genai_bench_dataset,
setup_geneval_dataset,
setup_hps_dataset,
setup_long_text_bench_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -114,6 +115,7 @@
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}),
"HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}),
"LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}
Expand Down Expand Up @@ -242,6 +244,19 @@ class BenchmarkInfo:
task_type="text_to_image",
subsets=["anime", "concept-art", "paintings", "photo"],
),
"LongTextBench": BenchmarkInfo(
name="long_text_bench",
display_name="Long Text Bench",
description=(
"Extended detail-rich prompts averaging 284.89 tokens with evaluation dimensions of "
"character attributes, structured locations, scene attributes, and spatial relationships "
"to test compositional reasoning under long prompt complexity."
),
metrics=[
# "text_score" not supported in Pruna
],
task_type="text_to_image",
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
Expand Down
35 changes: 35 additions & 0 deletions src/pruna/data/datasets/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,41 @@ def setup_hps_dataset(
return _prepare_test_only_prompt_dataset(ds, seed, "HPS")


def setup_long_text_bench_dataset(
seed: int,
fraction: float = 1.0,
train_sample_size: int | None = None,
test_sample_size: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Long Text Bench dataset.

License: Apache 2.0

Parameters
----------
seed : int
The seed to use.
fraction : float
The fraction of the dataset to use.
train_sample_size : int | None
Unused; train/val are dummy.
test_sample_size : int | None
The sample size to use for the test dataset.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Long Text Bench dataset (dummy train, dummy val, test).
"""
ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index]
ds = ds.rename_column("text", "text_content")
ds = ds.rename_column("prompt", "text")
test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size)
ds = ds.select(range(min(test_sample_size, len(ds))))
return _prepare_test_only_prompt_dataset(ds, seed, "LongTextBench")


def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the GenAI Bench dataset.
Expand Down
16 changes: 16 additions & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("VBench", dict(), marks=pytest.mark.slow),
pytest.param("GenEval", dict(), marks=pytest.mark.slow),
pytest.param("HPS", dict(), marks=pytest.mark.slow),
pytest.param("LongTextBench", dict(), marks=pytest.mark.slow),
],
)
def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None:
Expand Down Expand Up @@ -104,3 +105,18 @@ def test_benchmark_category_filter(dataset_name: str, category: str) -> None:
assert all(isinstance(p, str) for p in prompts)
_, aux_keys = BENCHMARK_CATEGORY_CONFIG[dataset_name]
assert all(any(aux.get(k) == category for k in aux_keys) for aux in auxiliaries)


@pytest.mark.slow
def test_long_text_bench_auxiliaries() -> None:
"""Test LongTextBench loading with auxiliaries."""
dm = PrunaDataModule.from_string(
"LongTextBench", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all("text_content" in aux for aux in auxiliaries)