Skip to content
172 changes: 171 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Tuple

Expand All @@ -28,6 +29,7 @@
from pruna.data.datasets.prompt import (
setup_drawbench_dataset,
setup_genai_bench_dataset,
setup_geneval_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -97,8 +99,176 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}


@dataclass
class BenchmarkInfo:
"""
Metadata for a benchmark dataset.

Parameters
----------
name : str
Internal identifier for the benchmark.
display_name : str
Human-readable name for display purposes.
description : str
Description of what the benchmark evaluates.
metrics : list[str]
List of metric names used for evaluation.
task_type : str
Type of task the benchmark evaluates (e.g., 'text_to_image').
subsets : list[str]
Optional list of benchmark subset names.
"""

name: str
display_name: str
description: str
metrics: list[str]
task_type: str
subsets: list[str] = field(default_factory=list)


benchmark_info: dict[str, BenchmarkInfo] = {
"PartiPrompts": BenchmarkInfo(
name="parti_prompts",
display_name="Parti Prompts",
description=(
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
subsets=[
"Abstract",
"Animals",
"Artifacts",
"Arts",
"Food & Beverage",
"Illustrations",
"Indoor Scenes",
"Outdoor Scenes",
"People",
"Produce & Plants",
"Vehicles",
"World Knowledge",
"Basic",
"Complex",
"Fine-grained Detail",
"Imagination",
"Linguistic Structures",
"Perspective",
"Properties & Positioning",
"Quantity",
"Simple Detail",
"Style & Format",
"Writing & Symbols",
],
),
"DrawBench": BenchmarkInfo(
name="drawbench",
display_name="DrawBench",
description="A comprehensive benchmark for evaluating text-to-image generation models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"GenAIBench": BenchmarkInfo(
name="genai_bench",
display_name="GenAI Bench",
description="A benchmark for evaluating generative AI models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"VBench": BenchmarkInfo(
name="vbench",
display_name="VBench",
description="A benchmark for evaluating video generation models.",
metrics=["clip_score"],
task_type="text_to_video",
),
"GenEval": BenchmarkInfo(
name="geneval",
display_name="GenEval",
description=(
"Fine-grained compositional evaluation across object co-occurrence, positioning, "
"counting, and color binding to identify specific failure modes in text-to-image alignment."
),
metrics=["accuracy"],
task_type="text_to_image",
subsets=["single_object", "two_object", "counting", "colors", "position", "color_attr"],
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
metrics=["fid", "clip_score", "clipiqa"],
task_type="text_to_image",
),
"ImageNet": BenchmarkInfo(
name="imagenet",
display_name="ImageNet",
description="Large-scale image classification benchmark with 1000 classes.",
metrics=["accuracy"],
task_type="image_classification",
),
"WikiText": BenchmarkInfo(
name="wikitext",
display_name="WikiText",
description="Language modeling benchmark based on Wikipedia articles.",
metrics=["perplexity"],
task_type="text_generation",
),
}


def list_benchmarks(task_type: str | None = None) -> list[str]:
"""
List available benchmark names.

Parameters
----------
task_type : str | None
Filter by task type (e.g., 'text_to_image', 'text_to_video').
If None, returns all benchmarks.

Returns
-------
list[str]
List of benchmark names.
"""
if task_type is None:
return list(benchmark_info.keys())
return [name for name, info in benchmark_info.items() if info.task_type == task_type]


def get_benchmark_info(name: str) -> BenchmarkInfo:
"""
Get benchmark metadata by name.

Parameters
----------
name : str
The benchmark name.

Returns
-------
BenchmarkInfo
The benchmark metadata.

Raises
------
KeyError
If benchmark name is not found.
"""
if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
117 changes: 115 additions & 2 deletions src/pruna/data/datasets/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
return ds.select([0]), ds.select([0]), ds


def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_parti_prompts_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Parti Prompts dataset.

Expand All @@ -51,18 +55,127 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
category : str | None
Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts,
Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People,
Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex,
Fine-grained Detail, Imagination, Linguistic Structures, Perspective,
Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Parti Prompts dataset.
The Parti Prompts dataset (dummy train, dummy val, test).
"""
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]

if category is not None:
if isinstance(category, list):
ds = ds.filter(
lambda x: x["Category"] in category or x["Challenge"] in category
)
else:
ds = ds.filter(
lambda x: x["Category"] == category or x["Challenge"] == category
)

# Note: Not shuffling since these are test-only datasets

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

ds = ds.rename_column("Prompt", "text")
pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


GENEVAL_CATEGORIES = ["single_object", "two_object", "counting", "colors", "position", "color_attr"]


def _generate_geneval_question(entry: dict) -> list[str]:
"""Generate evaluation questions from GenEval metadata."""
tag = entry.get("tag", "")
include = entry.get("include", [])
questions = []

for obj in include:
cls = obj.get("class", "")
if "color" in obj:
questions.append(f"Does the image contain a {obj['color']} {cls}?")
elif "count" in obj:
questions.append(f"Does the image contain exactly {obj['count']} {cls}(s)?")
else:
questions.append(f"Does the image contain a {cls}?")

if tag == "position" and len(include) >= 2:
a_cls = include[0].get("class", "")
b_cls = include[1].get("class", "")
pos = include[1].get("position")
if pos and pos[0]:
questions.append(f"Is the {b_cls} {pos[0]} the {a_cls}?")

return questions


def setup_geneval_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the GenEval benchmark dataset.

License: MIT

Parameters
----------
seed : int
The seed to use.
category : str | None
Filter by category. Available: single_object, two_object, counting, colors, position, color_attr.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The GenEval dataset (dummy train, dummy val, test).
"""
import json

import requests

url = "https://raw.githubusercontent.com/djghosh13/geneval/d927da8e42fde2b1b5cd743da4df5ff83c1654ff/prompts/evaluation_metadata.jsonl"
response = requests.get(url)
data = [json.loads(line) for line in response.text.splitlines()]

if category is not None:
if category not in GENEVAL_CATEGORIES:
raise ValueError(f"Invalid category: {category}. Must be one of {GENEVAL_CATEGORIES}")
data = [entry for entry in data if entry.get("tag") == category]

records = []
for entry in data:
questions = _generate_geneval_question(entry)
records.append({
"text": entry["prompt"],
"tag": entry.get("tag", ""),
"questions": questions,
"include": entry.get("include", []),
})

ds = Dataset.from_list(records)
# Note: Not shuffling since these are test-only datasets

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

pruna_logger.info("GenEval is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the GenAI Bench dataset.
Expand Down
38 changes: 35 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Callable

import pytest
from transformers import AutoTokenizer
from datasets import Dataset
from torch.utils.data import TensorDataset
import torch
from transformers import AutoTokenizer

from pruna.data.datasets.image import setup_imagenet_dataset
from pruna.data.pruna_datamodule import PrunaDataModule

Expand Down Expand Up @@ -45,6 +44,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("GenAIBench", dict(), marks=pytest.mark.slow),
pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
pytest.param("VBench", dict(), marks=pytest.mark.slow),
pytest.param("GenEval", dict(), marks=pytest.mark.slow),
],
)
def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None:
Expand Down Expand Up @@ -80,3 +80,35 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar
assert labels.dtype == torch.int64
# iterate through the dataloaders
iterate_dataloaders(datamodule)



@pytest.mark.slow
def test_parti_prompts_with_category_filter():
"""Test PartiPrompts loading with category filter."""
dm = PrunaDataModule.from_string(
"PartiPrompts", category="Animals", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all(aux["Category"] == "Animals" for aux in auxiliaries)


@pytest.mark.slow
def test_geneval_with_category_filter():
"""Test GenEval loading with category filter."""
dm = PrunaDataModule.from_string(
"GenEval", category="counting", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all(aux["tag"] == "counting" for aux in auxiliaries)
assert all("questions" in aux for aux in auxiliaries)