-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_parameterized_ids.py
More file actions
109 lines (86 loc) · 4.13 KB
/
test_parameterized_ids.py
File metadata and controls
109 lines (86 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.parameterize import DefaultParameterIdGenerator, pytest_parametrize
from eval_protocol.pytest.generate_parameter_combinations import generate_parameter_combinations
def test_parameterized_ids():
"""Test that evaluation_test generates proper parameter IDs."""
collected_ids = []
@evaluation_test(
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
completion_params=[
{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"},
{"model": "gpt-4"},
{"temperature": 0.5}, # No model - should not generate ID
],
)
def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow:
# Collect the row to verify it was processed
collected_ids.append(row.input_metadata.row_id)
return row
# The function should exist and be callable
assert test_parameterized_ids is not None
assert callable(test_parameterized_ids)
# Test that the decorator was applied (function should have pytest marks)
import pytest
marks = getattr(test_parameterized_ids, "pytestmark", [])
assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator"
# Verify it's a parametrize mark
parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"]
assert len(parametrize_marks) > 0, "Should have parametrize mark"
# Check that the parametrize mark has IDs
parametrize_mark = parametrize_marks[0]
assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs"
assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs"
# Extract the IDs from the parametrize mark
ids = parametrize_mark.kwargs.get("ids")
if ids is not None:
# Should have IDs for models but not for temperature-only params
expected_ids = ["model-gpt-oss-120b", "model-gpt-4"]
assert list(ids) == expected_ids, f"Expected {expected_ids}, got {list(ids)}"
def test_default_id_generator():
"""Test the DefaultParameterIdGenerator with various model names."""
generator = DefaultParameterIdGenerator()
# Test with full model path
combo1 = (None, {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, None, None, None)
id1 = generator.generate_id(combo1)
assert id1 == "model-gpt-oss-120b"
# Test with simple model name
combo2 = (None, {"model": "gpt-4"}, None, None, None)
id2 = generator.generate_id(combo2)
assert id2 == "model-gpt-4"
# Test with no model
combo3 = (None, {"temperature": 0.5}, None, None, None)
id3 = generator.generate_id(combo3)
assert id3 is None
# Test with None completion_params
combo4 = (None, None, None, None, None)
id4 = generator.generate_id(combo4)
assert id4 is None
def test_pytest_parametrize_with_custom_id_generator():
"""Test pytest_parametrize with a custom ID generator."""
# Create test combinations
combinations = [
(None, {"model": "gpt-4"}, None, None, None),
(None, {"model": "claude-3"}, None, None, None),
(None, {"temperature": 0.5}, None, None, None), # No model
]
# Test with default generator
result = pytest_parametrize(
combinations=combinations,
input_dataset=None,
completion_params=[{"model": "gpt-4"}, {"model": "claude-3"}, {"temperature": 0.5}],
input_messages=None,
input_rows=None,
evaluation_test_kwargs=None,
)
assert result["argnames"] == ["completion_params"]
assert len(list(result["argvalues"])) == 3
assert result["ids"] == ["model-gpt-4", "model-claude-3"] # None for no model
def test_id_generator_max_length():
"""Test that ID generator respects max_length parameter."""
generator = DefaultParameterIdGenerator(max_length=10)
# Test with long model name
combo = (None, {"model": "very-long-model-name-that-exceeds-max-length"}, None, None, None)
id_str = generator.generate_id(combo)
assert id_str == "model-v..."
assert len(id_str) <= 10