From 8cd8d7d36a51f3073cd55eb11f2fd64813f7405d Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:07:40 +0100 Subject: [PATCH 1/7] feat: add LongTextBench benchmark Add LongTextBench for evaluating text-to-image with long, complex prompts. Uses X-Omni/LongText-Bench dataset from HuggingFace. - Add setup_long_text_bench_dataset with num_samples filtering - Register in base_datasets with prompt_with_auxiliaries_collate - Add BenchmarkInfo with metrics: ["text_score"] - Preserve text_content field in auxiliaries for evaluation - Add tests Co-authored-by: Cursor --- src/pruna/data/__init__.py | 31 ++++++++++++++++++++++++++++- src/pruna/data/datasets/prompt.py | 33 +++++++++++++++++++++++++++++++ tests/data/test_datamodule.py | 16 +++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 820d1262..d0323688 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -29,6 +29,7 @@ from pruna.data.datasets.prompt import ( setup_drawbench_dataset, setup_genai_bench_dataset, + setup_long_text_bench_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -100,6 +101,7 @@ "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } @@ -107,7 +109,23 @@ @dataclass class BenchmarkInfo: - """Metadata for a benchmark dataset.""" + """Metadata for a benchmark dataset. + + Parameters + ---------- + name : str + Internal identifier for the benchmark. + display_name : str + Human-readable name for display purposes. + description : str + Description of what the benchmark evaluates. + metrics : list[str] + List of metric names used for evaluation. + task_type : str + Type of task the benchmark evaluates (e.g., 'text_to_image'). + subsets : list[str] + Optional list of benchmark subset names. + """ name: str display_name: str @@ -175,4 +193,15 @@ class BenchmarkInfo: metrics=["clip", "fvd"], task_type="text_to_video", ), + "LongTextBench": BenchmarkInfo( + name="long_text_bench", + display_name="Long Text Bench", + description=( + "Extended detail-rich prompts averaging 284.89 tokens with evaluation dimensions of " + "character attributes, structured locations, scene attributes, and spatial relationships " + "to test compositional reasoning under long prompt complexity." + ), + metrics=["text_score"], + task_type="text_to_image", + ), } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 4f275675..8b8ccde4 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -84,6 +84,39 @@ def setup_parti_prompts_dataset( return ds.select([0]), ds.select([0]), ds +def setup_long_text_bench_dataset( + seed: int, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the Long Text Bench dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + num_samples : int | None + Maximum number of samples to return. If None, returns all 160 samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The Long Text Bench dataset (dummy train, dummy val, test). + """ + ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index] + ds = ds.rename_column("text", "text_content") + ds = ds.rename_column("prompt", "text") + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("LongTextBench is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the GenAI Bench dataset. diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 61550698..be8922f7 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -45,6 +45,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), + pytest.param("LongTextBench", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -96,3 +97,18 @@ def test_parti_prompts_with_category_filter(): assert len(prompts) == 4 assert all(isinstance(p, str) for p in prompts) assert all(aux["Category"] == "Animals" for aux in auxiliaries) + + +@pytest.mark.slow +def test_long_text_bench_auxiliaries(): + """Test LongTextBench loading with auxiliaries.""" + dm = PrunaDataModule.from_string( + "LongTextBench", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all("text_content" in aux for aux in auxiliaries) From e3002b8779b2054eb636686359db4714499b76d4 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:12:36 +0100 Subject: [PATCH 2/7] fix: correct Numpydoc format for BenchmarkInfo docstring Move summary to new line after opening quotes per Numpydoc GL01. Co-authored-by: Cursor --- src/pruna/data/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index d0323688..5cc70b06 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -109,7 +109,8 @@ @dataclass class BenchmarkInfo: - """Metadata for a benchmark dataset. + """ + Metadata for a benchmark dataset. Parameters ---------- From 54c53cf74ecf10717558c95452502dd493d73453 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:26:42 +0100 Subject: [PATCH 3/7] feat: add benchmark discovery functions and expand benchmark registry - Add list_benchmarks() to filter benchmarks by task type - Add get_benchmark_info() to retrieve benchmark metadata - Add COCO, ImageNet, WikiText to benchmark_info registry - Fix metric names to match MetricRegistry (clip_score, clipiqa) Co-authored-by: Cursor --- src/pruna/data/__init__.py | 76 +++++++++++++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 5cc70b06..c9e70969 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -145,7 +145,7 @@ class BenchmarkInfo: "ranging from basic to complex, enabling comprehensive assessment of model capabilities " "across different domains and difficulty levels." ), - metrics=["arniqa", "clip", "clip_iqa", "sharpness"], + metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], task_type="text_to_image", subsets=[ "Abstract", @@ -177,21 +177,21 @@ class BenchmarkInfo: name="drawbench", display_name="DrawBench", description="A comprehensive benchmark for evaluating text-to-image generation models.", - metrics=["clip", "clip_iqa", "sharpness"], + metrics=["clip_score", "clipiqa", "sharpness"], task_type="text_to_image", ), "GenAIBench": BenchmarkInfo( name="genai_bench", display_name="GenAI Bench", description="A benchmark for evaluating generative AI models.", - metrics=["clip", "clip_iqa", "sharpness"], + metrics=["clip_score", "clipiqa", "sharpness"], task_type="text_to_image", ), "VBench": BenchmarkInfo( name="vbench", display_name="VBench", description="A benchmark for evaluating video generation models.", - metrics=["clip", "fvd"], + metrics=["clip_score"], task_type="text_to_video", ), "LongTextBench": BenchmarkInfo( @@ -202,7 +202,73 @@ class BenchmarkInfo: "character attributes, structured locations, scene attributes, and spatial relationships " "to test compositional reasoning under long prompt complexity." ), - metrics=["text_score"], + metrics=["clip_score", "clipiqa"], task_type="text_to_image", ), + "COCO": BenchmarkInfo( + name="coco", + display_name="COCO", + description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", + metrics=["fid", "clip_score", "clipiqa"], + task_type="text_to_image", + ), + "ImageNet": BenchmarkInfo( + name="imagenet", + display_name="ImageNet", + description="Large-scale image classification benchmark with 1000 classes.", + metrics=["accuracy"], + task_type="image_classification", + ), + "WikiText": BenchmarkInfo( + name="wikitext", + display_name="WikiText", + description="Language modeling benchmark based on Wikipedia articles.", + metrics=["perplexity"], + task_type="text_generation", + ), } + + +def list_benchmarks(task_type: str | None = None) -> list[str]: + """ + List available benchmark names. + + Parameters + ---------- + task_type : str | None + Filter by task type (e.g., 'text_to_image', 'text_to_video'). + If None, returns all benchmarks. + + Returns + ------- + list[str] + List of benchmark names. + """ + if task_type is None: + return list(benchmark_info.keys()) + return [name for name, info in benchmark_info.items() if info.task_type == task_type] + + +def get_benchmark_info(name: str) -> BenchmarkInfo: + """ + Get benchmark metadata by name. + + Parameters + ---------- + name : str + The benchmark name. + + Returns + ------- + BenchmarkInfo + The benchmark metadata. + + Raises + ------ + KeyError + If benchmark name is not found. + """ + if name not in benchmark_info: + available = ", ".join(benchmark_info.keys()) + raise KeyError(f"Benchmark '{name}' not found. Available: {available}") + return benchmark_info[name] From bf93c09eedd675afd7a17f9a0649d3109d7b12c8 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 27 Feb 2026 10:09:08 +0100 Subject: [PATCH 4/7] chore: apply ruff format to data module, add lint-before-push script Made-with: Cursor --- scripts/lint-before-push.sh | 11 +++++++++++ src/pruna/data/datasets/text_generation.py | 11 +++++------ 2 files changed, 16 insertions(+), 6 deletions(-) create mode 100755 scripts/lint-before-push.sh diff --git a/scripts/lint-before-push.sh b/scripts/lint-before-push.sh new file mode 100755 index 00000000..e580f958 --- /dev/null +++ b/scripts/lint-before-push.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -e +echo "=== Ruff check ===" +uv run ruff check src/pruna/ +echo "=== Ruff format check ===" +uv run ruff format --check src/pruna/ +echo "=== Ty type checker ===" +uv run ty check src/pruna +echo "=== Pytest style ===" +uv run pytest -m "style" -q +echo "=== All lint checks passed ===" diff --git a/src/pruna/data/datasets/text_generation.py b/src/pruna/data/datasets/text_generation.py index e53ccc7c..6eb3dbc4 100644 --- a/src/pruna/data/datasets/text_generation.py +++ b/src/pruna/data/datasets/text_generation.py @@ -33,8 +33,7 @@ def setup_wikitext_dataset() -> Tuple[Dataset, Dataset, Dataset]: The WikiText dataset. """ train_dataset, val_dataset, test_dataset = load_dataset( - path="mikasenghaas/wikitext-2", - split=["train", "validation", "test"] + path="mikasenghaas/wikitext-2", split=["train", "validation", "test"] ) return train_dataset, val_dataset, test_dataset # type: ignore[return-value] @@ -57,15 +56,15 @@ def setup_wikitext_tiny_dataset(seed: int = 42, num_rows: int = 960) -> Tuple[Da Tuple[Dataset, Dataset, Dataset] The TinyWikiText dataset split .8/.1/.1 into train/val/test subsets, respectively. """ - assert 10 <= num_rows < 1000, 'the total number of rows, r, for the tiny wikitext dataset must be 10 <= r < 1000' + assert 10 <= num_rows < 1000, "the total number of rows, r, for the tiny wikitext dataset must be 10 <= r < 1000" # load the 'mikasenghaas/wikitext-2' dataset with a total of 21,580 rows using the setup_wikitext_dataset() function train_ds, val_ds, test_ds = setup_wikitext_dataset() # assert the wikitext dataset train/val/test splits each have enough rows for reducing to .8/.1/.1, respectively - assert train_ds.num_rows >= int(num_rows * 0.8), f'wikitext cannot be reduced to {num_rows} rows, train too small' - assert val_ds.num_rows >= int(num_rows * 0.1), f'wikitext cannot be reduced to {num_rows} rows, val too small' - assert test_ds.num_rows >= int(num_rows * 0.1), f'wikitext cannot be reduced to {num_rows} rows, test too small' + assert train_ds.num_rows >= int(num_rows * 0.8), f"wikitext cannot be reduced to {num_rows} rows, train too small" + assert val_ds.num_rows >= int(num_rows * 0.1), f"wikitext cannot be reduced to {num_rows} rows, val too small" + assert test_ds.num_rows >= int(num_rows * 0.1), f"wikitext cannot be reduced to {num_rows} rows, test too small" # randomly select from the wikitext dataset a total number of rows below 1000 split .8/.1/.1 between train/val/test train_dataset_tiny = train_ds.shuffle(seed=seed).select(range(int(num_rows * 0.8))) From 9df608ffd2f92ffe18d625592e9e569cf8d05347 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 27 Feb 2026 10:13:25 +0100 Subject: [PATCH 5/7] chore: fix get_literal_values_from_param docstring, add SCOPE to lint script Made-with: Cursor --- src/pruna/data/utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index 104c9636..f05cd5bd 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -40,7 +40,21 @@ def __init__(self, message: str = "Tokenizer is missing. Please provide a valid def get_literal_values_from_param(func: Callable[..., Any], param_name: str) -> list[str] | None: - """Extract Literal values from a function parameter's type annotation (handles Union).""" + """ + Extract Literal values from a function parameter's type annotation (handles Union). + + Parameters + ---------- + func : Callable[..., Any] + The function to inspect. + param_name : str + The parameter name to extract Literal values from. + + Returns + ------- + list[str] | None + List of string values if the parameter is a Literal type, None otherwise. + """ unwrapped = getattr(func, "func", func) sig = inspect.signature(unwrapped) if param_name not in sig.parameters: From a2d25fb258af25e9b2956760c16b681e9e00c2bc Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 27 Feb 2026 10:15:14 +0100 Subject: [PATCH 6/7] chore: remove scripts/lint-before-push.sh Made-with: Cursor --- scripts/lint-before-push.sh | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100755 scripts/lint-before-push.sh diff --git a/scripts/lint-before-push.sh b/scripts/lint-before-push.sh deleted file mode 100755 index e580f958..00000000 --- a/scripts/lint-before-push.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env bash -set -e -echo "=== Ruff check ===" -uv run ruff check src/pruna/ -echo "=== Ruff format check ===" -uv run ruff format --check src/pruna/ -echo "=== Ty type checker ===" -uv run ty check src/pruna -echo "=== Pytest style ===" -uv run pytest -m "style" -q -echo "=== All lint checks passed ===" From 89d006ef381e37a6b58fe86500e9fb75096a85d4 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 27 Feb 2026 10:32:19 +0100 Subject: [PATCH 7/7] chore: align metrics with Pruna, comment unsupported InferBench metrics Made-with: Cursor --- src/pruna/data/__init__.py | 20 +++++++++++++++++--- src/pruna/evaluation/benchmarks/__init__.py | 14 ++++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 21251711..4202a389 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -185,14 +185,24 @@ class BenchmarkInfo: name="drawbench", display_name="DrawBench", description="A comprehensive benchmark for evaluating text-to-image generation models.", - metrics=["clip_score", "clipiqa", "sharpness"], + metrics=[ + "clip_score", + "clipiqa", + "sharpness", + # "image_reward" not supported in Pruna + ], task_type="text_to_image", ), "GenAIBench": BenchmarkInfo( name="genai_bench", display_name="GenAI Bench", description="A benchmark for evaluating generative AI models.", - metrics=["clip_score", "clipiqa", "sharpness"], + metrics=[ + "clip_score", + "clipiqa", + "sharpness", + # "vqa" not supported in Pruna + ], task_type="text_to_image", ), "VBench": BenchmarkInfo( @@ -210,7 +220,11 @@ class BenchmarkInfo: "character attributes, structured locations, scene attributes, and spatial relationships " "to test compositional reasoning under long prompt complexity." ), - metrics=["clip_score", "clipiqa"], + metrics=[ + "clip_score", + "clipiqa", + # "text_score" not supported in Pruna + ], task_type="text_to_image", ), "COCO": BenchmarkInfo( diff --git a/src/pruna/evaluation/benchmarks/__init__.py b/src/pruna/evaluation/benchmarks/__init__.py index e6255fdb..ca8cb9a9 100644 --- a/src/pruna/evaluation/benchmarks/__init__.py +++ b/src/pruna/evaluation/benchmarks/__init__.py @@ -78,13 +78,23 @@ class BenchmarkRegistry: Benchmark( name="DrawBench", description="A comprehensive benchmark for evaluating text-to-image generation models.", - metrics=["clip_score", "clipiqa", "sharpness"], + metrics=[ + "clip_score", + "clipiqa", + "sharpness", + # "image_reward" not supported in Pruna + ], task_type="text_to_image", ), Benchmark( name="GenAI Bench", description="A benchmark for evaluating generative AI models.", - metrics=["clip_score", "clipiqa", "sharpness"], + metrics=[ + "clip_score", + "clipiqa", + "sharpness", + # "vqa" not supported in Pruna + ], task_type="text_to_image", ), Benchmark(