Skip to content

Commit 1a14b73

Browse files
author
Dylan Huang
committed
added tests
1 parent 1b10a8f commit 1a14b73

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

tests/test_fireworks_api_client.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""Tests for FireworksAPIClient user-agent header functionality."""
2+
3+
import re
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
from eval_protocol.common_utils import get_user_agent
9+
from eval_protocol.fireworks_api_client import FireworksAPIClient
10+
11+
12+
class TestFireworksAPIClientUserAgent:
13+
"""Test that FireworksAPIClient correctly sets the User-Agent header."""
14+
15+
def test_get_user_agent_format(self):
16+
"""Test that get_user_agent returns the expected format."""
17+
user_agent = get_user_agent()
18+
# Should match format: eval-protocol/{version}
19+
# Version can be actual version or "unknown"
20+
assert user_agent.startswith("eval-protocol/")
21+
assert len(user_agent) > len("eval-protocol/")
22+
23+
def test_get_user_agent_fallback_logic(self):
24+
"""Test that get_user_agent has fallback logic for when version can't be imported.
25+
26+
This test verifies the code structure, since actually triggering an import
27+
failure during the import statement is difficult to test reliably.
28+
The important behavior (User-Agent header being set) is verified in other tests.
29+
"""
30+
# Verify the function exists and can be called normally
31+
user_agent = get_user_agent()
32+
# The function should always return a valid user agent string
33+
assert isinstance(user_agent, str)
34+
assert user_agent.startswith("eval-protocol/")
35+
36+
# The actual fallback ("eval-protocol/unknown") happens when the import
37+
# fails, which is hard to simulate without patching at a very low level.
38+
# The try/except block in the implementation handles this gracefully.
39+
40+
def test_get_headers_includes_user_agent(self):
41+
"""Test that _get_headers includes the User-Agent header."""
42+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
43+
headers = client._get_headers()
44+
45+
assert "User-Agent" in headers
46+
assert headers["User-Agent"] == get_user_agent()
47+
48+
def test_get_request_includes_user_agent(self):
49+
"""Test that GET requests include the User-Agent header."""
50+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
51+
52+
mock_response = MagicMock()
53+
mock_response.status_code = 200
54+
55+
with patch.object(client._session, "get", return_value=mock_response) as mock_get:
56+
client.get("test/path")
57+
58+
mock_get.assert_called_once()
59+
call_kwargs = mock_get.call_args[1]
60+
headers = call_kwargs["headers"]
61+
62+
assert "User-Agent" in headers
63+
assert headers["User-Agent"] == get_user_agent()
64+
assert headers["Authorization"] == "Bearer test_key"
65+
66+
def test_post_request_includes_user_agent(self):
67+
"""Test that POST requests include the User-Agent header."""
68+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
69+
70+
mock_response = MagicMock()
71+
mock_response.status_code = 200
72+
73+
with patch.object(client._session, "post", return_value=mock_response) as mock_post:
74+
client.post("test/path", json={"key": "value"})
75+
76+
mock_post.assert_called_once()
77+
call_kwargs = mock_post.call_args[1]
78+
headers = call_kwargs["headers"]
79+
80+
assert "User-Agent" in headers
81+
assert headers["User-Agent"] == get_user_agent()
82+
assert headers["Authorization"] == "Bearer test_key"
83+
assert headers["Content-Type"] == "application/json"
84+
85+
def test_post_with_files_excludes_content_type(self):
86+
"""Test that POST requests with files exclude Content-Type header."""
87+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
88+
89+
mock_response = MagicMock()
90+
mock_response.status_code = 200
91+
92+
with patch.object(client._session, "post", return_value=mock_response) as mock_post:
93+
client.post("test/path", files={"file": MagicMock()})
94+
95+
mock_post.assert_called_once()
96+
call_kwargs = mock_post.call_args[1]
97+
headers = call_kwargs["headers"]
98+
99+
assert "User-Agent" in headers
100+
assert headers["User-Agent"] == get_user_agent()
101+
# Content-Type should not be set when files are present
102+
assert "Content-Type" not in headers
103+
104+
def test_put_request_includes_user_agent(self):
105+
"""Test that PUT requests include the User-Agent header."""
106+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
107+
108+
mock_response = MagicMock()
109+
mock_response.status_code = 200
110+
111+
with patch.object(client._session, "put", return_value=mock_response) as mock_put:
112+
client.put("test/path", json={"key": "value"})
113+
114+
mock_put.assert_called_once()
115+
call_kwargs = mock_put.call_args[1]
116+
headers = call_kwargs["headers"]
117+
118+
assert "User-Agent" in headers
119+
assert headers["User-Agent"] == get_user_agent()
120+
assert headers["Authorization"] == "Bearer test_key"
121+
122+
def test_patch_request_includes_user_agent(self):
123+
"""Test that PATCH requests include the User-Agent header."""
124+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
125+
126+
mock_response = MagicMock()
127+
mock_response.status_code = 200
128+
129+
with patch.object(client._session, "patch", return_value=mock_response) as mock_patch:
130+
client.patch("test/path", json={"key": "value"})
131+
132+
mock_patch.assert_called_once()
133+
call_kwargs = mock_patch.call_args[1]
134+
headers = call_kwargs["headers"]
135+
136+
assert "User-Agent" in headers
137+
assert headers["User-Agent"] == get_user_agent()
138+
assert headers["Authorization"] == "Bearer test_key"
139+
140+
def test_delete_request_includes_user_agent(self):
141+
"""Test that DELETE requests include the User-Agent header."""
142+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
143+
144+
mock_response = MagicMock()
145+
mock_response.status_code = 200
146+
147+
with patch.object(client._session, "delete", return_value=mock_response) as mock_delete:
148+
client.delete("test/path")
149+
150+
mock_delete.assert_called_once()
151+
call_kwargs = mock_delete.call_args[1]
152+
headers = call_kwargs["headers"]
153+
154+
assert "User-Agent" in headers
155+
assert headers["User-Agent"] == get_user_agent()
156+
assert headers["Authorization"] == "Bearer test_key"
157+
# DELETE requests shouldn't have Content-Type
158+
assert "Content-Type" not in headers
159+
160+
def test_additional_headers_merged(self):
161+
"""Test that additional headers passed to requests are merged with User-Agent."""
162+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
163+
164+
mock_response = MagicMock()
165+
mock_response.status_code = 200
166+
167+
with patch.object(client._session, "get", return_value=mock_response) as mock_get:
168+
client.get("test/path", headers={"X-Custom-Header": "custom-value"})
169+
170+
mock_get.assert_called_once()
171+
call_kwargs = mock_get.call_args[1]
172+
headers = call_kwargs["headers"]
173+
174+
assert "User-Agent" in headers
175+
assert headers["User-Agent"] == get_user_agent()
176+
assert headers["X-Custom-Header"] == "custom-value"
177+
178+
def test_user_agent_consistent_across_methods(self):
179+
"""Test that User-Agent is consistent across all HTTP methods."""
180+
client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai")
181+
182+
mock_response = MagicMock()
183+
mock_response.status_code = 200
184+
185+
expected_user_agent = get_user_agent()
186+
187+
# Test all methods
188+
methods = [
189+
("get", lambda: client.get("test/path")),
190+
("post", lambda: client.post("test/path", json={})),
191+
("put", lambda: client.put("test/path", json={})),
192+
("patch", lambda: client.patch("test/path", json={})),
193+
("delete", lambda: client.delete("test/path")),
194+
]
195+
196+
for method_name, method_call in methods:
197+
with patch.object(client._session, method_name, return_value=mock_response) as mock_method:
198+
method_call()
199+
200+
call_kwargs = mock_method.call_args[1]
201+
headers = call_kwargs["headers"]
202+
203+
assert "User-Agent" in headers, f"{method_name} should include User-Agent"
204+
assert headers["User-Agent"] == expected_user_agent, (
205+
f"{method_name} User-Agent should match expected value"
206+
)
207+
208+
def test_user_agent_without_api_key(self):
209+
"""Test that User-Agent is still included even without API key."""
210+
client = FireworksAPIClient(api_key=None, api_base="https://api.fireworks.ai")
211+
212+
mock_response = MagicMock()
213+
mock_response.status_code = 200
214+
215+
with patch.object(client._session, "get", return_value=mock_response) as mock_get:
216+
client.get("test/path")
217+
218+
mock_get.assert_called_once()
219+
call_kwargs = mock_get.call_args[1]
220+
headers = call_kwargs["headers"]
221+
222+
assert "User-Agent" in headers
223+
assert headers["User-Agent"] == get_user_agent()
224+
# Authorization should not be present
225+
assert "Authorization" not in headers
226+
227+
228+
if __name__ == "__main__":
229+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)