-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_dynamic_data_loader.py
More file actions
75 lines (60 loc) · 3.27 KB
/
test_dynamic_data_loader.py
File metadata and controls
75 lines (60 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from eval_protocol.data_loader import DynamicDataLoader
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
def my_factory() -> list[EvaluationRow]:
"""Factory function that generates evaluation rows dynamically."""
return [EvaluationRow(messages=[Message(role="user", content="What is 2 + 2?")])]
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[my_factory],
),
)
def test_dynamic_data_loader(row: EvaluationRow) -> EvaluationRow:
"""Dynamic data loader should feed dynamically generated message bundles."""
assert row.messages[0].content == "What is 2 + 2?"
assert row.input_metadata.dataset_info is not None
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "my_factory"
assert row.input_metadata.dataset_info.get("data_loader_num_rows") == 1
assert row.input_metadata.dataset_info.get("data_loader_num_rows_after_preprocessing") == 1
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
assert (
row.input_metadata.dataset_info.get("data_loader_variant_description")
== "Factory function that generates evaluation rows dynamically."
)
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
return row
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[lambda: [EvaluationRow(messages=[Message(role="user", content="What is 3 * 3?")])]],
),
)
def test_dynamic_data_loader_lambda(row: EvaluationRow) -> EvaluationRow:
"""Dynamic data loader should work with lambda functions."""
assert row.messages[0].content == "What is 3 * 3?"
assert row.input_metadata.dataset_info is not None
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "<lambda>"
assert row.input_metadata.dataset_info.get("data_loader_num_rows") == 1
assert row.input_metadata.dataset_info.get("data_loader_num_rows_after_preprocessing") == 1
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
return row
def generate_many_rows() -> list[EvaluationRow]:
"""Factory function that generates many evaluation rows for testing max_dataset_rows."""
return [EvaluationRow(messages=[Message(role="user", content=f"What is {i} + {i}?")]) for i in range(10)]
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[generate_many_rows],
),
max_dataset_rows=3,
)
def test_dynamic_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow:
"""Dynamic data loader should respect max_dataset_rows parameter."""
# This test should only process 3 rows despite the generator creating 10
# The row content should be from the first 3 generated rows
content = row.messages[0].content
assert content in ["What is 0 + 0?", "What is 1 + 1?", "What is 2 + 2?"]
assert row.input_metadata.dataset_info is not None
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "generate_many_rows"
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
return row