diff --git a/.pylintrc b/.pylintrc index ce7476a..e865567 100644 --- a/.pylintrc +++ b/.pylintrc @@ -8,3 +8,4 @@ disable=E0611, [DESIGN] max-attributes=15 max-public-methods=25 +max-branches=15 diff --git a/src/pyrobusta/bindings/http_connection.py b/src/pyrobusta/bindings/http_connection.py index 7edba07..0b8236c 100644 --- a/src/pyrobusta/bindings/http_connection.py +++ b/src/pyrobusta/bindings/http_connection.py @@ -80,7 +80,8 @@ async def _run_state_machine(self): self._engine.set_response_body(b"Read error: " + str(e).encode("ascii")) # [2] process request by state machine - for _ in self._engine.run(self._recv_buf): + while True: + self._engine.run(self._recv_buf) if self._prev_state == self._engine.state: # No state transition occurred, read more data break diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index e36986f..f58a1ad 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -357,54 +357,54 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: """ Basic parser to extract HTTP/MIME headers. """ - header_lines = bytes(raw_headers).split(b"\r\n") headers = {} - for line in header_lines: - # pylint: disable=W0511 - if any(c > 127 for c in line): + start = 0 + n = len(raw_headers) + + while start < n: + end = start + colon = -1 + while end < n: + c = raw_headers[end] + if c > 127: + raise InvalidHeaders() + if c == 58 and colon == -1: + colon = end + if end + 1 < n and c == 13 and raw_headers[end + 1] == 10: + break + end += 1 + + if colon in (-1, start): raise InvalidHeaders() - if b":" not in line: - raise InvalidHeaders() - name, value = line.split(b":", 1) - if not name: - raise InvalidHeaders() - for c in name: - if ( + + for i in range(start, colon): + c = raw_headers[i] + if not ( 48 <= c <= 57 # 0-9 or 65 <= c <= 90 # A-Z or 97 <= c <= 122 # a-z or c in (45, 95) # -_ ): - continue - raise InvalidHeaders() - name = name.strip().lower().decode("ascii") - if any((c < 32 and c != 9) or c == 127 for c in value): + raise InvalidHeaders() + + name = bytes(raw_headers[start:colon]).strip(b" ").lower().decode("ascii") + value_bytes = bytes(raw_headers[colon + 1 : end]).strip(b" ") + + if any((c < 32 and c != 9) or c == 127 for c in value_bytes): raise InvalidHeaders() if name == "content-length": - value = int(value.strip()) + if not all(48 <= c <= 57 for c in value_bytes): + raise InvalidHeaders() + value = int(value_bytes) else: - value = value.strip().decode("ascii") + value = value_bytes.decode("ascii") if name not in headers and value: headers[name] = value elif value: headers[name] += ", " + value # Combined field value - return headers - @classmethod - def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]: - """ - Parse part headers and body and return them as a tuple. - """ - blank_idx = -1 - for i in range(len(part) - 3): - if part[i : i + 4] == b"\r\n\r\n": - blank_idx = i - break - if blank_idx == -1: - raise InvalidHeaders() - headers = cls._parse_headers(part[:blank_idx]) - body = part[blank_idx + 4 :] - return headers, body + start = end + 2 + return headers # ========================================= # Helpers for state machine termination @@ -469,8 +469,6 @@ def set_response_body( :param body: body to be sent in the response :param content_type: content-type of the body """ - self._unset_response_handler() - if not body: body_encoded = b"" if isinstance(body, (bytes, bytearray, memoryview)): @@ -486,20 +484,17 @@ def set_response_body( b"content-length", str(len(body_encoded)).encode("ascii") ) + # Unset and clean up existing handler if set + if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"): + self.resp_handler.close() + self.resp_handler = None + if len(body_encoded): self.set_response_header(b"content-type", content_type.encode("ascii")) if self.method != self.HEAD: self.resp_handler = BytesIO(body_encoded) - def _unset_response_handler(self): - """ - Unset the response handler (if set). - """ - if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"): - self.resp_handler.close() - self.resp_handler = None - def do_keep_alive(self): """ Determine if the connection should be kept alive @@ -553,7 +548,7 @@ def abort(self, status_code: int): :param status_code: HTTP status code """ self.resp_headers = [] - self._unset_response_handler() + self.set_response_body(b"") self.terminate(status_code) def is_request_empty(self): @@ -572,15 +567,13 @@ def run(self, rx): """ Run the state machine, consuming the content of a request buffer (rx). Unlike individual states, this method does not raise an exception. - This method yields on every state transition allowing the calling side + This method returns on every state transition allowing the calling side to flush the response buffer. """ if self.is_terminated(): return try: - while not self.is_terminated(): - self.state(rx) - yield + self.state(rx) except BufferFullError: self.abort(500) self.set_response_body(b"Buffer full") diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index a5d96ec..e563efc 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -12,7 +12,7 @@ # pylint: disable=W0212,R0401 from pyrobusta.protocol import http -from pyrobusta.utils.helpers import add_method, add_property, patch_extra_property +from pyrobusta.utils.patch import add_method, add_property, patch_extra_property def generate_multipart_response(self, callback: callable, dtype: str): @@ -147,6 +147,7 @@ def _parse_boundary_st(self, rx): return if is_last and self.content_len_cnt + rx.size() < self.headers["content-length"]: + # Wait for optional trailing newline return self.state = self._parse_complete_part_st @@ -165,7 +166,17 @@ def _parse_complete_part_st(self, rx): and rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter ) - part_headers, part_body = http.HttpEngine._parse_body_part(part) + # Parse part headers and part body + blank_idx = -1 + for i in range(len(part) - 3): + if part[i : i + 4] == b"\r\n\r\n": + blank_idx = i + break + if blank_idx == -1: + raise http.InvalidHeaders() + part_headers = http.HttpEngine._parse_headers(part[:blank_idx]) + part_body = part[blank_idx + 4 :] + handler = http.HttpEngine._get_handler(self.url, self.method) # Process complete part diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py index 8337a8a..11fcf90 100644 --- a/src/pyrobusta/utils/helpers.py +++ b/src/pyrobusta/utils/helpers.py @@ -77,45 +77,3 @@ def is_path_segment_valid(filename: str): ): return False return True - - -def add_method(cls, func: callable, method_type="instance"): - """ - Helper to patch/extend classes with additional methods and states. - :param func: function to add - :param method_type: type of the method (instance, static, class) - """ - if method_type == "instance": - setattr(cls, func.__name__, func) - elif method_type == "static": - setattr(cls, func.__name__, staticmethod(func)) - elif method_type == "class": - setattr(cls, func.__name__, classmethod(func)) - else: - raise ValueError("Invalid type") - - -def add_property(cls, getter: callable, setter: callable = None): - """ - Add a property to a class. - """ - setattr(cls, getter.__name__, property(getter, setter)) - - -# pylint: disable=W0212 -def patch_extra_property(cls, name): - """ - Add a property to 'cls' that stores its value in the instance's - '_extras' dictionary. Intended for '__slots__' classes that cannot - have arbitrary instance attributes. - """ - - def getter(self): - return self._extras.get(name) if self._extras else None - - def setter(self, value): - if self._extras is None: - self._extras = {} - self._extras[name] = value - - setattr(cls, name, property(getter, setter)) diff --git a/src/pyrobusta/utils/patch.py b/src/pyrobusta/utils/patch.py new file mode 100644 index 0000000..32175af --- /dev/null +++ b/src/pyrobusta/utils/patch.py @@ -0,0 +1,46 @@ +""" +Helper methods for patching classes +""" + +# pylint: disable=W0212 + + +def add_method(cls, func: callable, method_type="instance"): + """ + Helper to patch/extend classes with additional methods and states. + :param func: function to add + :param method_type: type of the method (instance, static, class) + """ + if method_type == "instance": + setattr(cls, func.__name__, func) + elif method_type == "static": + setattr(cls, func.__name__, staticmethod(func)) + elif method_type == "class": + setattr(cls, func.__name__, classmethod(func)) + else: + raise ValueError("Invalid type") + + +def add_property(cls, getter: callable, setter: callable = None): + """ + Add a property to a class. + """ + setattr(cls, getter.__name__, property(getter, setter)) + + +def patch_extra_property(cls, name): + """ + Add a property to 'cls' that stores its value in the instance's + '_extras' dictionary. Intended for '__slots__' classes that cannot + have arbitrary instance attributes. + """ + + def getter(self): + return self._extras.get(name) if self._extras else None + + def setter(self, value): + if self._extras is None: + self._extras = {} + self._extras[name] = value + + setattr(cls, name, property(getter, setter)) diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index 57c1994..6cb3833 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -184,7 +184,6 @@ def test_header_parsing_incomplete_header(self): def test_header_parsing_error(self): for case in ( - b"", b":", b": value", b" leading-space: value",