diff --git a/docs/plans/2025-02-26-benchmark-unification.md b/docs/plans/2025-02-26-benchmark-unification.md new file mode 100644 index 00000000..4f5c1ea0 --- /dev/null +++ b/docs/plans/2025-02-26-benchmark-unification.md @@ -0,0 +1,219 @@ +# Benchmark PR Unification Plan + +> **Goal:** Align all benchmark datasets with PartiPrompts. Minimize per-dataset code by extracting shared logic into a single helper. Same underlying flow everywhere. + +**Principle:** PartiPrompts is the reference. Every dataset uses the same signature, same sampling logic, and the same `_prepare_test_only_prompt_dataset` helper. + +--- + +## PartiPrompts reference pattern + +```python +def setup_parti_prompts_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: PartiCategory | list[PartiCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + ds = load_dataset("nateraw/parti-prompts")["train"] + if category is not None: + categories = [category] if not isinstance(category, list) else category + ds = ds.filter(lambda x: x["Category"] in categories or x["Challenge"] in categories) + test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size) + ds = ds.select(range(min(test_sample_size, len(ds)))) + ds = ds.rename_column("Prompt", "text") + return _prepare_test_only_prompt_dataset(ds, seed, "PartiPrompts") +``` + +--- + +## 1. Shared helper in `pruna/data/utils.py` + +**Add:** + +```python +def _prepare_test_only_prompt_dataset( + ds: Dataset, + seed: int, + dataset_name: str, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Shared tail for test-only prompt datasets: shuffle, return dummy train/val + test. + All benchmark datasets use this. + """ + ds = ds.shuffle(seed=seed) + pruna_logger.info(f"{dataset_name} is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds +``` + +**Effect:** Removes repeated `ds.shuffle(seed=seed); pruna_logger.info(...); return ds.select([0]), ds.select([0]), ds` from every setup function. + +--- + +## 2. Unified signature (match PartiPrompts) + +**All prompt/benchmark datasets** use the same param names, but **each dataset has its own Literal type** for `category`: + +```python +# PartiPrompts (existing) +PartiCategory = Literal["Abstract", "Animals", ...] + +# HPS +HPSCategory = Literal["anime", "concept-art", "paintings", "photo"] + +# GenEval +GenEvalCategory = Literal["single_object", "two_object", "counting", "colors", "position", "color_attr"] + +# ImgEdit +ImgEditCategory = Literal["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"] + +# GEditBench +GEditBenchCategory = Literal["background_change", ...] # (full list per dataset) + +# DPG +DPGCategory = Literal["entity", "attribute", "relation", "global", "other"] + +# OneIG +OneIGCategory = Literal["text_rendering", "portrait_alignment"] +``` + +**Signature per dataset:** + +```python +def setup_*_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: XCategory | list[XCategory] | None = None, # dataset-specific Literal +) -> Tuple[Dataset, Dataset, Dataset]: +``` + +**Replace `num_samples` with `test_sample_size`** and use `define_sample_size_for_dataset(ds, fraction, test_sample_size)` everywhere. + +**OneIG:** Use `category` with `OneIGCategory` instead of `subset` (align naming). + +--- + +## 3. Per-dataset: minimal divergence + +Each dataset keeps only: + +1. **Load** – dataset-specific (HF, URL, JSON) +2. **Filter** – if category: filter by dataset-specific column(s) +3. **Sample** – `define_sample_size_for_dataset` + `ds.select` (same as PartiPrompts) +4. **Rename** – ensure `text` column (if needed) +5. **Return** – `_prepare_test_only_prompt_dataset(ds, seed, "DatasetName")` + +**Example – HPS aligned with PartiPrompts (Literal for categories):** + +```python +from typing import Literal, get_args + +HPSCategory = Literal["anime", "concept-art", "paintings", "photo"] + +def setup_hps_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + category: HPSCategory | list[HPSCategory] | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + categories_to_load = list(get_args(HPSCategory)) if category is None else ( + [category] if not isinstance(category, list) else category + ) + + all_prompts = [] + for cat in categories_to_load: + file_path = hf_hub_download("zhwang/HPDv2", f"{cat}.json", ...) + with open(file_path, "r", encoding="utf-8") as f: + for prompt in json.load(f): + all_prompts.append({"text": prompt, "category": cat}) + + ds = Dataset.from_list(all_prompts) + test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size) + ds = ds.select(range(min(test_sample_size, len(ds)))) + return _prepare_test_only_prompt_dataset(ds, seed, "HPS") +``` + +**Example – GenEval / ImgEdit / DPG / GEditBench:** Same pattern: load → filter by category → build Dataset → `define_sample_size_for_dataset` + select → `_prepare_test_only_prompt_dataset`. + +--- + +## 4. Datamodule passthrough + +**Extend `from_string`** to pass `fraction`, `train_sample_size`, `test_sample_size` when the setup fn accepts them: + +```python +for param in ("fraction", "train_sample_size", "test_sample_size"): + if param in inspect.signature(setup_fn).parameters: + setup_fn = partial(setup_fn, **{param: locals()[param]}) +``` + +**OneIG:** Use `category` for subsets; no separate `subset` param. If OneIG must keep `subset` for backward compat, add passthrough for it. + +--- + +## 5. Benchmark category registry + single test + +**Option A – derive from Literal:** Use `get_literal_values_from_param(setup_fn, "category")` to get categories from each setup function's Literal type. No separate registry for categories. + +**Option B – explicit registry** (for aux_keys mapping): + +```python +BENCHMARK_CATEGORY_CONFIG: dict[str, tuple[str, list[str]]] = { + # (first_category_for_test, aux_keys_in_batch) + "PartiPrompts": ("Animals", ["Category", "Challenge"]), + "HPS": ("anime", ["category"]), + "GenEval": ("counting", ["tag"]), + "DPG": ("entity", ["category_broad"]), + "ImgEdit": ("replace", ["category"]), + "GEditBench": ("background_change", ["category"]), + "OneIG": ("text_rendering", ["subset"]), # or ["category"] if aligned +} +``` + +**Test params:** Use `get_literal_values_from_param` to get `(dataset_name, categories[0])` for each setup that has a Literal `category` param. Categories come from the function signature, not the registry. + +**Single parametrized test** (replaces all per-dataset category tests): + +```python +@pytest.mark.parametrize("dataset_name, category", [ + (name, cat) for name, (cat, _) in BENCHMARK_CATEGORY_CONFIG.items() + if name in base_datasets +]) +def test_benchmark_category_filter(dataset_name: str, category: str) -> None: + dm = PrunaDataModule.from_string(dataset_name, category=category, dataloader_args={"batch_size": 4}) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + _, aux_keys = BENCHMARK_CATEGORY_CONFIG[dataset_name] + assert all(any(aux.get(k) == category for k in aux_keys) for aux in auxiliaries) +``` + +**Remove:** `test_geneval_with_category_filter`, `test_hps_with_category_filter`, `test_dpg_with_category_filter`, `test_imgedit_with_category_filter`, `test_geditbench_with_category_filter`, `test_oneig_*` (category variants). **Keep:** `test_long_text_bench_auxiliaries` (no category). + +--- + +## 6. Execution order + +1. **PartiPrompts branch:** Add `_prepare_test_only_prompt_dataset` in utils, refactor PartiPrompts to use it, add `BENCHMARK_CATEGORY_CONFIG`, add `test_benchmark_category_filter`, extend datamodule passthrough. +2. **Merge PartiPrompts into each branch.** +3. **Per branch:** Refactor each dataset to use the shared helper + unified signature; add its entries to `BENCHMARK_CATEGORY_CONFIG`; remove per-dataset tests. + +--- + +## Summary + +| Change | Where | Effect | +|--------|-------|--------| +| `_prepare_test_only_prompt_dataset` | utils.py | Single return path for all benchmarks | +| Unified signature | prompt.py | Same params everywhere | +| `define_sample_size_for_dataset` | All setups | Same sampling logic as PartiPrompts | +| `BENCHMARK_CATEGORY_CONFIG` | __init__.py | One registry, one test | +| Datamodule passthrough | pruna_datamodule.py | fraction, test_sample_size forwarded | + +**Result:** PartiPrompts-style flow everywhere, minimal per-dataset code, shared helper and tests. Each dataset keeps its specific categories as a Literal in the function signature (type-safe, discoverable via `get_literal_values_from_param`). diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index deda62cb..f296a91e 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -31,6 +31,7 @@ setup_genai_bench_dataset, setup_geneval_dataset, setup_hps_dataset, + setup_long_text_bench_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -114,6 +115,7 @@ "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), "GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}), "HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}), + "LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } @@ -242,6 +244,19 @@ class BenchmarkInfo: task_type="text_to_image", subsets=["anime", "concept-art", "paintings", "photo"], ), + "LongTextBench": BenchmarkInfo( + name="long_text_bench", + display_name="Long Text Bench", + description=( + "Extended detail-rich prompts averaging 284.89 tokens with evaluation dimensions of " + "character attributes, structured locations, scene attributes, and spatial relationships " + "to test compositional reasoning under long prompt complexity." + ), + metrics=[ + # "text_score" not supported in Pruna + ], + task_type="text_to_image", + ), "COCO": BenchmarkInfo( name="coco", display_name="COCO", diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 3bb3000f..e095e6fa 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -250,6 +250,41 @@ def setup_hps_dataset( return _prepare_test_only_prompt_dataset(ds, seed, "HPS") +def setup_long_text_bench_dataset( + seed: int, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the Long Text Bench dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + The sample size to use for the test dataset. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The Long Text Bench dataset (dummy train, dummy val, test). + """ + ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index] + ds = ds.rename_column("text", "text_content") + ds = ds.rename_column("prompt", "text") + test_sample_size = define_sample_size_for_dataset(ds, fraction, test_sample_size) + ds = ds.select(range(min(test_sample_size, len(ds)))) + return _prepare_test_only_prompt_dataset(ds, seed, "LongTextBench") + + def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the GenAI Bench dataset. diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index e0511ba4..5f55737d 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -47,6 +47,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("VBench", dict(), marks=pytest.mark.slow), pytest.param("GenEval", dict(), marks=pytest.mark.slow), pytest.param("HPS", dict(), marks=pytest.mark.slow), + pytest.param("LongTextBench", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -104,3 +105,18 @@ def test_benchmark_category_filter(dataset_name: str, category: str) -> None: assert all(isinstance(p, str) for p in prompts) _, aux_keys = BENCHMARK_CATEGORY_CONFIG[dataset_name] assert all(any(aux.get(k) == category for k in aux_keys) for aux in auxiliaries) + + +@pytest.mark.slow +def test_long_text_bench_auxiliaries() -> None: + """Test LongTextBench loading with auxiliaries.""" + dm = PrunaDataModule.from_string( + "LongTextBench", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all("text_content" in aux for aux in auxiliaries)