diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 2042937..e36986f 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -66,11 +66,7 @@ class HttpEngine: "recv_chunk_size", "is_req_empty", "_is_req_complete", - "mp_boundary", - "mp_is_first", - "mp_is_last", - "mp_delimiter", - "mp_last_delimiter", + "_extras", ) ROUTES = [] # (route, handler, HTTP method) @@ -93,6 +89,8 @@ class HttpEngine: b"408 Request Timeout", 413, b"413 Content Too Large", + 415, + b"415 Unsupported Media Type", 500, b"500 Internal Server Error", 503, @@ -163,8 +161,8 @@ def __init__(self): self.is_req_empty = True self._is_req_complete = False - # [Multipart state] - self.mp_boundary = None + # [Extras] + self._extras = None def reset(self): """ @@ -184,7 +182,7 @@ def reset(self): self.recv_chunk_size = 0 self.is_req_empty = True self._is_req_complete = False - self.mp_boundary = None + self._extras = None # ========================================= # Methods/decorators for routing @@ -364,12 +362,12 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: for line in header_lines: # pylint: disable=W0511 if any(c > 127 for c in line): - raise InvalidHeaders("Non-ASCII character") + raise InvalidHeaders() if b":" not in line: raise InvalidHeaders() name, value = line.split(b":", 1) if not name: - raise InvalidHeaders("Empty header name") + raise InvalidHeaders() for c in name: if ( 48 <= c <= 57 # 0-9 @@ -378,10 +376,10 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: or c in (45, 95) # -_ ): continue - raise InvalidHeaders("Invalid header name") + raise InvalidHeaders() name = name.strip().lower().decode("ascii") if any((c < 32 and c != 9) or c == 127 for c in value): - raise InvalidHeaders("Invalid header value") + raise InvalidHeaders() if name == "content-length": value = int(value.strip()) else: @@ -392,38 +390,6 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: headers[name] += ", " + value # Combined field value return headers - @staticmethod - def _get_mp_boundary(headers: dict) -> str: - """ - Determine from the headers if a request is multipart, - and return the boundary value. - """ - content_type = headers.get("content-type") - if not content_type or not content_type.lower().startswith("multipart/"): - return None - - parts = content_type.split(";") - for part in parts[1:]: - if "=" not in part: - continue - key, value = part.strip().split("=", 1) - - if key.strip().lower() != "boundary": - continue - value = value.strip() - - if value.startswith('"'): - if len(value) < 2 or not value.endswith('"'): - raise InvalidHeaders() - value = value[1:-1] - elif value.endswith('"'): - raise InvalidHeaders() - - if not value: - raise InvalidHeaders() - return value - raise InvalidHeaders() - @classmethod def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]: """ @@ -756,9 +722,7 @@ def _route_request_st(self, _): if self.has_payload(): if self.method in (self.GET, self.HEAD): raise MalformedRequest() - if mp_boundary := self._get_mp_boundary(self.headers): - # Request body is multipart - self.mp_boundary = mp_boundary.encode("ascii") + if self.is_multipart(): self.state = self._start_multipart_parser_st elif self.is_chunked(): # Request body is chunked @@ -784,6 +748,9 @@ def _route_request_st(self, _): return # Fallback: serve file if self.method in (self.GET, self.HEAD): + if self.has_payload(): + raise MalformedRequest() + self._is_req_complete = True self.state = self._fs_retrieve_st return self.terminate(404) @@ -925,6 +892,9 @@ def _terminal_st(self, rx): # pylint: disable=W0613 ): self.set_response_header(b"content-length", b"0") + if not self.get_response_header(b"cache-control"): + self.set_response_header(b"cache-control", b"no-store") + self.state = None diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index 4c79ca9..93281b8 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -97,12 +97,12 @@ def delete_file(http_ctx, _): return "text/plain", "Directory not empty" rmdir(fs_path) http_ctx.terminate(204) - return "text/plain", "Deleted" + return "text/plain", "OK" # Delete file remove(fs_path) http_ctx.terminate(204) - return "text/plain", "Deleted" + return "text/plain", "OK" except OSError: http_ctx.terminate(404) return "text/plain", "Not found" diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index 7645ae7..a5d96ec 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -12,7 +12,7 @@ # pylint: disable=W0212,R0401 from pyrobusta.protocol import http -from pyrobusta.utils.helpers import add_method +from pyrobusta.utils.helpers import add_method, add_property, patch_extra_property def generate_multipart_response(self, callback: callable, dtype: str): @@ -23,7 +23,7 @@ def generate_multipart_response(self, callback: callable, dtype: str): :param dtype: exact multipart content-type (multipart/*) """ if not callable(callback): - raise ValueError("Invalid function callback") + raise ValueError("Invalid callback") boundary = self.MULTIPART_BOUNDARY self.set_response_header( @@ -82,6 +82,38 @@ def _multipart_wrapper(tx): return _multipart_wrapper +def _get_mp_boundary(headers: dict) -> str: + """ + Determine from the headers if a request is multipart, + and return the boundary value. + """ + content_type = headers.get("content-type") + if not content_type or not content_type.lower().startswith("multipart/"): + return None + + parts = content_type.split(";") + for part in parts[1:]: + if "=" not in part: + continue + key, value = part.strip().split("=", 1) + + if key.strip().lower() != "boundary": + continue + value = value.strip() + + if value.startswith('"'): + if len(value) < 2 or not value.endswith('"'): + raise http.InvalidHeaders() + value = value[1:-1] + elif value.endswith('"'): + raise http.InvalidHeaders() + + if not value: + raise http.InvalidHeaders() + return value + raise http.InvalidHeaders() + + def _start_multipart_parser_st(self, rx): """ Initial state for processing multipart requests. @@ -90,13 +122,17 @@ def _start_multipart_parser_st(self, rx): """ if not "content-length" in self.headers: raise http.InvalidContentLength() + + self.mp_boundary = _get_mp_boundary(self.headers).encode("ascii") + if (start_delimiter := rx.find(b"\r\n")) == -1: return - self.mp_delimiter = b"--" + self.mp_boundary + b"\r\n" - self.mp_last_delimiter = b"--" + self.mp_boundary + b"--" + if rx.peek(start_delimiter + 2) != self.mp_delimiter: raise http.MalformedRequest() self._consume_payload(rx, start_delimiter + 2) + self.mp_is_first = True + self.mp_is_last = False self.state = self._parse_boundary_st @@ -168,20 +204,29 @@ def apply_patches(): """ Apply patches to class attributes for multipart parsing. """ - orig_init = http.HttpEngine.__init__ - def new_init(self, *args, **kwargs): - orig_init(self, *args, **kwargs) - self.mp_is_first = True - self.mp_is_last = False - self.mp_delimiter = None - self.mp_last_delimiter = None + def mp_delimiter(self): + if self.mp_boundary is None: + return None + return b"--" + self.mp_boundary + b"\r\n" - http.HttpEngine.__init__ = new_init - http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary" + def mp_last_delimiter(self): + if self.mp_boundary is None: + return None + return b"--" + self.mp_boundary + b"--" + + add_property(http.HttpEngine, mp_delimiter) + add_property(http.HttpEngine, mp_last_delimiter) + + patch_extra_property(http.HttpEngine, "mp_boundary") + patch_extra_property(http.HttpEngine, "mp_is_first") + patch_extra_property(http.HttpEngine, "mp_is_last") add_method(http.HttpEngine, generate_multipart_response) + add_method(http.HttpEngine, _get_mp_boundary, "static") add_method(http.HttpEngine, _multipart_wrapper_factory, "static") add_method(http.HttpEngine, _start_multipart_parser_st) add_method(http.HttpEngine, _parse_boundary_st) add_method(http.HttpEngine, _parse_complete_part_st) + + http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary" diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py index e90f282..8337a8a 100644 --- a/src/pyrobusta/utils/helpers.py +++ b/src/pyrobusta/utils/helpers.py @@ -93,3 +93,29 @@ def add_method(cls, func: callable, method_type="instance"): setattr(cls, func.__name__, classmethod(func)) else: raise ValueError("Invalid type") + + +def add_property(cls, getter: callable, setter: callable = None): + """ + Add a property to a class. + """ + setattr(cls, getter.__name__, property(getter, setter)) + + +# pylint: disable=W0212 +def patch_extra_property(cls, name): + """ + Add a property to 'cls' that stores its value in the instance's + '_extras' dictionary. Intended for '__slots__' classes that cannot + have arbitrary instance attributes. + """ + + def getter(self): + return self._extras.get(name) if self._extras else None + + def setter(self, value): + if self._extras is None: + self._extras = {} + self._extras[name] = value + + setattr(cls, name, property(getter, setter)) diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py index 56e886b..a72c545 100644 --- a/tests/unit/test_http_file_server.py +++ b/tests/unit/test_http_file_server.py @@ -370,11 +370,9 @@ def test_file_serving_single_file_upload(self, *_): self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 151 - self.engine.headers["content-type"] = "multipart/form-data" - - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) self.engine.state = self.engine._start_multipart_parser_st body_part = ( @@ -421,11 +419,9 @@ def test_file_serving_multiple_file_upload(self, *_): self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 287 - self.engine.headers["content-type"] = "multipart/form-data" - - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) self.engine.state = self.engine._start_multipart_parser_st body_part = ( @@ -488,11 +484,9 @@ def test_file_serving_single_file_multiple_parts_upload(self, *_): self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 285 - self.engine.headers["content-type"] = "multipart/form-data" - - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) self.engine.state = self.engine._start_multipart_parser_st body_part = ( @@ -543,11 +537,9 @@ def test_file_serving_multiple_file_chunked_upload(self, *_): self.engine.method = b"POST" self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 565 - self.engine.headers["content-type"] = "multipart/form-data" - - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) self.engine.state = self.engine._start_multipart_parser_st body_part = ( diff --git a/tests/unit/test_http_multipart.py b/tests/unit/test_http_multipart.py index 5ff22f3..2c0af5c 100644 --- a/tests/unit/test_http_multipart.py +++ b/tests/unit/test_http_multipart.py @@ -16,6 +16,11 @@ def setUpClass(cls): cls.base_config = {"http_multipart": "True", "http_files_api": "False"} cls.cwd = os.getcwd() + def feed_body_part(self, body_part): + for i in range(len(body_part)): + self.rx.write(body_part[i : i + 1]) + self.engine.state(self.rx) + def test_multipart_parser(self): for case in [ ({}, None), @@ -53,12 +58,13 @@ def test_multipart_parser(self): def test_multipart_receiver_valid(self): self.engine.state = self.engine._start_multipart_parser_st self.engine.headers["content-length"] = 100 - self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) + body_part = b"--test-boundary\r\nContent-Type:text/plain" - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) self.assertEqual(self.engine.state, self.engine._parse_boundary_st) self.assertEqual(self.rx.peek(), b"Content-Type:text/plain") @@ -67,16 +73,17 @@ def test_multipart_receiver_boundary_mismatch(self): self.engine.state = self.engine._start_multipart_parser_st self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 100 - self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) + body_part = b"--test-boundary-delimiter\r\nContent-Type:text/plain" with self.assertRaises(self.http_module.MalformedRequest): - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) def test_multipart_receiver_complete_part(self): - self.engine.state = self.engine._parse_boundary_st + self.engine.state = self.engine._start_multipart_parser_st self.engine.url = b"/api/test" self.engine.method = b"GET" @@ -84,25 +91,21 @@ def test_multipart_receiver_complete_part(self): self.engine.register("/api/test", test_handler) self.engine.headers["content-length"] = 1000 - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) body_part = ( + b"--test-boundary\r\n" b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' b"Content-Type:text/plain\r\n\r\n" b"Upload content\r\n" b"--test-boundary\r\n" ) - for i in range(len(body_part)): - self.assertEqual(self.engine.state, self.engine._parse_boundary_st) - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) - self.assertEqual(self.rx.peek(), body_part) - self.assertEqual(self.engine.mp_is_first, True) self.engine.state(self.rx) @@ -120,6 +123,30 @@ def test_multipart_receiver_complete_part(self): self.assertEqual(self.engine.mp_is_first, False) self.assertEqual(self.engine.mp_is_last, False) + def test_multipart_receiver_first_part(self): + self.engine.state = self.engine._start_multipart_parser_st + self.engine.url = b"/api/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.headers["content-length"] = 131 + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) + + body_part = ( + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + ) + + self.feed_body_part(body_part) + + self.assertEqual(self.engine.state, self.engine._parse_boundary_st) + self.assertEqual(self.engine.mp_boundary, b"test-boundary") + self.assertEqual(self.engine.mp_is_first, True) + self.assertEqual(self.engine.mp_is_last, False) + def test_multipart_receiver_last_part(self): self.engine.state = self.engine._parse_boundary_st self.engine.url = b"/api/test" @@ -127,8 +154,6 @@ def test_multipart_receiver_last_part(self): self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 131 self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" test_handler = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_handler) @@ -140,10 +165,7 @@ def test_multipart_receiver_last_part(self): b"--test-boundary--" ) - for i in range(len(body_part)): - self.assertEqual(self.engine.state, self.engine._parse_boundary_st) - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) self.assertEqual(self.rx.peek(), body_part) @@ -162,7 +184,6 @@ def test_multipart_receiver_last_part(self): b"Upload content", ), ) - self.assertEqual(self.engine.mp_is_first, True) self.assertEqual(self.engine.mp_is_last, True) def test_multipart_content_length_match(self): @@ -171,7 +192,9 @@ def test_multipart_content_length_match(self): self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 148 - self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) test_handler = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_handler) @@ -184,9 +207,7 @@ def test_multipart_content_length_match(self): b"--test-boundary--" ) - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) while self.engine.state is not None: self.engine.state(self.rx) @@ -214,6 +235,9 @@ def test_multipart_content_length_smaller(self): self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 148 - 1 self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) test_handler = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_handler) @@ -226,9 +250,7 @@ def test_multipart_content_length_smaller(self): b"--test-boundary--" ) - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) with self.assertRaises(self.http_module.InvalidContentLength): while self.engine.state is not None: @@ -244,7 +266,9 @@ def test_multipart_content_length_larger(self): self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 148 + 1 - self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) test_handler = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_handler) @@ -257,9 +281,7 @@ def test_multipart_content_length_larger(self): b"--test-boundary--" ) - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) self.engine.state(self.rx) @@ -276,7 +298,9 @@ def test_multipart_epilogue_data(self): self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 148 + 13 - self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) test_handler = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_handler) @@ -289,9 +313,7 @@ def test_multipart_epilogue_data(self): b"--test-boundary--epilogue-data" ) - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) with self.assertRaises(self.http_module.InvalidContentLength): while self.engine.state is not None: @@ -303,7 +325,9 @@ def test_multipart_complete_part_trailing_crlf(self): self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["content-length"] = 150 - self.engine.mp_boundary = b"test-boundary" + self.engine.headers["content-type"] = ( + 'multipart/form-data;boundary="test-boundary"' + ) test_handler = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_handler) @@ -316,9 +340,7 @@ def test_multipart_complete_part_trailing_crlf(self): b"--test-boundary--\r\n" ) - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) + self.feed_body_part(body_part) while self.engine.state is not None: self.engine.state(self.rx)