Skip to content

Commit 04a2a2f

Browse files
author
Dylan Huang
committed
reject absolute within path arg
1 parent 432021a commit 04a2a2f

File tree

2 files changed

+106
-64
lines changed

2 files changed

+106
-64
lines changed

eval_protocol/fireworks_api_client.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,116 +10,144 @@
1010

1111
class FireworksAPIClient:
1212
"""Client for making authenticated requests to Fireworks API with proper headers.
13-
13+
1414
This client automatically includes:
1515
- Authorization header (Bearer token)
1616
- User-Agent header for tracking eval-protocol CLI usage
1717
"""
18-
18+
1919
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
2020
"""Initialize the Fireworks API client.
21-
21+
2222
Args:
2323
api_key: Fireworks API key. If None, will be read from environment.
2424
api_base: API base URL. If None, defaults to https://api.fireworks.ai
2525
"""
2626
self.api_key = api_key
2727
self.api_base = api_base or os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai")
2828
self._session = requests.Session()
29-
30-
def _get_headers(self, content_type: Optional[str] = "application/json",
31-
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
29+
30+
def _validate_path_is_relative(self, path: str) -> None:
31+
"""Validate that the path is relative, not an absolute URL.
32+
33+
Args:
34+
path: The path to validate
35+
36+
Raises:
37+
ValueError: If path appears to be an absolute URL (starts with http:// or https://)
38+
"""
39+
if path.startswith(("http://", "https://")):
40+
raise ValueError(
41+
f"Absolute URL detected: '{path}'. FireworksAPIClient methods expect relative paths only. "
42+
f"Use a relative path like 'v1/path' instead of '{path}'. "
43+
f"The client will automatically prepend the api_base: '{self.api_base}'"
44+
)
45+
46+
def _get_headers(
47+
self, content_type: Optional[str] = "application/json", additional_headers: Optional[Dict[str, str]] = None
48+
) -> Dict[str, str]:
3249
"""Build headers for API requests.
33-
50+
3451
Args:
3552
content_type: Content-Type header value. If None, Content-Type won't be set.
3653
additional_headers: Additional headers to merge in.
37-
54+
3855
Returns:
3956
Dictionary of headers including authorization and user-agent.
4057
"""
4158
headers = {
4259
"User-Agent": get_user_agent(),
4360
}
44-
61+
4562
if self.api_key:
4663
headers["Authorization"] = f"Bearer {self.api_key}"
47-
64+
4865
if content_type:
4966
headers["Content-Type"] = content_type
50-
67+
5168
if additional_headers:
5269
headers.update(additional_headers)
53-
70+
5471
return headers
55-
56-
def get(self, path: str, params: Optional[Dict[str, Any]] = None,
57-
timeout: int = 30, **kwargs) -> requests.Response:
72+
73+
def get(
74+
self, path: str, params: Optional[Dict[str, Any]] = None, timeout: int = 30, **kwargs
75+
) -> requests.Response:
5876
"""Make a GET request to the Fireworks API.
59-
77+
6078
Args:
6179
path: API path (relative to api_base)
6280
params: Query parameters
6381
timeout: Request timeout in seconds
6482
**kwargs: Additional arguments passed to requests.get
65-
83+
6684
Returns:
6785
Response object
6886
"""
87+
self._validate_path_is_relative(path)
6988
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
7089
headers = self._get_headers(content_type=None)
7190
if "headers" in kwargs:
7291
headers.update(kwargs.pop("headers"))
7392
return self._session.get(url, params=params, headers=headers, timeout=timeout, **kwargs)
74-
75-
def post(self, path: str, json: Optional[Dict[str, Any]] = None,
76-
data: Optional[Any] = None, files: Optional[Dict[str, Any]] = None,
77-
timeout: int = 60, **kwargs) -> requests.Response:
93+
94+
def post(
95+
self,
96+
path: str,
97+
json: Optional[Dict[str, Any]] = None,
98+
data: Optional[Any] = None,
99+
files: Optional[Dict[str, Any]] = None,
100+
timeout: int = 60,
101+
**kwargs,
102+
) -> requests.Response:
78103
"""Make a POST request to the Fireworks API.
79-
104+
80105
Args:
81106
path: API path (relative to api_base)
82107
json: JSON payload
83108
data: Form data payload
84109
files: Files to upload
85110
timeout: Request timeout in seconds
86111
**kwargs: Additional arguments passed to requests.post
87-
112+
88113
Returns:
89114
Response object
90115
"""
116+
self._validate_path_is_relative(path)
91117
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
92-
118+
93119
# For file uploads, don't set Content-Type (let requests handle multipart/form-data)
94120
content_type = None if files else "application/json"
95121
headers = self._get_headers(content_type=content_type)
96-
122+
97123
if "headers" in kwargs:
98124
headers.update(kwargs.pop("headers"))
99-
100-
return self._session.post(url, json=json, data=data, files=files,
101-
headers=headers, timeout=timeout, **kwargs)
102-
103-
def put(self, path: str, json: Optional[Dict[str, Any]] = None,
104-
timeout: int = 60, **kwargs) -> requests.Response:
125+
126+
return self._session.post(url, json=json, data=data, files=files, headers=headers, timeout=timeout, **kwargs)
127+
128+
def put(self, path: str, json: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs) -> requests.Response:
105129
"""Make a PUT request to the Fireworks API."""
130+
self._validate_path_is_relative(path)
106131
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
107132
headers = self._get_headers()
108133
if "headers" in kwargs:
109134
headers.update(kwargs.pop("headers"))
110135
return self._session.put(url, json=json, headers=headers, timeout=timeout, **kwargs)
111-
112-
def patch(self, path: str, json: Optional[Dict[str, Any]] = None,
113-
timeout: int = 60, **kwargs) -> requests.Response:
136+
137+
def patch(
138+
self, path: str, json: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs
139+
) -> requests.Response:
114140
"""Make a PATCH request to the Fireworks API."""
141+
self._validate_path_is_relative(path)
115142
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
116143
headers = self._get_headers()
117144
if "headers" in kwargs:
118145
headers.update(kwargs.pop("headers"))
119146
return self._session.patch(url, json=json, headers=headers, timeout=timeout, **kwargs)
120-
147+
121148
def delete(self, path: str, timeout: int = 30, **kwargs) -> requests.Response:
122149
"""Make a DELETE request to the Fireworks API."""
150+
self._validate_path_is_relative(path)
123151
url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}"
124152
headers = self._get_headers(content_type=None)
125153
if "headers" in kwargs:

