From 42deaa6419d567af7cc3784e686ff65d6627c25a Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sat, 13 Jun 2026 16:21:56 +0200 Subject: [PATCH] Improve API usability; minor fixes Allow all types of multipart response types to trigger multipart response processing. Enforce chunked encoding for all multipart responses with a function callback as the response length is not known in advance. Normalize keys when setting response headers. Allow all callables to be used as response producers, not just closures. --- src/pyrobusta/bindings/http_connection.py | 2 +- src/pyrobusta/protocol/http.py | 8 ++++---- src/pyrobusta/protocol/http_file_server.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/pyrobusta/bindings/http_connection.py b/src/pyrobusta/bindings/http_connection.py index 876937d..7edba07 100644 --- a/src/pyrobusta/bindings/http_connection.py +++ b/src/pyrobusta/bindings/http_connection.py @@ -98,7 +98,7 @@ async def _run_state_machine(self): await self._response_handler(self._engine.resp_handler) async def _response_handler(self, resp_handler): - if "closure" == type(resp_handler).__name__: + if callable(resp_handler): if self._engine.get_response_header(b"transfer-encoding") == b"chunked": for is_finished in resp_handler(self._send_buf): await self.write(b"%x\r\n" % self._send_buf.size()) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 3a8e73e..12ff6d2 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -387,9 +387,7 @@ def _get_mp_boundary(headers: dict) -> str: and return the boundary value. """ content_type = headers.get("content-type") - if not content_type or not content_type.lower().startswith( - "multipart/form-data" - ): + if not content_type or not content_type.lower().startswith("multipart/"): return None parts = content_type.split(";") @@ -440,6 +438,7 @@ def set_response_header(self, key: bytes, value: bytes): :param key: HTTP header key :param value: HTTP header value """ + key = key.lower() if ( key in self.resp_headers and (index := self.resp_headers.index(key)) % 2 == 0 @@ -810,7 +809,8 @@ def _app_endpoint_st(self, rx): return dtype, data = callback_response - if dtype.startswith("multipart/"): + if dtype.startswith("multipart/") and callable(data): + self.set_response_header(b"transfer-encoding", b"chunked") self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) return diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index c8fa52e..83bf3b3 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -254,7 +254,7 @@ def _traverse_dir(tx): tx.write(b",") file_stat = stat(it) - obj = dumps( + data = dumps( { "path": it, "size": str(file_stat[6]), @@ -263,11 +263,11 @@ def _traverse_dir(tx): ).encode("ascii") written = 0 - while written < len(obj): + while written < len(data): to_write = tx.capacity - tx.size() if not to_write: raise BufferError() - tx.write(obj[written : written + to_write]) + tx.write(data[written : written + to_write]) written += to_write yield False tx.write(b"]\r\n")