Skip to content

Commit 1f1724a

Browse files
fix: address PR comments - category filter, unused imports
- Fix category filter to handle list[str] in setup_*_dataset functions - Remove unused BenchmarkInfo imports in test files - Align with PR #502 fixes
1 parent 9669db3 commit 1f1724a

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/pruna/data/datasets/prompt.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,14 @@ def setup_parti_prompts_dataset(
7272
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]
7373

7474
if category is not None:
75-
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)
75+
if isinstance(category, list):
76+
ds = ds.filter(
77+
lambda x: x["Category"] in category or x["Challenge"] in category
78+
)
79+
else:
80+
ds = ds.filter(
81+
lambda x: x["Category"] == category or x["Challenge"] == category
82+
)
7683

7784
ds = ds.shuffle(seed=seed)
7885

tests/data/test_datamodule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from transformers import AutoTokenizer
66

7-
from pruna.data import BenchmarkInfo, benchmark_info
87
from pruna.data.datasets.image import setup_imagenet_dataset
98
from pruna.data.pruna_datamodule import PrunaDataModule
109

0 commit comments

Comments
 (0)