Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions banana_dev/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

class API():
"The Banana API class interacts with the Banana API."
def __init__(self, api_key):
def __init__(self, api_key, request_timeout = 300):
self.base_url = "https://api.banana.dev/v1"
self.api_key = api_key.strip()
self.request_timeout = request_timeout

"Get all projects under the team account"
def list_projects(self, query: dict = {}) -> Tuple[dict, int]:
Expand All @@ -30,10 +31,10 @@ def __call(self, method: str, route: str, data: dict = {}) -> Tuple[dict, int]:
endpoint = f"{self.base_url}/{route}"

if method == "POST":
res = requests.post(endpoint, json=data, headers=headers)
res = requests.post(endpoint, json=data, headers=headers, timeout=self.request_timeout)
elif method == "PUT":
res = requests.put(endpoint, json=data, headers=headers)
res = requests.put(endpoint, json=data, headers=headers, timeout=self.request_timeout)
elif method == "GET":
res = requests.get(endpoint, params=data, headers=headers)
res = requests.get(endpoint, params=data, headers=headers, timeout=self.request_timeout)

return res.json(), res.status_code
15 changes: 9 additions & 6 deletions banana_dev/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@ def __init__(self, message = "" , res: requests.Response = None):

class Client():
"The Banana client class is for interacting with a specific project on Banana."
def __init__(self, api_key, url, verbosity = "DEBUG"):
def __init__(self, api_key, url, verbosity = "DEBUG", request_timeout = 300):
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review: I kept the default at 300 seconds to line up with the existing retry_timeout and 5-minute timeout behavior, while still exposing request_timeout for callers that need a stricter bound.

self.api_key = api_key
self.url = url
self.verbosity = verbosity
self.request_timeout = request_timeout

def warmup(self) -> Tuple[dict, dict]:
"Warm up the Potassium server"
return self.call("/_k/warmup", json={}, headers={}, retry=False)

"Call a route on the Banana server with a POST request"
def call(self, route: str, json: dict = {}, headers: dict = {}, retry=True, retry_timeout = 300) -> Tuple[dict, dict]:
headers["Content-Type"] = "application/json"
headers['X-BANANA-API-KEY'] = self.api_key
headers['X-BANANA-REQUEST-ID'] = str(uuid4()) # we use the same uuid to track all retries
def call(self, route: str, json: dict = {}, headers: dict = {}, retry=True, retry_timeout = 300, request_timeout = None) -> Tuple[dict, dict]:
request_timeout = self.request_timeout if request_timeout is None else request_timeout
request_headers = dict(headers)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review: Copying the caller-provided headers keeps the timeout change from also mutating user-owned dictionaries when Banana auth headers are injected.

request_headers["Content-Type"] = "application/json"
request_headers['X-BANANA-API-KEY'] = self.api_key
request_headers['X-BANANA-REQUEST-ID'] = str(uuid4()) # we use the same uuid to track all retries

endpoint = self.url.rstrip("/") + "/" + route.lstrip("/")

Expand All @@ -47,7 +50,7 @@ def call(self, route: str, json: dict = {}, headers: dict = {}, retry=True, retr
print("Retrying...")

backoff_interval = min(backoff_interval*2, 3)
res = requests.post(endpoint, json=json, headers=headers)
res = requests.post(endpoint, json=json, headers=request_headers, timeout=request_timeout)

if self.verbosity == "DEBUG":
if res.status_code != 200:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_request_timeouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest
from unittest.mock import Mock, patch

from banana_dev import API, Client


class RequestTimeoutTests(unittest.TestCase):
def test_client_call_passes_timeout_and_preserves_headers(self):
response = Mock(status_code=200, headers={"x-test": "ok"})
response.json.return_value = {"ok": True}
headers = {"X-CALLER": "present"}

with patch("banana_dev.client.requests.post", return_value=response) as post:
result, meta = Client("secret", "https://example.test", request_timeout=12).call(
"/run",
json={"input": "value"},
headers=headers,
retry=False,
)

self.assertEqual(result, {"ok": True})
self.assertEqual(meta, {"headers": {"x-test": "ok"}})
self.assertEqual(headers, {"X-CALLER": "present"})
self.assertEqual(post.call_args.kwargs["timeout"], 12)
self.assertEqual(post.call_args.kwargs["headers"]["X-BANANA-API-KEY"], "secret")
self.assertEqual(post.call_args.kwargs["headers"]["X-CALLER"], "present")

def test_client_call_allows_per_call_timeout_override(self):
response = Mock(status_code=200, headers={})
response.json.return_value = {"ok": True}

with patch("banana_dev.client.requests.post", return_value=response) as post:
Client("secret", "https://example.test", request_timeout=12).call(
"/run",
retry=False,
request_timeout=3,
)

self.assertEqual(post.call_args.kwargs["timeout"], 3)

def test_api_methods_pass_timeout(self):
response = Mock(status_code=200)
response.json.return_value = {"results": []}
api = API(" secret ", request_timeout=7)

with patch("banana_dev.api.requests.post", return_value=response) as post:
api._API__call("POST", "projects", {"name": "example"})

with patch("banana_dev.api.requests.put", return_value=response) as put:
api._API__call("PUT", "projects/example", {"name": "renamed"})

with patch("banana_dev.api.requests.get", return_value=response) as get:
result, status = api.list_projects()

self.assertEqual(result, {"results": []})
self.assertEqual(status, 200)
self.assertEqual(post.call_args.kwargs["timeout"], 7)
self.assertEqual(put.call_args.kwargs["timeout"], 7)
self.assertEqual(get.call_args.kwargs["timeout"], 7)
self.assertEqual(get.call_args.kwargs["headers"]["X-BANANA-API-KEY"], "secret")


if __name__ == "__main__":
unittest.main()