diff --git a/.pylintrc b/.pylintrc index 60a1103..ce7476a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -6,4 +6,5 @@ disable=E0611, R1710 [DESIGN] -max-attributes=15 \ No newline at end of file +max-attributes=15 +max-public-methods=25 diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html index e0cdfd9..1ade6de 100644 --- a/dist/pyrobusta/assets/www/examples.html +++ b/dist/pyrobusta/assets/www/examples.html @@ -90,8 +90,8 @@

Demo Application

include_server_version = False if http_ctx.query: - is_detailed = http_ctx.get_url_encoded_query_param( - http_ctx.query, "detailed", default="false" + is_detailed = http_ctx.get_query_param( + "detailed", default="false" ).lower() if is_detailed not in ("true", "false"): diff --git a/example/demo_app/app.py b/example/demo_app/app.py index 6817989..8cf96db 100644 --- a/example/demo_app/app.py +++ b/example/demo_app/app.py @@ -12,8 +12,8 @@ def version(http_ctx, _): include_server_version = False if http_ctx.query: - is_detailed = http_ctx.get_url_encoded_query_param( - http_ctx.query, "detailed", default="false" + is_detailed = http_ctx.get_query_param( + "detailed", default="false" ).lower() if is_detailed not in ("true", "false"): diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py index bfde590..f4bb084 100644 --- a/example/mem_usage/app.py +++ b/example/mem_usage/app.py @@ -13,13 +13,11 @@ def mem_usage(http_ctx, _): usage_percentage = 100 * used / (free + used) if http_ctx.query: - value_format = http_ctx.get_url_encoded_query_param( - http_ctx.query, "format", "bytes" - ) + value_format = http_ctx.get_query_param("format", "bytes") if value_format not in ("%", "bytes"): raise ValueError("invalid format") - selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "") + selector = http_ctx.get_query_param("key", "") if selector == "free": if value_format == "%": free = round(100 * free / (used + free), 2) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 12ff6d2..8e7146b 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -236,7 +236,7 @@ def decorator(func): return decorator # ========================================= - # Static helpers for parsing + # Helpers for parsing # ========================================= @staticmethod @@ -255,18 +255,20 @@ def percent_decode(s: str): i += 1 return "".join(out) - @staticmethod - def get_url_encoded_query_param(query: str, key: str, default: str = None): + def get_query_param(self, key: str, default: str = None) -> str: """ Parse a query and return the value belonging to a key according to the x-www-form-urlencoded format. - :param query: query part :param key: key to parse from the query :param default: default value to return when key is not present + :return: value of the key or default """ - if query.startswith(key + "="): + if not self.query or not key: + return default + + if self.query.startswith(key + "="): idx_start = 0 - elif (idx_start := query.find("&" + key + "=")) != -1: + elif (idx_start := self.query.find("&" + key + "=")) != -1: idx_start += 1 elif default is None: raise KeyError() @@ -274,10 +276,10 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None): return default idx_end = -1 - idx_end = query.find("&", idx_start) + idx_end = self.query.find("&", idx_start) if idx_end > -1: - return query[idx_start + len(key) + 1 : idx_end] - return query[idx_start + len(key) + 1 :] + return self.query[idx_start + len(key) + 1 : idx_end] + return self.query[idx_start + len(key) + 1 :] @staticmethod def _is_matching_url_path(path: bytes, pattern: bytes) -> bool: @@ -524,6 +526,27 @@ def do_keep_alive(self): self.version == b"HTTP/1.1" and "close" not in connection_tokens ) + def _handle_route_response(self, callback_response: tuple | None): + """ + Terminate the state machine based on the return value of a + user-defined route handler. If the handler does not explicitly + set a status code, default to HTTP 200. If the handler returns + a response body and content type, set them accordingly. + """ + if not self.is_terminated(): + self.terminate(200, True) + + if callback_response is None: + return + + dtype, data = callback_response + if dtype.startswith("multipart/") and callable(data): + self.set_response_header(b"transfer-encoding", b"chunked") + self.generate_multipart_response(data, dtype) + return + + self.set_response_body(data, content_type=dtype) + def terminate(self, status_code: int, request_complete: bool = False): """ Regular state machine termination with a specific status code. @@ -605,7 +628,13 @@ def is_chunked(self): """ Determines if the request has a payload with chunked transfer-encoding. """ - return self.headers.get("transfer-encoding") == "chunked" + return self.headers.get("transfer-encoding", "").lower() == "chunked" + + def is_multipart(self): + """ + Determines if the request has a multipart payload. + """ + return self.headers.get("content-type", "").lower().startswith("multipart/") def has_payload(self): """ @@ -787,13 +816,17 @@ def _app_endpoint_st(self, rx): if self.has_payload(): if self.is_chunked(): if self.recv_chunk_size: - callback(self, bytes(rx.peek(self.recv_chunk_size))) + callback_response = callback( + self, bytes(rx.peek(self.recv_chunk_size)) + ) self._consume_payload(rx, self.recv_chunk_size + 2) - self.state = self._recv_chunk_size_st - return - # Last chunk, callback with empty body to signal end of request body - callback_response = callback(self, b"") - self._consume_payload(rx, self.recv_chunk_size + 2, last=True) + if not self.is_terminated(): + self.state = self._recv_chunk_size_st + return + else: + # Last chunk, callback with empty body to signal end of request body + callback_response = callback(self, b"") + self._consume_payload(rx, self.recv_chunk_size + 2, last=True) else: callback_response = callback( self, bytes(rx.peek(self.headers["content-length"])) @@ -802,19 +835,7 @@ def _app_endpoint_st(self, rx): else: callback_response = callback(self, b"") - if not self.is_terminated(): - self.terminate(200, True) - - if callback_response is None: - return - - dtype, data = callback_response - if dtype.startswith("multipart/") and callable(data): - self.set_response_header(b"transfer-encoding", b"chunked") - self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) - return - - self.set_response_body(data, content_type=dtype) + self._handle_route_response(callback_response) def _fs_retrieve_st(self, _): """ @@ -860,15 +881,13 @@ def _start_multipart_parser_st(self, rx): # pylint: disable=W0613 """ Initial state for processing multipart requests (placeholder). """ - self.terminate(503) + self.abort(503) - def _generate_multipart_response( - self, rx, callback, dtype - ): # pylint: disable=W0613 + def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613 """ Generate multipart response depening on the exact content type (placeholder). """ - self.terminate(503, True) + self.abort(503) def enable_optional_features(): diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index 83bf3b3..a674d31 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -113,42 +113,33 @@ def upload_file(http_ctx, payload: bytes): Callback function for handling single file uploads, supporting chunked transfer encoding. Uploads are saved to _UPLOAD_ROOT, with the name determined by the URL path. """ - content_type = http_ctx.headers.get("content-type") - if content_type and content_type.lower().startswith("multipart/"): + target_path = http_ctx.url.decode("ascii")[6:] + + if http_ctx.is_multipart() or not is_file_path_valid(target_path): http_ctx.terminate(400) return "text/plain", "Bad request" - is_chunked = http_ctx.headers.get("transfer-encoding") == "chunked" - - if is_chunked: - url_path = http_ctx.url.decode("ascii") - file_name_idx = url_path.rfind("/") + 1 - if not file_name_idx: - http_ctx.terminate(400) - return "text/plain", "Bad request" - file_path = _TMP_DIR + "/" + f"{url_path[file_name_idx:]}.{http_ctx.id}" - else: - file_path = normalize_path(http_ctx.url.decode("ascii")[6:]) - - if not is_file_path_valid(file_path): - http_ctx.terminate(400) - return "text/plain", "Invalid or missing filename" + if not normalize_path(target_path).startswith(_UPLOAD_ROOT): + http_ctx.terminate(403, True) + return "text/plain", "Forbidden" try: - if not file_path.startswith(_UPLOAD_ROOT) and not file_path.startswith( - _TMP_DIR - ): - http_ctx.terminate(403, True) - return "text/plain", "Forbidden" + if http_ctx.is_chunked(): + file_name_idx = target_path.rfind("/") + 1 + if not file_name_idx: + http_ctx.terminate(400) + return "text/plain", "Bad request" + + tmp_path = _TMP_DIR + "/" + f"{target_path[file_name_idx:]}.{http_ctx.id}" - if is_chunked: - if not payload: # Last chunk received, finalize upload - rename(file_path, normalize_path(http_ctx.url.decode("ascii")[6:])) - else: - with open(file_path, "ab") as f: + if payload: # Wait for more chunks before setting response status + with open(tmp_path, "ab") as f: f.write(payload) + return + # Last chunk received, finalize upload + rename(tmp_path, normalize_path(target_path)) else: - with open(file_path, "wb") as f: + with open(normalize_path(target_path), "wb") as f: f.write(payload) http_ctx.terminate(201, True) @@ -166,8 +157,7 @@ def bulk_upload_file(http_ctx, payload: tuple): same file name, the content of the second part is appended to the first part. Split files to multiple parts for chunking large files to avoid HTTP 413 errors. """ - content_type = http_ctx.headers.get("content-type") - if not content_type or not content_type.lower().startswith("multipart/form-data"): + if not http_ctx.is_multipart(): http_ctx.terminate(400) return "text/plain", "Bad request" diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index c21f264..86af41b 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -15,16 +15,16 @@ from pyrobusta.utils.helpers import add_method -def _generate_multipart_response(self, _, callback: callable, dtype: str): +def generate_multipart_response(self, callback: callable, dtype: str): """ Generate multipart response depening on the exact content type. The callback function is called without arguments, and it must return bytes-like objects. :param callback: function for part generation, each call generates a separate part :param dtype: exact multipart content-type (multipart/*) """ - if type(callback).__name__ not in ("function", "closure"): + if not callable(callback): raise ValueError("Invalid response handler") - self.terminate(200, True) + boundary = self.MULTIPART_BOUNDARY self.set_response_header( b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary @@ -134,12 +134,16 @@ def _parse_complete_part_st(self, rx): # Process complete part if not is_final: - callback(self, (part_headers, part_body)) + callback_response = callback(self, (part_headers, part_body)) if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter: raise http.MalformedRequest() self._consume_payload(rx, len(self.mp_delimiter)) self.mp_is_first = False - self.state = self._parse_boundary_st + if not self.is_terminated(): + # Proceed to next part if there is no early termination + self.state = self._parse_boundary_st + elif callback_response: + self._handle_route_response(callback_response) return # Process last part @@ -155,11 +159,9 @@ def _parse_complete_part_st(self, rx): self._consume_payload(rx, 0, last=True) self.mp_is_last = True - dtype, data = callback(self, (part_headers, part_body)) + callback_response = callback(self, (part_headers, part_body)) - if not self.is_terminated(): - self.terminate(200, True) - self.set_response_body(data, dtype) + self._handle_route_response(callback_response) def apply_patches(): @@ -178,7 +180,7 @@ def new_init(self, *args, **kwargs): http.HttpEngine.__init__ = new_init http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary" - add_method(http.HttpEngine, _generate_multipart_response) + add_method(http.HttpEngine, generate_multipart_response) add_method(http.HttpEngine, _multipart_wrapper_factory, "static") add_method(http.HttpEngine, _start_multipart_parser_st) add_method(http.HttpEngine, _parse_boundary_st) diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 2d85271..7214a41 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -121,7 +121,14 @@ def test_path_segment_validation(self): def test_file_path_validation(self): valid_paths = ["/file", "/dir1/file", "/dir-2/file", "/dir_3/file"] - invalid_paths = ["file", "dir1/file", "/dir\\segment/file"] + invalid_paths = [ + "file", + "dir1/file", + "/dir\\segment/file", + "/", + "/dir/", + "/dir/file/", + ] for path in valid_paths: self.assertTrue(self.helpers_module.is_file_path_valid(path)) diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index a02907c..db8ad5d 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -319,9 +319,7 @@ def test_single_url_encoded_query_parameter(self): self.rx.write(request[i : i + 1]) self.engine.state(self.rx) - self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param"), "value" - ) + self.assertEqual(self.engine.get_query_param("param"), "value") def test_multiple_url_encoded_query_parameter(self): request = ( @@ -333,15 +331,15 @@ def test_multiple_url_encoded_query_parameter(self): self.engine.state(self.rx) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param1"), + self.engine.get_query_param("param1"), "value1", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param2"), + self.engine.get_query_param("param2"), "value2", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param3"), + self.engine.get_query_param("param3"), "value3", ) @@ -353,22 +351,20 @@ def test_empty_or_missing_url_encoded_query_parameter(self): self.engine.state(self.rx) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param1"), + self.engine.get_query_param("param1"), "", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "param2"), + self.engine.get_query_param("param2"), "", ) self.assertEqual( - self.engine.get_url_encoded_query_param( - self.engine.query, "param3", "default" - ), + self.engine.get_query_param("param3", "default"), "default", ) with self.assertRaises(KeyError): - self.engine.get_url_encoded_query_param(self.engine.query, "param3") + self.engine.get_query_param("param3") def test_overlapping_url_encoded_query_parameter(self): request = b"GET /api/test?data=value1&ta=value2&a=value3 HTTP/1.1\r\n" @@ -378,15 +374,15 @@ def test_overlapping_url_encoded_query_parameter(self): self.engine.state(self.rx) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "data"), + self.engine.get_query_param("data"), "value1", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "ta"), + self.engine.get_query_param("ta"), "value2", ) self.assertEqual( - self.engine.get_url_encoded_query_param(self.engine.query, "a"), + self.engine.get_query_param("a"), "value3", ) diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py index 2bfb210..597f7ed 100644 --- a/tests/unit/test_http_file_server.py +++ b/tests/unit/test_http_file_server.py @@ -276,6 +276,23 @@ def test_file_serving_complete_file_invalid_name(self, *_): self.assertEqual(self.engine.status_code, 400) + def test_file_serving_complete_file_invalid_path(self, *_): + self.engine.url = b"/files/www/user_data/file/" + self.engine.method = b"PUT" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 28 + self.engine.headers["content-type"] = "application/octet-stream" + + self.engine.state = self.engine._app_endpoint_st + body_part = b"File uploaded for testing.\r\n" + self.rx.write(body_part) + + while self.engine.state is not None: + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 400) + def test_file_serving_chunked_file_upload(self, *_): self.engine.url = b"/files/www/user_data/chunked.txt" self.engine.method = b"PUT"