Skip to content

Commit 7e9e4dc

Browse files
committed
add tests for grabbing the dataset
Signed-off-by: dalthecow <dalcowboiz@gmail.com>
1 parent f0ad9f7 commit 7e9e4dc

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

src/guidellm/presentation/data_models.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def from_data(cls, request_loader: Any):
8989
if creator == SyntheticDatasetCreator:
9090
data_dict = SyntheticDatasetConfig.parse_str(data)
9191
dataset_name = data_dict.source
92-
if creator == FileDatasetCreator or isinstance(
93-
creator, HFDatasetsCreator
94-
):
92+
if creator == FileDatasetCreator or creator == HFDatasetsCreator:
9593
dataset_name = data
9694
if creator == InMemoryDatasetCreator:
9795
dataset_name = "In-memory"

tests/unit/presentation/test_data_models.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import pytest
2+
from unittest.mock import MagicMock, patch
23

3-
from guidellm.presentation.data_models import Bucket
4+
from guidellm.dataset.file import FileDatasetCreator
5+
from guidellm.dataset.hf_datasets import HFDatasetsCreator
6+
from guidellm.dataset.in_memory import InMemoryDatasetCreator
7+
from guidellm.dataset.synthetic import SyntheticDatasetCreator
8+
from guidellm.presentation.data_models import Bucket, Dataset
9+
from tests.unit.mock_benchmark import mock_generative_benchmark
410

511

612
@pytest.mark.smoke
@@ -18,3 +24,80 @@ def test_bucket_from_data():
1824
assert buckets[1].value == 8.0
1925
assert buckets[1].count == 5
2026
assert bucket_width == 1
27+
28+
def mock_processor(cls):
29+
return mock_generative_benchmark().request_loader.processor
30+
31+
def new_handle_create(cls, *args, **kwargs):
32+
return MagicMock()
33+
34+
def new_extract_dataset_name(cls, *args, **kwargs):
35+
return "data:prideandprejudice.txt.gz"
36+
37+
@pytest.mark.smoke
38+
def test_dataset_from_data_uses_extracted_dataset_name():
39+
mock_benchmark = mock_generative_benchmark()
40+
with (
41+
patch.object(SyntheticDatasetCreator, 'handle_create', new=new_handle_create),
42+
patch.object(SyntheticDatasetCreator, 'extract_dataset_name', new=new_extract_dataset_name)
43+
):
44+
dataset = Dataset.from_data(mock_benchmark.request_loader)
45+
assert dataset.name == "data:prideandprejudice.txt.gz"
46+
# with unittest.mock.patch.object(PreTrainedTokenizerBase, 'processor', new=mock_processor):
47+
48+
def new_is_supported(cls, *args, **kwargs):
49+
return True
50+
51+
@pytest.mark.smoke
52+
def test_dataset_from_data_with_in_memory_dataset():
53+
mock_benchmark = mock_generative_benchmark()
54+
with patch.object(InMemoryDatasetCreator, 'is_supported', new=new_is_supported):
55+
dataset = Dataset.from_data(mock_benchmark.request_loader)
56+
assert dataset.name == "In-memory"
57+
58+
def hardcoded_isnt_supported(cls, *args, **kwargs):
59+
return False
60+
61+
def new_extract_dataset_name_none(cls, *args, **kwargs):
62+
return None
63+
64+
@pytest.mark.smoke
65+
def test_dataset_from_data_with_synthetic_dataset():
66+
mock_benchmark = mock_generative_benchmark()
67+
with (
68+
patch.object(SyntheticDatasetCreator, 'handle_create', new=new_handle_create),
69+
patch.object(InMemoryDatasetCreator, 'is_supported', new=hardcoded_isnt_supported),
70+
patch.object(SyntheticDatasetCreator, 'is_supported', new=new_is_supported),
71+
patch.object(SyntheticDatasetCreator, 'extract_dataset_name', new=new_extract_dataset_name_none)
72+
):
73+
dataset = Dataset.from_data(mock_benchmark.request_loader)
74+
assert dataset.name == "data:prideandprejudice.txt.gz"
75+
76+
@pytest.mark.smoke
77+
def test_dataset_from_data_with_file_dataset():
78+
mock_benchmark = mock_generative_benchmark()
79+
mock_benchmark.request_loader.data = 'dataset.yaml'
80+
with (
81+
patch.object(FileDatasetCreator, 'handle_create', new=new_handle_create),
82+
patch.object(InMemoryDatasetCreator, 'is_supported', new=hardcoded_isnt_supported),
83+
patch.object(SyntheticDatasetCreator, 'is_supported', new=hardcoded_isnt_supported),
84+
patch.object(FileDatasetCreator, 'is_supported', new=new_is_supported),
85+
patch.object(FileDatasetCreator, 'extract_dataset_name', new=new_extract_dataset_name_none)
86+
):
87+
dataset = Dataset.from_data(mock_benchmark.request_loader)
88+
assert dataset.name == "dataset.yaml"
89+
90+
@pytest.mark.smoke
91+
def test_dataset_from_data_with_hf_dataset():
92+
mock_benchmark = mock_generative_benchmark()
93+
mock_benchmark.request_loader.data = 'openai/gsm8k'
94+
with (
95+
patch.object(HFDatasetsCreator, 'handle_create', new=new_handle_create),
96+
patch.object(InMemoryDatasetCreator, 'is_supported', new=hardcoded_isnt_supported),
97+
patch.object(SyntheticDatasetCreator, 'is_supported', new=hardcoded_isnt_supported),
98+
patch.object(FileDatasetCreator, 'is_supported', new=hardcoded_isnt_supported),
99+
patch.object(HFDatasetsCreator, 'is_supported', new=new_is_supported),
100+
patch.object(HFDatasetsCreator, 'extract_dataset_name', new=new_extract_dataset_name_none)
101+
):
102+
dataset = Dataset.from_data(mock_benchmark.request_loader)
103+
assert dataset.name == "openai/gsm8k"

0 commit comments

Comments
 (0)