File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -72,7 +72,14 @@ def setup_parti_prompts_dataset(
7272 ds = load_dataset ("nateraw/parti-prompts" )["train" ] # type: ignore[index]
7373
7474 if category is not None :
75- ds = ds .filter (lambda x : x ["Category" ] == category or x ["Challenge" ] == category )
75+ if isinstance (category , list ):
76+ ds = ds .filter (
77+ lambda x : x ["Category" ] in category or x ["Challenge" ] in category
78+ )
79+ else :
80+ ds = ds .filter (
81+ lambda x : x ["Category" ] == category or x ["Challenge" ] == category
82+ )
7683
7784 ds = ds .shuffle (seed = seed )
7885
Original file line number Diff line number Diff line change 44import torch
55from transformers import AutoTokenizer
66
7- from pruna .data import BenchmarkInfo , benchmark_info
87from pruna .data .datasets .image import setup_imagenet_dataset
98from pruna .data .pruna_datamodule import PrunaDataModule
109
You can’t perform that action at this time.
0 commit comments