From 2a7fb9c9842f02f66c0f298685ca8ce539906af7 Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sun, 14 Jun 2026 17:11:36 +0200 Subject: [PATCH 1/4] Simplify the usage of query parameter parsing Rename `get_url_encoded_query_param` to `get_query_param` and make it an instance method to simplify the usage. --- dist/pyrobusta/assets/www/examples.html | 4 ++-- example/demo_app/app.py | 4 ++-- example/mem_usage/app.py | 6 ++---- src/pyrobusta/protocol/http.py | 20 ++++++++++--------- tests/unit/test_http.py | 26 +++++++++++-------------- 5 files changed, 28 insertions(+), 32 deletions(-) diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html index e0cdfd9..1ade6de 100644 --- a/dist/pyrobusta/assets/www/examples.html +++ b/dist/pyrobusta/assets/www/examples.html @@ -90,8 +90,8 @@

Demo Application

include_server_version = False if http_ctx.query: - is_detailed = http_ctx.get_url_encoded_query_param( - http_ctx.query, "detailed", default="false" + is_detailed = http_ctx.get_query_param( + "detailed", default="false" ).lower() if is_detailed not in ("true", "false"): diff --git a/example/demo_app/app.py b/example/demo_app/app.py index 6817989..8cf96db 100644 --- a/example/demo_app/app.py +++ b/example/demo_app/app.py @@ -12,8 +12,8 @@ def version(http_ctx, _): include_server_version = False if http_ctx.query: - is_detailed = http_ctx.get_url_encoded_query_param( - http_ctx.query, "detailed", default="false" + is_detailed = http_ctx.get_query_param( + "detailed", default="false" ).lower() if is_detailed not in ("true", "false"): diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py index bfde590..f4bb084 100644 --- a/example/mem_usage/app.py +++ b/example/mem_usage/app.py @@ -13,13 +13,11 @@ def mem_usage(http_ctx, _): usage_percentage = 100 * used / (free + used) if http_ctx.query: - value_format = http_ctx.get_url_encoded_query_param( - http_ctx.query, "format", "bytes" - ) + value_format = http_ctx.get_query_param("format", "bytes") if value_format not in ("%", "bytes"): raise ValueError("invalid format") - selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "") + selector = http_ctx.get_query_param("key", "") if selector == "free": if value_format == "%": free = round(100 * free / (used + free), 2) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 12ff6d2..4683141 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -236,7 +236,7 @@ def decorator(func): return decorator # ========================================= - # Static helpers for parsing + # Helpers for parsing # ========================================= @staticmethod @@ -255,18 +255,20 @@ def percent_decode(s: str): i += 1 return "".join(out) - @staticmethod - def get_url_encoded_query_param(query: str, key: str, default: str = None): + def get_query_param(self, key: str, default: str = None) -> str: """ Parse a query and return the value belonging to a key according to the x-www-form-urlencoded format. - :param query: query part :param key: key to parse from the query :param default: default value to return when key is not present + :return: value of the key or default """ - if query.startswith(key + "="): + if not self.query or not key: + return default + + if self.query.startswith(key + "="): idx_start = 0 - elif (idx_start := query.find("&" + key + "=")) != -1: + elif (idx_start := self.query.find("&" + key + "=")) != -1: idx_start += 1 elif default is None: raise KeyError() @@ -274,10 +276,10 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None): return default idx_end = -1 - idx_end = query.find("&", idx_start) + idx_end = self.query.find("&", idx_start) if idx_end > -1: - return query[idx_start + len(key) + 1 : idx_end] - return query[idx_start + len(key) + 1 :] + return self.query[idx_start + len(key) + 1 : idx_end] + return self.query[idx_start + len(key) + 1 :] @staticmethod def _is_matching_url_path(path: bytes, pattern: bytes) -> bool: diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index a02907c..db8ad5d 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -319,9 +319,7 @@ def test_single_url_encoded_query_parameter(self): self.rx.write(request[i : i + 1]) self.engine.state(self.rx) - self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param"), "value" - ) + self.assertEqual(self.engine.get_query_param("param"), "value") def test_multiple_url_encoded_query_parameter(self): request = ( @@ -333,15 +331,15 @@ def test_multiple_url_encoded_query_parameter(self): self.engine.state(self.rx) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param1"), + self.engine.get_query_param("param1"), "value1", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param2"), + self.engine.get_query_param("param2"), "value2", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param3"), + self.engine.get_query_param("param3"), "value3", ) @@ -353,22 +351,20 @@ def test_empty_or_missing_url_encoded_query_parameter(self): self.engine.state(self.rx) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param1"), + self.engine.get_query_param("param1"), "", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param2"), + self.engine.get_query_param("param2"), "", ) self.assertEqual( - self.engine.get_url_encoded_query_param( - self.engine.query, "param3", "default" - ), + self.engine.get_query_param("param3", "default"), "default", ) with self.assertRaises(KeyError): - self.engine.get_url_encoded_query_param(self.engine.query, "param3") + self.engine.get_query_param("param3") def test_overlapping_url_encoded_query_parameter(self): request = b"GET /api/test?data=value1&ta=value2&a=value3 HTTP/1.1\r\n" @@ -378,15 +374,15 @@ def test_overlapping_url_encoded_query_parameter(self): self.engine.state(self.rx) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "data"), + self.engine.get_query_param("data"), "value1", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "ta"), + self.engine.get_query_param("ta"), "value2", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "a"), + self.engine.get_query_param("a"), "value3", ) From c629e972bd6ad550aeb12cbe7ccb29a72373ca8c Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sun, 14 Jun 2026 17:22:18 +0200 Subject: [PATCH 2/4] Unify early termination behavior in streaming and multipart requests Handle early termination requested by the user application when processing chunked or multipart requests. Stop processing the payload immediately when the state machine is terminated by the routing handler. Restructure route handler for file uploads, and create a helper function for identifying multipart requests. Move path validation logic to the beginning. --- src/pyrobusta/protocol/http.py | 59 ++++++++++++++-------- src/pyrobusta/protocol/http_file_server.py | 50 ++++++++---------- src/pyrobusta/protocol/http_multipart.py | 15 +++--- 3 files changed, 68 insertions(+), 56 deletions(-) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 4683141..4f0544a 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -526,6 +526,27 @@ def do_keep_alive(self): self.version == b"HTTP/1.1" and "close" not in connection_tokens ) + def _handle_route_response(self, callback_response: tuple | None): + """ + Terminate the state machine based on the return value of a + user-defined route handler. If the handler does not explicitly + set a status code, default to HTTP 200. If the handler returns + a response body and content type, set them accordingly. + """ + if not self.is_terminated(): + self.terminate(200, True) + + if callback_response is None: + return + + dtype, data = callback_response + if dtype.startswith("multipart/") and callable(data): + self.set_response_header(b"transfer-encoding", b"chunked") + self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) + return + + self.set_response_body(data, content_type=dtype) + def terminate(self, status_code: int, request_complete: bool = False): """ Regular state machine termination with a specific status code. @@ -607,7 +628,13 @@ def is_chunked(self): """ Determines if the request has a payload with chunked transfer-encoding. """ - return self.headers.get("transfer-encoding") == "chunked" + return self.headers.get("transfer-encoding", "").lower() == "chunked" + + def is_multipart(self): + """ + Determines if the request has a multipart payload. + """ + return self.headers.get("content-type", "").lower().startswith("multipart/") def has_payload(self): """ @@ -789,13 +816,17 @@ def _app_endpoint_st(self, rx): if self.has_payload(): if self.is_chunked(): if self.recv_chunk_size: - callback(self, bytes(rx.peek(self.recv_chunk_size))) + callback_response = callback( + self, bytes(rx.peek(self.recv_chunk_size)) + ) self._consume_payload(rx, self.recv_chunk_size + 2) - self.state = self._recv_chunk_size_st - return - # Last chunk, callback with empty body to signal end of request body - callback_response = callback(self, b"") - self._consume_payload(rx, self.recv_chunk_size + 2, last=True) + if not self.is_terminated(): + self.state = self._recv_chunk_size_st + return + else: + # Last chunk, callback with empty body to signal end of request body + callback_response = callback(self, b"") + self._consume_payload(rx, self.recv_chunk_size + 2, last=True) else: callback_response = callback( self, bytes(rx.peek(self.headers["content-length"])) @@ -804,19 +835,7 @@ def _app_endpoint_st(self, rx): else: callback_response = callback(self, b"") - if not self.is_terminated(): - self.terminate(200, True) - - if callback_response is None: - return - - dtype, data = callback_response - if dtype.startswith("multipart/") and callable(data): - self.set_response_header(b"transfer-encoding", b"chunked") - self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) - return - - self.set_response_body(data, content_type=dtype) + self._handle_route_response(callback_response) def _fs_retrieve_st(self, _): """ diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index 83bf3b3..a674d31 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -113,42 +113,33 @@ def upload_file(http_ctx, payload: bytes): Callback function for handling single file uploads, supporting chunked transfer encoding. Uploads are saved to _UPLOAD_ROOT, with the name determined by the URL path. """ - content_type = http_ctx.headers.get("content-type") - if content_type and content_type.lower().startswith("multipart/"): + target_path = http_ctx.url.decode("ascii")[6:] + + if http_ctx.is_multipart() or not is_file_path_valid(target_path): http_ctx.terminate(400) return "text/plain", "Bad request" - is_chunked = http_ctx.headers.get("transfer-encoding") == "chunked" - - if is_chunked: - url_path = http_ctx.url.decode("ascii") - file_name_idx = url_path.rfind("/") + 1 - if not file_name_idx: - http_ctx.terminate(400) - return "text/plain", "Bad request" - file_path = _TMP_DIR + "/" + f"{url_path[file_name_idx:]}.{http_ctx.id}" - else: - file_path = normalize_path(http_ctx.url.decode("ascii")[6:]) - - if not is_file_path_valid(file_path): - http_ctx.terminate(400) - return "text/plain", "Invalid or missing filename" + if not normalize_path(target_path).startswith(_UPLOAD_ROOT): + http_ctx.terminate(403, True) + return "text/plain", "Forbidden" try: - if not file_path.startswith(_UPLOAD_ROOT) and not file_path.startswith( - _TMP_DIR - ): - http_ctx.terminate(403, True) - return "text/plain", "Forbidden" + if http_ctx.is_chunked(): + file_name_idx = target_path.rfind("/") + 1 + if not file_name_idx: + http_ctx.terminate(400) + return "text/plain", "Bad request" + + tmp_path = _TMP_DIR + "/" + f"{target_path[file_name_idx:]}.{http_ctx.id}" - if is_chunked: - if not payload: # Last chunk received, finalize upload - rename(file_path, normalize_path(http_ctx.url.decode("ascii")[6:])) - else: - with open(file_path, "ab") as f: + if payload: # Wait for more chunks before setting response status + with open(tmp_path, "ab") as f: f.write(payload) + return + # Last chunk received, finalize upload + rename(tmp_path, normalize_path(target_path)) else: - with open(file_path, "wb") as f: + with open(normalize_path(target_path), "wb") as f: f.write(payload) http_ctx.terminate(201, True) @@ -166,8 +157,7 @@ def bulk_upload_file(http_ctx, payload: tuple): same file name, the content of the second part is appended to the first part. Split files to multiple parts for chunking large files to avoid HTTP 413 errors. """ - content_type = http_ctx.headers.get("content-type") - if not content_type or not content_type.lower().startswith("multipart/form-data"): + if not http_ctx.is_multipart(): http_ctx.terminate(400) return "text/plain", "Bad request" diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index c21f264..f37e64f 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -134,12 +134,17 @@ def _parse_complete_part_st(self, rx): # Process complete part if not is_final: - callback(self, (part_headers, part_body)) + callback_response = callback(self, (part_headers, part_body)) if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter: raise http.MalformedRequest() self._consume_payload(rx, len(self.mp_delimiter)) self.mp_is_first = False - self.state = self._parse_boundary_st + if not self.is_terminated(): + # Proceed to next part if there is no early termination + self.state = self._parse_boundary_st + elif callback_response: + dtype, data = callback_response + self.set_response_body(data, dtype) return # Process last part @@ -155,11 +160,9 @@ def _parse_complete_part_st(self, rx): self._consume_payload(rx, 0, last=True) self.mp_is_last = True - dtype, data = callback(self, (part_headers, part_body)) + callback_response = callback(self, (part_headers, part_body)) - if not self.is_terminated(): - self.terminate(200, True) - self.set_response_body(data, dtype) + self._handle_route_response(callback_response) def apply_patches(): From f12870befb824fe0e8cacaedf02290c258d4db7f Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sun, 14 Jun 2026 18:25:37 +0200 Subject: [PATCH 3/4] Refactor multipart response generation Convert `generate_multipart_response()` from a state-machine state into a helper function. The previous design required wrapper lambdas to adapt multipart response generation to the state handler interface, which expects only an `rx` argument. By moving multipart generation into a helper, applications can invoke it directly without returning an additional callback. This also unifies multipart route responses and early state machine termination under the same response handling path, eliminating ambiguous state transitions. --- .pylintrc | 3 ++- src/pyrobusta/protocol/http.py | 10 ++++------ src/pyrobusta/protocol/http_multipart.py | 11 +++++------ 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.pylintrc b/.pylintrc index 60a1103..ce7476a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -6,4 +6,5 @@ disable=E0611, R1710 [DESIGN] -max-attributes=15 \ No newline at end of file +max-attributes=15 +max-public-methods=25 diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 4f0544a..8e7146b 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -542,7 +542,7 @@ def _handle_route_response(self, callback_response: tuple | None): dtype, data = callback_response if dtype.startswith("multipart/") and callable(data): self.set_response_header(b"transfer-encoding", b"chunked") - self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) + self.generate_multipart_response(data, dtype) return self.set_response_body(data, content_type=dtype) @@ -881,15 +881,13 @@ def _start_multipart_parser_st(self, rx): # pylint: disable=W0613 """ Initial state for processing multipart requests (placeholder). """ - self.terminate(503) + self.abort(503) - def _generate_multipart_response( - self, rx, callback, dtype - ): # pylint: disable=W0613 + def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613 """ Generate multipart response depening on the exact content type (placeholder). """ - self.terminate(503, True) + self.abort(503) def enable_optional_features(): diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index f37e64f..86af41b 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -15,16 +15,16 @@ from pyrobusta.utils.helpers import add_method -def _generate_multipart_response(self, _, callback: callable, dtype: str): +def generate_multipart_response(self, callback: callable, dtype: str): """ Generate multipart response depening on the exact content type. The callback function is called without arguments, and it must return bytes-like objects. :param callback: function for part generation, each call generates a separate part :param dtype: exact multipart content-type (multipart/*) """ - if type(callback).__name__ not in ("function", "closure"): + if not callable(callback): raise ValueError("Invalid response handler") - self.terminate(200, True) + boundary = self.MULTIPART_BOUNDARY self.set_response_header( b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary @@ -143,8 +143,7 @@ def _parse_complete_part_st(self, rx): # Proceed to next part if there is no early termination self.state = self._parse_boundary_st elif callback_response: - dtype, data = callback_response - self.set_response_body(data, dtype) + self._handle_route_response(callback_response) return # Process last part @@ -181,7 +180,7 @@ def new_init(self, *args, **kwargs): http.HttpEngine.__init__ = new_init http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary" - add_method(http.HttpEngine, _generate_multipart_response) + add_method(http.HttpEngine, generate_multipart_response) add_method(http.HttpEngine, _multipart_wrapper_factory, "static") add_method(http.HttpEngine, _start_multipart_parser_st) add_method(http.HttpEngine, _parse_boundary_st) From 7b35ab5a239f7a9bf9dacd31cdf9c2060760016f Mon Sep 17 00:00:00 2001 From: szeka9 Date: Sun, 14 Jun 2026 18:54:54 +0200 Subject: [PATCH 4/4] Add tests for file paths with trailing / Reject file paths with a trailing / character, create tests for the validation function and the file server module. --- tests/unit/test_helpers.py | 9 ++++++++- tests/unit/test_http_file_server.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 2d85271..7214a41 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -121,7 +121,14 @@ def test_path_segment_validation(self): def test_file_path_validation(self): valid_paths = ["/file", "/dir1/file", "/dir-2/file", "/dir_3/file"] - invalid_paths = ["file", "dir1/file", "/dir\\segment/file"] + invalid_paths = [ + "file", + "dir1/file", + "/dir\\segment/file", + "/", + "/dir/", + "/dir/file/", + ] for path in valid_paths: self.assertTrue(self.helpers_module.is_file_path_valid(path)) diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py index 2bfb210..597f7ed 100644 --- a/tests/unit/test_http_file_server.py +++ b/tests/unit/test_http_file_server.py @@ -276,6 +276,23 @@ def test_file_serving_complete_file_invalid_name(self, *_): self.assertEqual(self.engine.status_code, 400) + def test_file_serving_complete_file_invalid_path(self, *_): + self.engine.url = b"/files/www/user_data/file/" + self.engine.method = b"PUT" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 28 + self.engine.headers["content-type"] = "application/octet-stream" + + self.engine.state = self.engine._app_endpoint_st + body_part = b"File uploaded for testing.\r\n" + self.rx.write(body_part) + + while self.engine.state is not None: + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 400) + def test_file_serving_chunked_file_upload(self, *_): self.engine.url = b"/files/www/user_data/chunked.txt" self.engine.method = b"PUT"