diff --git a/CHANGES/13016.bugfix.rst b/CHANGES/13016.bugfix.rst new file mode 100644 index 00000000000..a984f64e333 --- /dev/null +++ b/CHANGES/13016.bugfix.rst @@ -0,0 +1 @@ +Fixed request body not being read on rejected WebSocket upgrades -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 825e5238b0b..475b3263764 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -322,6 +322,7 @@ cdef class HttpParser: set _seen_singletons list _raw_headers bint _upgraded + bint _pending_upgrade list _messages bint _more_data_available bint _paused @@ -398,6 +399,7 @@ cdef class HttpParser: self._response_with_body = response_with_body self._read_until_eof = read_until_eof self._upgraded = False + self._pending_upgrade = False self._auto_decompress = auto_decompress self._content_encoding = None self._lax = False @@ -482,10 +484,15 @@ cdef class HttpParser: raise BadHttpMessage("Missing 'Host' header in request.") h_upg = headers.get("upgrade", "") if (upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES) or self._cparser.method == cparser.HTTP_CONNECT: - self._upgraded = True + # https://www.rfc-editor.org/info/rfc9110/#section-7.8-15 + # Defer the protocol switch until the complete request has been + # received. + self._pending_upgrade = True else: if upgrade and self._cparser.status_code == 101: - self._upgraded = True + # llhttp pauses for a 101 on its own; just mark the pending + # switch so feed_data returns the upgraded-protocol tail. + self._pending_upgrade = True # do not support old websocket spec if SEC_WEBSOCKET_KEY1 in headers: @@ -644,6 +651,10 @@ cdef class HttpParser: if errno is cparser.HPE_PAUSED_UPGRADE: cparser.llhttp_resume_after_upgrade(self._cparser) nb = cparser.llhttp_get_error_pos(self._cparser) - base + if self._pending_upgrade: + # A supported upgrade whose request body has now been fully read. + self._upgraded = True + self._pending_upgrade = False elif errno is cparser.HPE_PAUSED: cparser.llhttp_resume(self._cparser) pos = cparser.llhttp_get_error_pos(self._cparser) - base @@ -862,8 +873,6 @@ cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1: pyparser._last_error = exc return -1 else: - if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT: - return 2 if not pyparser._response_with_body: return 1 return 0 diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 6abbe04feae..ae158e7e223 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -283,6 +283,7 @@ def __init__( self._lines: list[bytes] = [] self._tail = b"" self._upgraded = False + self._pending_upgrade = False self._payload = None self._payload_parser: HttpPayloadParser | None = None self._payload_has_more_data = False @@ -411,9 +412,7 @@ def get_content_length() -> int | None: if SEC_WEBSOCKET_KEY1 in msg.headers: raise InvalidHeader(SEC_WEBSOCKET_KEY1) - self._upgraded = msg.upgrade and _is_supported_upgrade( - msg.headers - ) + upgraded = msg.upgrade and _is_supported_upgrade(msg.headers) method = getattr(msg, "method", self.method) # code is only present on responses @@ -425,8 +424,7 @@ def get_content_length() -> int | None: method and method in EMPTY_BODY_METHODS ) if not empty_body and ( - ((length is not None and length > 0) or msg.chunked) - and not self._upgraded + (length is not None and length > 0) or msg.chunked ): payload = StreamReader( self.protocol, @@ -452,6 +450,10 @@ def get_content_length() -> int | None: ) if not payload_parser.done: self._payload_parser = payload_parser + # https://www.rfc-editor.org/info/rfc9110/#section-7.8-15 + # Defer any requested upgrade until the + # complete request has been read. + self._pending_upgrade = upgraded elif method == METH_CONNECT: assert isinstance(msg, RawRequestMessage) payload = StreamReader( @@ -498,6 +500,11 @@ def get_content_length() -> int | None: ) if not payload_parser.done: self._payload_parser = payload_parser + elif upgraded: + # No body to read, so the connection switches to + # the upgraded protocol immediately. + self._upgraded = True + payload = EMPTY_PAYLOAD else: payload = EMPTY_PAYLOAD @@ -555,6 +562,11 @@ def get_content_length() -> int | None: start_pos = 0 data_len = len(data) self._payload_parser = None + if self._pending_upgrade: + # Body fully read: the deferred upgrade takes effect and + # the rest of the connection is the upgraded protocol. + self._upgraded = True + self._pending_upgrade = False if data and start_pos < data_len: data = data[start_pos:] diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index e61af88cc16..7f95b491b08 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1621,6 +1621,54 @@ async def test_http_request_upgrade_unknown(parser: HttpRequestParser) -> None: assert await messages[0][-1].read() == b"{}" +@pytest.mark.parametrize("chunked", (False, True), ids=("content-length", "chunked")) +async def test_http_request_upgrade_with_body_read( + parser: HttpRequestParser, chunked: bool +) -> None: + body_request = b"foobarbaz\r\n\r\n" + if chunked: + framing = b"Transfer-Encoding: chunked\r\n" + body = b"%x\r\n%s\r\n0\r\n\r\n" % (len(body_request), body_request) + else: + framing = b"Content-Length: %d\r\n" % len(body_request) + body = body_request + after = b"GET /after HTTP/1.1\r\nHost: a\r\n\r\n" + text = ( + b"GET /ws HTTP/1.1\r\nHost: a\r\n" + b"Connection: Upgrade\r\nUpgrade: websocket\r\n" + + framing + + b"\r\n" + + body + + after + ) + messages, upgrade, tail = parser.feed_data(text) + assert len(messages) == 1 + msg, payload = messages[0] + assert msg.method == "GET" + assert msg.path == "/ws" + assert msg.upgrade + assert await payload.read() == body_request + # The connection switches protocols only after the body is fully read. + assert upgrade + assert tail == after + + +def test_http_request_upgrade_empty_body_allowed(parser: HttpRequestParser) -> None: + text = ( + b"GET /ws HTTP/1.1\r\n" + b"Host: a\r\n" + b"Connection: Upgrade\r\n" + b"Upgrade: websocket\r\n" + b"Content-Length: 0\r\n\r\n" + b"some raw data" + ) + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.upgrade + assert upgrade + assert tail == b"some raw data" + + @pytest.fixture def xfail_c_parser_url(request: pytest.FixtureRequest) -> None: if isinstance(request.getfixturevalue("parser"), HttpRequestParserPy): diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index b8488588906..73dd094d89e 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1847,6 +1847,57 @@ def raw_get(path: str) -> bytes: assert len(handled) == pipelined_requests + 1 +async def test_declined_websocket_upgrade_reads_body( + aiohttp_server: AiohttpServer, +) -> None: + body_read = b"" + + async def ws_handler(request: web.Request) -> web.Response: + nonlocal body_read + # Decline the upgrade; read the body as a normal handler may. + body_read = await request.read() + return web.Response(text="declined") + + async def after_handler(request: web.Request) -> web.Response: + return web.Response(text="after") + + app = web.Application() + app.router.add_get("/ws", ws_handler) + app.router.add_get("/after", after_handler) + server = await aiohttp_server(app) + + body = b"FooBarBaz\r\n\r\n" + # Use raw connection in order to pipeline requests. + pipeline = ( + ( + b"GET /ws HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: Upgrade\r\n" + b"Upgrade: websocket\r\n" + b"Content-Length: %d\r\n\r\n" % len(body) + ) + + body + + b"GET /after HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ) + + reader, writer = await asyncio.open_connection(server.host, server.port) + try: + writer.write(pipeline) + await writer.drain() + # The trailing request sends Connection: close, so the server closes + # once it has answered it -- reading to EOF gathers both responses. + response = await asyncio.wait_for(reader.read(), 5) + finally: + writer.close() + with suppress(ConnectionResetError, BrokenPipeError): + await writer.wait_closed() + + assert body_read == body + # Both the upgrade request and the pipelined request were served. + assert response.count(b"HTTP/1.") == 2, response + assert response.count(b" 200 ") == 2, response + + @pytest.mark.parametrize("decompressed_size", [4 * 1024 * 1024, 32 * 1024 * 1024]) async def test_unread_compressed_body_drain_is_bounded( aiohttp_server: AiohttpServer,