Skip to content

Commit 911daca

Browse files
committed
testing script
1 parent bce982d commit 911daca

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

eval_protocol/proxy/test_trail.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""
2+
Tests for Trail Management System proxy implementation.
3+
"""
4+
5+
import pytest
6+
from unittest.mock import Mock, AsyncMock, patch
7+
from fastapi.testclient import TestClient
8+
import redis
9+
10+
from proxy_core.models import ChatParams, ProxyConfig
11+
from proxy_core.app import create_app
12+
from proxy_core.auth import NoAuthProvider
13+
14+
15+
@pytest.fixture
16+
def mock_config():
17+
"""Mock ProxyConfig."""
18+
return ProxyConfig(
19+
litellm_url="http://mock-litellm:8000",
20+
langfuse_host="https://mock-langfuse.com",
21+
langfuse_keys={
22+
"test-project": {
23+
"public_key": "pk-test",
24+
"secret_key": "sk-test"
25+
}
26+
},
27+
default_project_id="test-project",
28+
request_timeout=300.0
29+
)
30+
31+
32+
@pytest.fixture
33+
def mock_redis():
34+
"""Mock Redis client."""
35+
mock = Mock(spec=redis.Redis)
36+
mock.ping.return_value = True
37+
mock.close.return_value = None
38+
mock.sadd = Mock()
39+
return mock
40+
41+
42+
@pytest.fixture
43+
def app(mock_config, mock_redis):
44+
"""Create test app."""
45+
app = create_app(auth_provider=NoAuthProvider())
46+
app.state.config = mock_config
47+
app.state.redis = mock_redis
48+
return app
49+
50+
51+
@pytest.fixture
52+
def client(app):
53+
"""Create test client."""
54+
return TestClient(app)
55+
56+
57+
class TestTrailModels:
58+
"""Test data models."""
59+
60+
def test_chat_params_trail_id(self):
61+
"""ChatParams accepts trail_id."""
62+
params = ChatParams(trail_id="test-trail-123", project_id="my-project")
63+
assert params.trail_id == "test-trail-123"
64+
assert params.project_id == "my-project"
65+
assert params.rollout_id is None
66+
67+
def test_chat_params_backward_compatibility(self):
68+
"""ChatParams still works with rollout_id."""
69+
params = ChatParams(
70+
rollout_id="rollout-123",
71+
invocation_id="inv-1",
72+
experiment_id="exp-1",
73+
run_id="run-1",
74+
row_id="row-1"
75+
)
76+
assert params.rollout_id == "rollout-123"
77+
assert params.trail_id is None
78+
79+
80+
class TestTrailRoutes:
81+
"""Test trail routes."""
82+
83+
def test_trail_chat_routes_registered(self, client):
84+
"""Trail chat completion routes exist."""
85+
routes = [route.path for route in client.app.routes]
86+
assert "/trails/{trail_id}/chat/completions" in routes
87+
assert "/v1/trails/{trail_id}/chat/completions" in routes
88+
assert "/project_id/{project_id}/trails/{trail_id}/chat/completions" in routes
89+
90+
def test_trail_traces_routes_registered(self, client):
91+
"""Trail traces routes exist."""
92+
routes = [route.path for route in client.app.routes]
93+
assert "/trails/{trail_id}/traces" in routes
94+
assert "/v1/trails/{trail_id}/traces" in routes
95+
assert "/trails/{trail_id}/traces/pointwise" in routes
96+
97+
def test_legacy_routes_preserved(self, client):
98+
"""Legacy rollout routes still exist."""
99+
routes_str = " ".join([route.path for route in client.app.routes])
100+
assert "rollout_id" in routes_str
101+
assert "invocation_id" in routes_str
102+
103+
def test_health_endpoint(self, client):
104+
"""Health endpoint works."""
105+
response = client.get("/health")
106+
assert response.status_code == 200
107+
data = response.json()
108+
assert data["status"] == "healthy"
109+
110+
111+
class TestTrailTagInjection:
112+
"""Test tag injection logic."""
113+
114+
@pytest.mark.asyncio
115+
async def test_trail_simple_tags(self, mock_config, mock_redis):
116+
"""Trail requests inject simple tags (2 tags)."""
117+
from proxy_core.litellm import handle_chat_completion
118+
from fastapi import Request
119+
120+
mock_request = Mock(spec=Request)
121+
mock_request.headers = {"authorization": "Bearer test-key"}
122+
mock_request.body = AsyncMock(return_value=b'{"model": "test", "messages": []}')
123+
124+
params = ChatParams(trail_id="test-trail-123")
125+
126+
with patch('proxy_core.litellm.httpx.AsyncClient') as mock_client:
127+
mock_response = Mock()
128+
mock_response.status_code = 200
129+
mock_response.content = b'{"choices": []}'
130+
mock_response.headers = {}
131+
132+
mock_post = AsyncMock(return_value=mock_response)
133+
mock_client.return_value.__aenter__.return_value.post = mock_post
134+
135+
await handle_chat_completion(mock_config, mock_redis, mock_request, params)
136+
137+
sent_data = mock_post.call_args.kwargs['json']
138+
tags = sent_data['metadata']['tags']
139+
140+
# Trail system: only 2 tags
141+
assert len(tags) == 2
142+
trail_tags = [t for t in tags if t.startswith('trail_id:')]
143+
assert len(trail_tags) == 1
144+
assert trail_tags[0] == 'trail_id:test-trail-123'
145+
146+
insertion_tags = [t for t in tags if t.startswith('insertion_id:')]
147+
assert len(insertion_tags) == 1
148+
149+
@pytest.mark.asyncio
150+
async def test_rollout_complex_tags(self, mock_config, mock_redis):
151+
"""Rollout requests inject complex tags (6 tags) - backward compat."""
152+
from proxy_core.litellm import handle_chat_completion
153+
from fastapi import Request
154+
155+
mock_request = Mock(spec=Request)
156+
mock_request.headers = {"authorization": "Bearer test-key"}
157+
mock_request.body = AsyncMock(return_value=b'{"model": "test", "messages": []}')
158+
159+
params = ChatParams(
160+
rollout_id="rollout-123",
161+
invocation_id="inv-1",
162+
experiment_id="exp-1",
163+
run_id="run-1",
164+
row_id="row-1"
165+
)
166+
167+
with patch('proxy_core.litellm.httpx.AsyncClient') as mock_client:
168+
mock_response = Mock()
169+
mock_response.status_code = 200
170+
mock_response.content = b'{"choices": []}'
171+
mock_response.headers = {}
172+
173+
mock_post = AsyncMock(return_value=mock_response)
174+
mock_client.return_value.__aenter__.return_value.post = mock_post
175+
176+
await handle_chat_completion(mock_config, mock_redis, mock_request, params)
177+
178+
sent_data = mock_post.call_args.kwargs['json']
179+
tags = sent_data['metadata']['tags']
180+
181+
# Legacy system: 6 tags
182+
assert len(tags) == 6
183+
tag_prefixes = [t.split(':')[0] for t in tags]
184+
assert 'rollout_id' in tag_prefixes
185+
assert 'invocation_id' in tag_prefixes
186+
assert 'experiment_id' in tag_prefixes
187+
188+
189+
class TestRedisTracking:
190+
"""Test Redis tracking."""
191+
192+
@pytest.mark.asyncio
193+
async def test_redis_uses_trail_id_as_key(self, mock_config, mock_redis):
194+
"""Redis uses trail_id as key."""
195+
from proxy_core.litellm import handle_chat_completion
196+
from fastapi import Request
197+
198+
mock_request = Mock(spec=Request)
199+
mock_request.headers = {"authorization": "Bearer test-key"}
200+
mock_request.body = AsyncMock(return_value=b'{"model": "test", "messages": []}')
201+
202+
params = ChatParams(trail_id="my-trail-456")
203+
204+
with patch('proxy_core.litellm.httpx.AsyncClient') as mock_client:
205+
mock_response = Mock()
206+
mock_response.status_code = 200
207+
mock_response.content = b'{"choices": []}'
208+
mock_response.headers = {}
209+
210+
mock_post = AsyncMock(return_value=mock_response)
211+
mock_client.return_value.__aenter__.return_value.post = mock_post
212+
213+
await handle_chat_completion(mock_config, mock_redis, mock_request, params)
214+
215+
# Verify Redis sadd was called with trail_id
216+
assert mock_redis.sadd.called
217+
call_args = mock_redis.sadd.call_args[0]
218+
assert call_args[0] == "my-trail-456"
219+
220+
# Second arg should be insertion_id
221+
insertion_id = call_args[1]
222+
assert isinstance(insertion_id, str)
223+
assert len(insertion_id) > 0
224+
225+
226+
227+
228+
229+

0 commit comments

Comments
 (0)