From 5529eb0fb11e3139ebd2df95c4a2bd1aee5c8f98 Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sat, 20 Jun 2026 19:17:54 +0200 Subject: [PATCH 1/3] Add new helper for parsing path segments Simplify applications by adding a helper function for obtaining URL path segments by index. --- dist/pyrobusta/assets/www/examples.html | 2 +- example/demo_app/app.py | 2 +- src/pyrobusta/protocol/http.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html index 1ade6de..f081947 100644 --- a/dist/pyrobusta/assets/www/examples.html +++ b/dist/pyrobusta/assets/www/examples.html @@ -115,7 +115,7 @@

Demo Application

@HttpEngine.route("/{app_or_server}/version", "GET") def version(http_ctx, _): include_server_version = False - resource = http_ctx.url.split(b"/")[1] + resource = http_ctx.path_segment(0) if resource not in (b"app", b"server"): http_ctx.terminate(404, True) diff --git a/example/demo_app/app.py b/example/demo_app/app.py index 8cf96db..f3c346d 100644 --- a/example/demo_app/app.py +++ b/example/demo_app/app.py @@ -37,7 +37,7 @@ def version(http_ctx, _): @HttpEngine.route("/{app_or_server}/version", "GET") def version(http_ctx, _): include_server_version = False - resource = http_ctx.url.split(b"/")[1] + resource = http_ctx.path_segment(0) if resource not in (b"app", b"server"): http_ctx.terminate(404, True) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 8e7146b..921068e 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -255,6 +255,16 @@ def percent_decode(s: str): i += 1 return "".join(out) + def path_segment(self, idx: int): + """ + Return the nth path segment of the URL path. + The index is shifted by one to ignore the first + empty segment before the leading slash ('/'). + :param idx: index of the segment + :return: string path segment + """ + return self.url.split(b"/")[idx + 1].decode("ascii") + def get_query_param(self, key: str, default: str = None) -> str: """ Parse a query and return the value belonging to a key From 80ff82f6bc89cb5656781c5bea74b65e592688f8 Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sat, 20 Jun 2026 19:54:27 +0200 Subject: [PATCH 2/3] Simplify keep-alive connection handling The public terminate() method exposes an optional switch (request_complete) to indicate the completeness of request processing. Based on this switch, a keep-alive connection is forcefully closed if the request body is partially processed. Such an approach is needed to avoid incorrectly parsing a subsequent requests with invalid framing. This commit removes the requirement for applications to know if the processing of a request is complete. Instead, rely on existing content-length validation logic to determine if a payload is fully processed. The following changes were added: - new state machine variable `_is_req_complete` to indicate completeness - update usages of the `terminate()` method to only accept status codes - new terminal state for updating response headers for keep-alive connections - the state machine is considered to reach a terminal state regardless of status code - new helper for resetting response body producer - remove member `aborted`; keep-alive connection is only based on request process completeness --- dist/pyrobusta/assets/www/examples.html | 4 +- example/demo_app/app.py | 6 +- src/pyrobusta/protocol/http.py | 86 ++++++++++++++-------- src/pyrobusta/protocol/http_file_server.py | 26 +++---- src/pyrobusta/protocol/http_multipart.py | 2 +- tests/unit/test_http.py | 24 +++--- tests/unit/test_http_file_server.py | 20 ++--- tests/unit/test_http_multipart.py | 2 +- 8 files changed, 95 insertions(+), 75 deletions(-) diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html index f081947..99e0481 100644 --- a/dist/pyrobusta/assets/www/examples.html +++ b/dist/pyrobusta/assets/www/examples.html @@ -95,7 +95,7 @@

Demo Application

).lower() if is_detailed not in ("true", "false"): - http_ctx.terminate(400, True) + http_ctx.terminate(400) return "text/plain", "Invalid query" include_server_version = is_detailed.lower() == "true" @@ -118,7 +118,7 @@

Demo Application

resource = http_ctx.path_segment(0) if resource not in (b"app", b"server"): - http_ctx.terminate(404, True) + http_ctx.terminate(404) return "text/plain", "Not found" version_string = APP_VERSION if resource == b"app" else PYROBUSTA_VERSION diff --git a/example/demo_app/app.py b/example/demo_app/app.py index f3c346d..12426df 100644 --- a/example/demo_app/app.py +++ b/example/demo_app/app.py @@ -17,7 +17,7 @@ def version(http_ctx, _): ).lower() if is_detailed not in ("true", "false"): - http_ctx.terminate(400, True) + http_ctx.terminate(400) return "text/plain", "Invalid query" include_server_version = is_detailed.lower() == "true" @@ -40,7 +40,7 @@ def version(http_ctx, _): resource = http_ctx.path_segment(0) if resource not in (b"app", b"server"): - http_ctx.terminate(404, True) + http_ctx.terminate(404) return "text/plain", "Not found" version_string = APP_VERSION if resource == b"app" else PYROBUSTA_VERSION @@ -59,4 +59,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 921068e..67d993b 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -57,7 +57,6 @@ class HttpEngine: "status_code", "resp_headers", "resp_handler", - "aborted", "version", "headers", "method", @@ -66,6 +65,7 @@ class HttpEngine: "content_len_cnt", "recv_chunk_size", "is_req_empty", + "_is_req_complete", "mp_boundary", "mp_is_first", "mp_is_last", @@ -151,7 +151,6 @@ def __init__(self): self.status_code = None self.resp_headers = [] self.resp_handler = None - self.aborted = False # [Recived request] self.version = None @@ -162,6 +161,7 @@ def __init__(self): self.content_len_cnt = 0 self.recv_chunk_size = 0 self.is_req_empty = True + self._is_req_complete = False # [Multipart state] self.mp_boundary = None @@ -175,7 +175,6 @@ def reset(self): self.status_code = None self.resp_headers.clear() self.resp_handler = None - self.aborted = False self.version = None self.headers.clear() self.method = None @@ -184,6 +183,7 @@ def reset(self): self.content_len_cnt = 0 self.recv_chunk_size = 0 self.is_req_empty = True + self._is_req_complete = False self.mp_boundary = None # ========================================= @@ -503,8 +503,10 @@ def set_response_body( :param body: body to be sent in the response :param content_type: content-type of the body """ - if body is None: - return + self.resp_handler = None + + if not body: + body_encoded = b"" if isinstance(body, (bytes, bytearray, memoryview)): body_encoded = body elif isinstance(body, str): @@ -520,12 +522,20 @@ def set_response_body( if self.method != self.HEAD: self.resp_handler = BytesIO(body_encoded) + def _unset_response_handler(self): + """ + Unset the response handler (if set). + """ + if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"): + self.resp_handler.close() + self.resp_handler = None + def do_keep_alive(self): """ Determine if the connection should be kept alive depending on the HTTP version and headers sent in the request. """ - if self.aborted: + if self.is_terminated() and not self._is_req_complete: return False connection_tokens = [ @@ -543,8 +553,7 @@ def _handle_route_response(self, callback_response: tuple | None): set a status code, default to HTTP 200. If the handler returns a response body and content type, set them accordingly. """ - if not self.is_terminated(): - self.terminate(200, True) + self.terminate(self.status_code or 200) if callback_response is None: return @@ -557,36 +566,25 @@ def _handle_route_response(self, callback_response: tuple | None): self.set_response_body(data, content_type=dtype) - def terminate(self, status_code: int, request_complete: bool = False): + def terminate(self, status_code: int): """ Regular state machine termination with a specific status code. :param status_code: HTTP status code - :param request_complete: true if the complete request is processed """ - self.state = None + self.state = self._terminal_st + if not isinstance(status_code, int) or status_code not in self.RESP_HEADERS: + raise ValueError("Invalid status") self.status_code = status_code - if self.version == b"HTTP/1.0" and self.do_keep_alive() and request_complete: - self.set_response_header(b"connection", b"keep-alive") - elif ( - self.version == b"HTTP/1.1" - and not self.do_keep_alive() - and not request_complete - ): - self.set_response_header(b"connection", b"close") - def abort(self, status_code: int): """ Abort state machine due to runtime errors. Reset any header or response body set earlier. :param status_code: HTTP status code """ - self.aborted = True self.resp_headers = [] - if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"): - self.resp_handler.close() - self.resp_handler = None - self.terminate(status_code, False) + self._unset_response_handler() + self.terminate(status_code) def is_request_empty(self): """ @@ -598,7 +596,7 @@ def is_terminated(self): """ Returns true if the state machine is terminated. """ - return self.state is None and self.status_code + return self.state is None def run(self, rx): """ @@ -661,6 +659,7 @@ def _consume_payload(self, rx, size, last=False): of content length when the last flag is set. When the request is chunked, the content length should not be set, otherwise it is ignored. """ + assert not self._is_req_complete if ( not self.is_chunked() and "content-length" in self.headers @@ -675,6 +674,7 @@ def _consume_payload(self, rx, size, last=False): raise InvalidContentLength() self.content_len_cnt += size rx.consume(size) + self._is_req_complete = last # ================================================================================ # Parser states @@ -747,7 +747,7 @@ def _route_request_st(self, _): if self.method == self.OPTIONS: supported_methods = self._supported_methods(self.url) self.set_response_header(b"allow", b", ".join(supported_methods)) - self.terminate(204, True) + self.terminate(204) return if self.has_payload(): if self.method in (self.GET, self.HEAD): @@ -830,7 +830,7 @@ def _app_endpoint_st(self, rx): self, bytes(rx.peek(self.recv_chunk_size)) ) self._consume_payload(rx, self.recv_chunk_size + 2) - if not self.is_terminated(): + if not self.state == self._terminal_st: # pylint: disable=W0143 self.state = self._recv_chunk_size_st return else: @@ -844,6 +844,7 @@ def _app_endpoint_st(self, rx): self._consume_payload(rx, self.headers["content-length"], last=True) else: callback_response = callback(self, b"") + self._is_req_complete = True self._handle_route_response(callback_response) @@ -865,7 +866,7 @@ def _fs_retrieve_st(self, _): try: if not is_path_served: stat(norm_path) - self.terminate(403, True) + self.terminate(403) return try: @@ -880,12 +881,12 @@ def _fs_retrieve_st(self, _): b"content-length", str(stat(norm_path)[6]).encode("ascii") ) self.set_response_header(b"content-type", content_type) - self.terminate(200, True) + self.terminate(200) if self.method != self.HEAD: self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732 return except OSError: - self.terminate(404, True) + self.terminate(404) def _start_multipart_parser_st(self, rx): # pylint: disable=W0613 """ @@ -899,6 +900,29 @@ def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613 """ self.abort(503) + def _terminal_st(self, rx): # pylint: disable=W0613 + """ + Terminal state for finalizing request/response processing. + """ + if ( + self.version == b"HTTP/1.0" + and self.do_keep_alive() + and self._is_req_complete + ): + self.set_response_header(b"connection", b"keep-alive") + elif self.version == b"HTTP/1.1" and ( + not self.do_keep_alive() or not self._is_req_complete + ): + self.set_response_header(b"connection", b"close") + + if ( + self.get_response_header(b"transfer-encoding") != b"chunked" + and self.get_response_header(b"content-length") is None + ): + self.set_response_header(b"content-length", b"0") + + self.state = None + def enable_optional_features(): """ diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index a674d31..2fd2df1 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -43,14 +43,14 @@ def fs_retrieve(http_ctx, _): try: if not is_path_served: stat(norm_path) - http_ctx.terminate(403, True) + http_ctx.terminate(403) return "text/plain", "Forbidden" # Retrieve directory structure if stat(norm_path)[0] & 0x4000: http_ctx.set_response_header(b"content-type", b"application/json") http_ctx.set_response_header(b"transfer-encoding", b"chunked") - http_ctx.terminate(200, True) + http_ctx.terminate(200) http_ctx.resp_handler = _traverse_dir_factory(norm_path) return @@ -67,11 +67,11 @@ def fs_retrieve(http_ctx, _): b"content-length", str(stat(norm_path)[6]).encode("ascii") ) http_ctx.set_response_header(b"content-type", content_type) - http_ctx.terminate(200, True) + http_ctx.terminate(200) if http_ctx.method != http_ctx.HEAD: http_ctx.resp_handler = open(norm_path, "rb") # pylint: disable=R1732 except OSError: - http_ctx.terminate(404, True) + http_ctx.terminate(404) return "text/plain", "Not found" @@ -87,24 +87,24 @@ def delete_file(http_ctx, _): try: if not fs_path.startswith(_UPLOAD_ROOT): stat(fs_path) - http_ctx.terminate(403, True) + http_ctx.terminate(403) return "text/plain", "Forbidden" # Delete directory structure if stat(fs_path)[0] & 0x4000: if listdir(fs_path): - http_ctx.terminate(400, True) + http_ctx.terminate(400) return "text/plain", "Directory not empty" rmdir(fs_path) - http_ctx.terminate(204, True) + http_ctx.terminate(204) return "text/plain", "Deleted" # Delete file remove(fs_path) - http_ctx.terminate(204, True) + http_ctx.terminate(204) return "text/plain", "Deleted" except OSError: - http_ctx.terminate(404, True) + http_ctx.terminate(404) return "text/plain", "Not found" @@ -120,7 +120,7 @@ def upload_file(http_ctx, payload: bytes): return "text/plain", "Bad request" if not normalize_path(target_path).startswith(_UPLOAD_ROOT): - http_ctx.terminate(403, True) + http_ctx.terminate(403) return "text/plain", "Forbidden" try: @@ -142,10 +142,10 @@ def upload_file(http_ctx, payload: bytes): with open(normalize_path(target_path), "wb") as f: f.write(payload) - http_ctx.terminate(201, True) + http_ctx.terminate(201) return "text/plain", "OK" except OSError: - http_ctx.terminate(404, True) + http_ctx.terminate(404) return "text/plain", "Not found" @@ -191,7 +191,7 @@ def bulk_upload_file(http_ctx, payload: tuple): if file.endswith(suffix): rename(_TMP_DIR + "/" + file, _UPLOAD_ROOT + "/" + file[: -len(suffix)]) - http_ctx.terminate(201, True) + http_ctx.terminate(201) return "text/plain", "OK" diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index 86af41b..52985d7 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -139,7 +139,7 @@ def _parse_complete_part_st(self, rx): raise http.MalformedRequest() self._consume_payload(rx, len(self.mp_delimiter)) self.mp_is_first = False - if not self.is_terminated(): + if not self.state == self._terminal_st: # Proceed to next part if there is no early termination self.state = self._parse_boundary_st elif callback_response: diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index db8ad5d..cd194a8 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -141,7 +141,7 @@ def test_status_parsing_unsupported_method(self): self.assertEqual(self.engine.method, b"NOTSUPORTED") self.assertEqual(self.engine.url, b"/index.html") self.assertEqual(self.engine.version, b"HTTP/1.1") - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) self.assertEqual(self.engine.status_code, 405) def test_status_parsing_unsupported_version(self): @@ -156,7 +156,7 @@ def test_status_parsing_unsupported_version(self): self.assertEqual(self.engine.method, b"GET") self.assertEqual(self.engine.url, b"/index.html") self.assertEqual(self.engine.version, b"HTTP/2") - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) self.assertEqual(self.engine.status_code, 505) def test_header_parsing_valid(self): @@ -219,7 +219,7 @@ def test_routing_unsupported_method(self): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 405) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) self.assertIn(b"allow", self.engine.resp_headers) self.assertIn(b"POST", self.engine.resp_headers) @@ -237,7 +237,7 @@ def test_routing_options_method(self): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 204) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) self.assertIn(b"allow", self.engine.resp_headers) self.assertIn(b"GET, POST, PUT", self.engine.resp_headers) @@ -256,7 +256,6 @@ def test_routing_get_method(self): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) self.assertEqual( int(self.engine._lookup(self.engine.resp_headers, b"content-length")), len(test_response), @@ -278,7 +277,6 @@ def test_routing_head_method(self): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) self.assertEqual( int(self.engine._lookup(self.engine.resp_headers, b"content-length")), len(test_response), @@ -445,7 +443,7 @@ def test_chunked_transfer_encoding_valid(self): ) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) def test_chunked_transfer_encoding_invalid_chunk_size_smaller(self): self.engine.url = b"/api/test" @@ -502,7 +500,6 @@ def test_payload_length_matches_content_length(self): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) test_callback.assert_called_with(self.engine, payload) def test_payload_length_exceeds_content_length(self): @@ -533,7 +530,6 @@ def test_payload_length_exceeds_content_length(self): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) test_callback.assert_called_with(self.engine, b"hello world") self.assertEqual( self.rx.peek(), b"!" @@ -610,7 +606,7 @@ def test_file_serving_root(self, *_): self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_subdir(self, *_): @@ -626,7 +622,7 @@ def test_file_serving_subdir(self, *_): self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_missing_file(self, *_): @@ -638,7 +634,7 @@ def test_file_serving_missing_file(self, *_): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 404) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_known_content_type(self, *_): @@ -658,7 +654,7 @@ def test_file_serving_known_content_type(self, *_): b"application/javascript", ) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_fallback_content_type(self, *_): @@ -678,7 +674,7 @@ def test_file_serving_fallback_content_type(self, *_): b"application/octet-stream", ) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) if __name__ == "__main__": diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py index 597f7ed..4131cac 100644 --- a/tests/unit/test_http_file_server.py +++ b/tests/unit/test_http_file_server.py @@ -44,7 +44,7 @@ def test_file_serving_missing_file(self, *_): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 404) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_files_endpoint(self, *_): @@ -60,7 +60,7 @@ def test_file_serving_files_endpoint(self, *_): self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_known_content_type(self, *_): @@ -80,7 +80,7 @@ def test_file_serving_known_content_type(self, *_): b"application/javascript", ) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_fallback_content_type(self, *_): @@ -100,7 +100,7 @@ def test_file_serving_fallback_content_type(self, *_): b"application/octet-stream", ) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_unserved_content_rejected(self, *_): @@ -116,7 +116,7 @@ def test_file_serving_unserved_content_rejected(self, *_): self.assertNotEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 403) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat(stat_is_file=False) def test_file_serving_directory_path(self): @@ -133,7 +133,7 @@ def test_file_serving_directory_path(self): b"application/json", ) self.assertEqual(self.engine.status_code, 200) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_directory_traversal(self): @@ -179,7 +179,7 @@ def test_file_serving_missing_file(self, *_): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 404) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_non_user_data_rejected(self, *_): @@ -191,7 +191,7 @@ def test_file_serving_non_user_data_rejected(self, *_): self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 403) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat() def test_file_serving_user_data_deleted(self, *_): @@ -205,7 +205,7 @@ def test_file_serving_user_data_deleted(self, *_): m.assert_called_once_with("/www/user_data/user_content.json") self.assertEqual(self.engine.status_code, 204) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) @patch_os_stat(stat_is_file=True) def test_file_serving_user_directory_deleted(self, *_): @@ -219,7 +219,7 @@ def test_file_serving_user_directory_deleted(self, *_): m.assert_called_once_with("/www/user_data/user_dir") self.assertEqual(self.engine.status_code, 204) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) class TestFileServerUpload(TestHttpBase): diff --git a/tests/unit/test_http_multipart.py b/tests/unit/test_http_multipart.py index 740c488..406a2e3 100644 --- a/tests/unit/test_http_multipart.py +++ b/tests/unit/test_http_multipart.py @@ -150,7 +150,7 @@ def test_multipart_receiver_last_part(self): self.engine.state(self.rx) - self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.state, self.engine._terminal_st) self.assertEqual(self.engine.status_code, 200) test_callback.assert_called_once_with( self.engine, From ea6f903e2939ebc903af0a4ad4d0a77d844fc6c7 Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sat, 20 Jun 2026 20:00:16 +0200 Subject: [PATCH 3/3] Enforce stateless behavior in response body setter `set_response_body` may be invoked multiple times during the lifecycle of a connection. Ensure that the response handler is reset every time the setter is called, to avoid a stale handler to be reused. Additionally, accept an empty/undefined response body, and set content-length to 0 accordingly. --- src/pyrobusta/protocol/http.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 67d993b..81ca0dc 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -503,7 +503,7 @@ def set_response_body( :param body: body to be sent in the response :param content_type: content-type of the body """ - self.resp_handler = None + self._unset_response_handler() if not body: body_encoded = b"" @@ -515,12 +515,16 @@ def set_response_body( body_encoded = dumps(body).encode() else: raise ValueError("Unhandled body type") + self.set_response_header( b"content-length", str(len(body_encoded)).encode("ascii") ) - self.set_response_header(b"content-type", content_type.encode("ascii")) - if self.method != self.HEAD: - self.resp_handler = BytesIO(body_encoded) + + if len(body_encoded): + self.set_response_header(b"content-type", content_type.encode("ascii")) + + if self.method != self.HEAD: + self.resp_handler = BytesIO(body_encoded) def _unset_response_handler(self): """