-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_utils.py
More file actions
188 lines (155 loc) · 7.82 KB
/
test_utils.py
File metadata and controls
188 lines (155 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.models import EvaluationRow, Status, InputMetadata, ExecutionMetadata
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
class TestRolloutProcessorWithRetry:
"""Test the rollout_processor_with_retry function to ensure logging works correctly."""
@pytest.fixture
def mock_rollout_processor(self):
"""Create a mock rollout processor that returns async tasks."""
processor = MagicMock()
processor.cleanup = MagicMock()
processor.acleanup = AsyncMock() # async cleanup method
return processor
@pytest.fixture
def mock_config(self):
"""Create a mock config with a logger."""
config = MagicMock(spec=RolloutProcessorConfig)
config.logger = MagicMock(spec=DatasetLogger)
config.logger.log = MagicMock()
config.exception_handler_config = None
config.kwargs = {}
return config
@pytest.fixture
def sample_dataset(self):
"""Create a sample dataset for testing."""
from datetime import datetime
row = EvaluationRow(
messages=[],
input_metadata=InputMetadata(completion_params={"model": "test-model"}),
rollout_status=Status.rollout_finished(),
execution_metadata=ExecutionMetadata(),
created_at=datetime.fromisoformat("2024-01-01T00:00:00"),
)
return [row]
@pytest.mark.asyncio
async def test_logger_called_on_successful_execution(self, mock_rollout_processor, mock_config, sample_dataset):
"""Test that the logger is called when execution succeeds."""
# Create mock tasks that will complete successfully
async def mock_task():
from datetime import datetime
row = EvaluationRow(
messages=[],
input_metadata=InputMetadata(completion_params={"model": "test-model"}),
rollout_status=Status.rollout_finished(),
execution_metadata=ExecutionMetadata(),
created_at=datetime.fromisoformat("2024-01-01T00:00:00"),
)
return row
# Mock the processor to return a list of tasks
mock_rollout_processor.return_value = [asyncio.create_task(mock_task())]
# Call the function
results = []
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
results.append(result)
# Verify that the logger was called for each result
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])
@pytest.mark.asyncio
async def test_logger_called_on_failed_execution(self, mock_rollout_processor, mock_config, sample_dataset):
"""Test that the logger is called when execution fails."""
# Mock the processor to return a task that raises an exception
async def failing_task():
raise ValueError("Test error")
mock_rollout_processor.return_value = [asyncio.create_task(failing_task())]
# Call the function
results = []
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
results.append(result)
# Verify that the logger was called for the failed result
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])
# Verify the result has an error status
assert results[0].rollout_status.code == 13 # INTERNAL error code
assert "Test error" in results[0].rollout_status.message
@pytest.mark.asyncio
async def test_logger_called_on_retry_execution(self, mock_rollout_processor, mock_config, sample_dataset):
"""Test that the logger is called when execution succeeds after retry."""
# Mock the processor to return a task that fails first, then succeeds on retry
call_count = 0
async def flaky_task():
nonlocal call_count
call_count += 1
if call_count == 1:
raise ConnectionError("Connection failed")
else:
from datetime import datetime
row = EvaluationRow(
messages=[],
input_metadata=InputMetadata(completion_params={}),
rollout_status=Status.rollout_finished(),
execution_metadata=ExecutionMetadata(),
created_at=datetime.fromisoformat("2024-01-01T00:00:00"),
)
return row
mock_rollout_processor.return_value = [asyncio.create_task(flaky_task())]
# Call the function - it should handle the retry internally
results = []
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
results.append(result)
# Verify that the logger was called for the result
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])
@pytest.mark.asyncio
async def test_logger_called_for_multiple_rows(self, mock_rollout_processor, mock_config):
"""Test that the logger is called for each row in a multi-row dataset."""
# Create a dataset with multiple rows
from datetime import datetime
sample_dataset = [
EvaluationRow(
messages=[],
input_metadata=InputMetadata(completion_params={"model": "test-model"}),
rollout_status=Status.rollout_finished(),
execution_metadata=ExecutionMetadata(),
created_at=datetime.fromisoformat("2024-01-01T00:00:00"),
),
EvaluationRow(
messages=[],
input_metadata=InputMetadata(completion_params={"model": "test-model"}),
rollout_status=Status.rollout_finished(),
execution_metadata=ExecutionMetadata(),
created_at=datetime.fromisoformat("2024-01-01T00:00:00"),
),
]
# Mock the processor to return multiple tasks
async def mock_task():
row = EvaluationRow(
messages=[],
input_metadata=InputMetadata(completion_params={"model": "test-model"}),
rollout_status=Status.rollout_finished(),
execution_metadata=ExecutionMetadata(),
created_at=datetime.fromisoformat("2024-01-01T00:00:00"),
)
return row
mock_rollout_processor.return_value = [asyncio.create_task(mock_task()), asyncio.create_task(mock_task())]
# Call the function
results = []
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
results.append(result)
# Verify that the logger was called for each result
assert mock_config.logger.log.call_count == 2
assert len(results) == 2
@pytest.mark.asyncio
async def test_logger_called_even_when_processor_fails_to_initialize(
self, mock_rollout_processor, mock_config, sample_dataset
):
"""Test that cleanup is called even when the processor fails to initialize."""
# Mock the processor to raise an exception during initialization
mock_rollout_processor.side_effect = RuntimeError("Processor failed to initialize")
# Call the function and expect it to raise the exception
with pytest.raises(RuntimeError, match="Processor failed to initialize"):
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
pass