diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 43462d0..3a8e73e 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -620,11 +620,19 @@ def _consume_payload(self, rx, size, last=False): """ Consume data from the request buffer and increment content length counter. Raise an exception if the content length is exceeded. Allow strict checking - of content length when the last flag is set. + of content length when the last flag is set. When the request is chunked, + the content length should not be set, otherwise it is ignored. """ - if "content-length" in self.headers and ( - (self.content_len_cnt + size > self.headers["content-length"]) - or (last and self.headers["content-length"] != self.content_len_cnt + size) + if ( + not self.is_chunked() + and "content-length" in self.headers + and ( + (self.content_len_cnt + size > self.headers["content-length"]) + or ( + last + and self.headers["content-length"] != self.content_len_cnt + size + ) + ) ): raise InvalidContentLength() self.content_len_cnt += size @@ -713,7 +721,9 @@ def _route_request_st(self, _): elif self.is_chunked(): # Request body is chunked if "content-length" in self.headers: - raise MalformedRequest() + # Ignore content-length as per RFC 9112, + # chunked transfer-encoding takes precedence + pass self.state = self._recv_chunk_size_st else: self.state = self._recv_payload_st diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index 0d04679..c21f264 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -29,6 +29,7 @@ def _generate_multipart_response(self, _, callback: callable, dtype: str): self.set_response_header( b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary ) + self.set_response_header(b"transfer-encoding", b"chunked") if self.method != self.HEAD: self.resp_handler = self._multipart_wrapper_factory(callback, boundary) @@ -51,25 +52,32 @@ def _multipart_wrapper(tx): :return bool: true if the stream is completed """ while True: - tx.write(delimiter) part = callback() + if not part: - tx.write(b"--") + tx.write(delimiter) + tx.write(b"--\r\n") yield True + return + content_type, part_body = part - tx.write(b"\r\n") - tx.write(b"content-type:") - tx.write(content_type.encode("ascii")) - tx.write(b"\r\n\r\n") - written = 0 - while written < len(part_body): - to_write = tx.capacity - tx.size() - if not to_write: - raise BufferError() - tx.write(part_body[written : written + to_write]) - written += to_write - yield False - tx.write(b"\r\n") + headers = ( + delimiter + + b"\r\ncontent-type:" + + content_type.encode("ascii") + + b"\r\n\r\n" + ) + + for chunk_part in (headers, part_body, b"\r\n"): + written = 0 + while written < len(chunk_part): + to_write = tx.capacity - tx.size() + if not to_write: + raise BufferError() + chunk_part = chunk_part[written : written + to_write] + tx.write(chunk_part) + written += len(chunk_part) + yield False return _multipart_wrapper