tests/test_fireworks_api_client.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -415,23 +415,19 @@ def test_paths_containing_v1_pattern(self):
415415

416416
mock_post.reset_mock()
417417

418-
def test_full_url_passed_by_mistake_detected(self):
419-
"""Test that accidentally passing a full URL instead of relative path is detected.
418+
def test_full_url_passed_by_mistake_raises_error(self):
419+
"""Test that accidentally passing a full URL instead of relative path raises ValueError.
420420
421-
This test documents the bug pattern: if a full URL like '{api_base}/v1/path'
422-
is passed instead of a relative path like 'v1/path', it will result in a
423-
malformed URL like '{api_base}/{api_base}/v1/path'.
424-
425-
This test verifies that our code correctly handles relative paths (which prevents
426-
the bug), and documents what would happen if the bug occurred.
421+
This test verifies that our code correctly catches the bug early by raising an error
422+
when an absolute URL is passed instead of a relative path.
427423
"""
428424
api_base = "https://api.fireworks.ai"
429425
client = FireworksAPIClient(api_key="test_key", api_base=api_base)
430426

427+
# CORRECT: Relative path (what we should use) - should work fine
431428
mock_response = MagicMock()
432429
mock_response.status_code = 200
433430

434-
# CORRECT: Relative path (what we should use)
435431
with patch.object(client._session, "post", return_value=mock_response) as mock_post:
436432
correct_relative_path = "v1/test-evaluator:getUploadEndpoint"
437433
client.post(correct_relative_path, json={})
@@ -441,27 +437,45 @@ def test_full_url_passed_by_mistake_detected(self):
441437
expected_correct_url = f"{api_base}/{correct_relative_path}"
442438
assert correct_url == expected_correct_url
443439

444-
# INCORRECT: Full URL (this would cause the bug - but we're not actually testing this,
445-
# just documenting that our current implementation would create a malformed URL)
446-
# If someone accidentally did: client.post(f"{api_base}/v1/path", ...)
447-
# The result would be: f"{api_base}/{api_base}/v1/path" which is wrong.
448-
# Our tests above ensure we use relative paths, preventing this bug.
449-
mock_post.reset_mock()
450-
with patch.object(client._session, "post", return_value=mock_response) as mock_post:
451-
# Simulating what WOULD happen if buggy code passed full URL
452-
buggy_full_url = f"{api_base}/v1/test-evaluator:getUploadEndpoint"
453-
client.post(buggy_full_url, json={})
440+
# INCORRECT: Full URL should raise ValueError
441+
full_url_with_http = "https://api.fireworks.ai/v1/test-evaluator:getUploadEndpoint"
442+
with pytest.raises(ValueError, match="Absolute URL detected"):
443+
client.post(full_url_with_http, json={})
444+
445+
full_url_with_http_scheme = "http://api.fireworks.ai/v1/test-evaluator:getUploadEndpoint"
446+
with pytest.raises(ValueError, match="Absolute URL detected"):
447+
client.post(full_url_with_http_scheme, json={})
448+
449+
# Test that error message is helpful
450+
with pytest.raises(ValueError) as exc_info:
451+
client.post(full_url_with_http, json={})
452+
error_msg = str(exc_info.value)
453+
assert "Absolute URL detected" in error_msg
454+
assert full_url_with_http in error_msg
455+
assert "relative paths only" in error_msg
456+
assert api_base in error_msg # Should mention api_base in the help message
457+
458+
def test_all_methods_reject_absolute_urls(self):
459+
"""Test that all HTTP methods reject absolute URLs."""
460+
api_base = "https://api.fireworks.ai"
461+
client = FireworksAPIClient(api_key="test_key", api_base=api_base)
454462

455-
call_args = mock_post.call_args
456-
buggy_url = call_args[0][0]
457-
# This shows what the buggy URL would look like
458-
buggy_expected = f"{api_base}/{buggy_full_url}"
459-
460-
# This assertion documents the bug pattern - the URL would be malformed
461-
assert buggy_url == buggy_expected
462-
assert buggy_url.startswith(f"{api_base}/{api_base}"), (
463-
"This documents the bug: passing full URL creates double-prefix. Always use relative paths!"
464-
)
463+
absolute_url = f"{api_base}/v1/test/path"
464+
465+
methods = [
466+
("get", lambda url: client.get(url)),
467+
("post", lambda url: client.post(url, json={})),
468+
("put", lambda url: client.put(url, json={})),
469+
("patch", lambda url: client.patch(url, json={})),
470+
("delete", lambda url: client.delete(url)),
471+
]
472+
473+
for method_name, method_call in methods:
474+
with pytest.raises(ValueError, match="Absolute URL detected") as exc_info:
475+
method_call(absolute_url)
476+
error_msg = str(exc_info.value)
477+
assert "Absolute URL detected" in error_msg, f"{method_name.upper()} should reject absolute URL"
478+
assert absolute_url in error_msg
465479

466480

467481
if __name__ == "__main__":

0 commit comments

Comments
 (0)