diff --git a/banana_dev/api.py b/banana_dev/api.py index 67376d2..c4b8372 100644 --- a/banana_dev/api.py +++ b/banana_dev/api.py @@ -4,9 +4,10 @@ class API(): "The Banana API class interacts with the Banana API." - def __init__(self, api_key): + def __init__(self, api_key, request_timeout = 300): self.base_url = "https://api.banana.dev/v1" self.api_key = api_key.strip() + self.request_timeout = request_timeout "Get all projects under the team account" def list_projects(self, query: dict = {}) -> Tuple[dict, int]: @@ -30,10 +31,10 @@ def __call(self, method: str, route: str, data: dict = {}) -> Tuple[dict, int]: endpoint = f"{self.base_url}/{route}" if method == "POST": - res = requests.post(endpoint, json=data, headers=headers) + res = requests.post(endpoint, json=data, headers=headers, timeout=self.request_timeout) elif method == "PUT": - res = requests.put(endpoint, json=data, headers=headers) + res = requests.put(endpoint, json=data, headers=headers, timeout=self.request_timeout) elif method == "GET": - res = requests.get(endpoint, params=data, headers=headers) + res = requests.get(endpoint, params=data, headers=headers, timeout=self.request_timeout) return res.json(), res.status_code diff --git a/banana_dev/client.py b/banana_dev/client.py index 51fa85d..ce56f95 100644 --- a/banana_dev/client.py +++ b/banana_dev/client.py @@ -14,20 +14,23 @@ def __init__(self, message = "" , res: requests.Response = None): class Client(): "The Banana client class is for interacting with a specific project on Banana." - def __init__(self, api_key, url, verbosity = "DEBUG"): + def __init__(self, api_key, url, verbosity = "DEBUG", request_timeout = 300): self.api_key = api_key self.url = url self.verbosity = verbosity + self.request_timeout = request_timeout def warmup(self) -> Tuple[dict, dict]: "Warm up the Potassium server" return self.call("/_k/warmup", json={}, headers={}, retry=False) "Call a route on the Banana server with a POST request" - def call(self, route: str, json: dict = {}, headers: dict = {}, retry=True, retry_timeout = 300) -> Tuple[dict, dict]: - headers["Content-Type"] = "application/json" - headers['X-BANANA-API-KEY'] = self.api_key - headers['X-BANANA-REQUEST-ID'] = str(uuid4()) # we use the same uuid to track all retries + def call(self, route: str, json: dict = {}, headers: dict = {}, retry=True, retry_timeout = 300, request_timeout = None) -> Tuple[dict, dict]: + request_timeout = self.request_timeout if request_timeout is None else request_timeout + request_headers = dict(headers) + request_headers["Content-Type"] = "application/json" + request_headers['X-BANANA-API-KEY'] = self.api_key + request_headers['X-BANANA-REQUEST-ID'] = str(uuid4()) # we use the same uuid to track all retries endpoint = self.url.rstrip("/") + "/" + route.lstrip("/") @@ -47,7 +50,7 @@ def call(self, route: str, json: dict = {}, headers: dict = {}, retry=True, retr print("Retrying...") backoff_interval = min(backoff_interval*2, 3) - res = requests.post(endpoint, json=json, headers=headers) + res = requests.post(endpoint, json=json, headers=request_headers, timeout=request_timeout) if self.verbosity == "DEBUG": if res.status_code != 200: diff --git a/tests/test_request_timeouts.py b/tests/test_request_timeouts.py new file mode 100644 index 0000000..a082b58 --- /dev/null +++ b/tests/test_request_timeouts.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import Mock, patch + +from banana_dev import API, Client + + +class RequestTimeoutTests(unittest.TestCase): + def test_client_call_passes_timeout_and_preserves_headers(self): + response = Mock(status_code=200, headers={"x-test": "ok"}) + response.json.return_value = {"ok": True} + headers = {"X-CALLER": "present"} + + with patch("banana_dev.client.requests.post", return_value=response) as post: + result, meta = Client("secret", "https://example.test", request_timeout=12).call( + "/run", + json={"input": "value"}, + headers=headers, + retry=False, + ) + + self.assertEqual(result, {"ok": True}) + self.assertEqual(meta, {"headers": {"x-test": "ok"}}) + self.assertEqual(headers, {"X-CALLER": "present"}) + self.assertEqual(post.call_args.kwargs["timeout"], 12) + self.assertEqual(post.call_args.kwargs["headers"]["X-BANANA-API-KEY"], "secret") + self.assertEqual(post.call_args.kwargs["headers"]["X-CALLER"], "present") + + def test_client_call_allows_per_call_timeout_override(self): + response = Mock(status_code=200, headers={}) + response.json.return_value = {"ok": True} + + with patch("banana_dev.client.requests.post", return_value=response) as post: + Client("secret", "https://example.test", request_timeout=12).call( + "/run", + retry=False, + request_timeout=3, + ) + + self.assertEqual(post.call_args.kwargs["timeout"], 3) + + def test_api_methods_pass_timeout(self): + response = Mock(status_code=200) + response.json.return_value = {"results": []} + api = API(" secret ", request_timeout=7) + + with patch("banana_dev.api.requests.post", return_value=response) as post: + api._API__call("POST", "projects", {"name": "example"}) + + with patch("banana_dev.api.requests.put", return_value=response) as put: + api._API__call("PUT", "projects/example", {"name": "renamed"}) + + with patch("banana_dev.api.requests.get", return_value=response) as get: + result, status = api.list_projects() + + self.assertEqual(result, {"results": []}) + self.assertEqual(status, 200) + self.assertEqual(post.call_args.kwargs["timeout"], 7) + self.assertEqual(put.call_args.kwargs["timeout"], 7) + self.assertEqual(get.call_args.kwargs["timeout"], 7) + self.assertEqual(get.call_args.kwargs["headers"]["X-BANANA-API-KEY"], "secret") + + +if __name__ == "__main__": + unittest.main()