Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/13016.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed request body not being read on rejected WebSocket upgrades -- by :user:`Dreamsorcerer`.
17 changes: 13 additions & 4 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ cdef class HttpParser:
set _seen_singletons
list _raw_headers
bint _upgraded
bint _pending_upgrade
list _messages
bint _more_data_available
bint _paused
Expand Down Expand Up @@ -398,6 +399,7 @@ cdef class HttpParser:
self._response_with_body = response_with_body
self._read_until_eof = read_until_eof
self._upgraded = False
self._pending_upgrade = False
self._auto_decompress = auto_decompress
self._content_encoding = None
self._lax = False
Expand Down Expand Up @@ -482,10 +484,15 @@ cdef class HttpParser:
raise BadHttpMessage("Missing 'Host' header in request.")
h_upg = headers.get("upgrade", "")
if (upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES) or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
# https://www.rfc-editor.org/info/rfc9110/#section-7.8-15
# Defer the protocol switch until the complete request has been
# received.
self._pending_upgrade = True
else:
if upgrade and self._cparser.status_code == 101:
self._upgraded = True
# llhttp pauses for a 101 on its own; just mark the pending
# switch so feed_data returns the upgraded-protocol tail.
self._pending_upgrade = True

# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
Expand Down Expand Up @@ -644,6 +651,10 @@ cdef class HttpParser:
if errno is cparser.HPE_PAUSED_UPGRADE:
cparser.llhttp_resume_after_upgrade(self._cparser)
nb = cparser.llhttp_get_error_pos(self._cparser) - base
if self._pending_upgrade:
# A supported upgrade whose request body has now been fully read.
self._upgraded = True
self._pending_upgrade = False
elif errno is cparser.HPE_PAUSED:
cparser.llhttp_resume(self._cparser)
pos = cparser.llhttp_get_error_pos(self._cparser) - base
Expand Down Expand Up @@ -862,8 +873,6 @@ cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
pyparser._last_error = exc
return -1
else:
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
if not pyparser._response_with_body:
return 1
return 0
Expand Down
22 changes: 17 additions & 5 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def __init__(
self._lines: list[bytes] = []
self._tail = b""
self._upgraded = False
self._pending_upgrade = False
self._payload = None
self._payload_parser: HttpPayloadParser | None = None
self._payload_has_more_data = False
Expand Down Expand Up @@ -411,9 +412,7 @@ def get_content_length() -> int | None:
if SEC_WEBSOCKET_KEY1 in msg.headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)

self._upgraded = msg.upgrade and _is_supported_upgrade(
msg.headers
)
upgraded = msg.upgrade and _is_supported_upgrade(msg.headers)

method = getattr(msg, "method", self.method)
# code is only present on responses
Expand All @@ -425,8 +424,7 @@ def get_content_length() -> int | None:
method and method in EMPTY_BODY_METHODS
)
if not empty_body and (
((length is not None and length > 0) or msg.chunked)
and not self._upgraded
(length is not None and length > 0) or msg.chunked
):
payload = StreamReader(
self.protocol,
Expand All @@ -452,6 +450,10 @@ def get_content_length() -> int | None:
)
if not payload_parser.done:
self._payload_parser = payload_parser
# https://www.rfc-editor.org/info/rfc9110/#section-7.8-15
# Defer any requested upgrade until the
# complete request has been read.
self._pending_upgrade = upgraded
elif method == METH_CONNECT:
assert isinstance(msg, RawRequestMessage)
payload = StreamReader(
Expand Down Expand Up @@ -498,6 +500,11 @@ def get_content_length() -> int | None:
)
if not payload_parser.done:
self._payload_parser = payload_parser
elif upgraded:
# No body to read, so the connection switches to
# the upgraded protocol immediately.
self._upgraded = True
payload = EMPTY_PAYLOAD
else:
payload = EMPTY_PAYLOAD

Expand Down Expand Up @@ -555,6 +562,11 @@ def get_content_length() -> int | None:
start_pos = 0
data_len = len(data)
self._payload_parser = None
if self._pending_upgrade:
# Body fully read: the deferred upgrade takes effect and
# the rest of the connection is the upgraded protocol.
self._upgraded = True
self._pending_upgrade = False

if data and start_pos < data_len:
data = data[start_pos:]
Expand Down
48 changes: 48 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,54 @@ async def test_http_request_upgrade_unknown(parser: HttpRequestParser) -> None:
assert await messages[0][-1].read() == b"{}"


@pytest.mark.parametrize("chunked", (False, True), ids=("content-length", "chunked"))
async def test_http_request_upgrade_with_body_read(
parser: HttpRequestParser, chunked: bool
) -> None:
body_request = b"foobarbaz\r\n\r\n"
if chunked:
framing = b"Transfer-Encoding: chunked\r\n"
body = b"%x\r\n%s\r\n0\r\n\r\n" % (len(body_request), body_request)
else:
framing = b"Content-Length: %d\r\n" % len(body_request)
body = body_request
after = b"GET /after HTTP/1.1\r\nHost: a\r\n\r\n"
text = (
b"GET /ws HTTP/1.1\r\nHost: a\r\n"
b"Connection: Upgrade\r\nUpgrade: websocket\r\n"
+ framing
+ b"\r\n"
+ body
+ after
)
messages, upgrade, tail = parser.feed_data(text)
assert len(messages) == 1
msg, payload = messages[0]
assert msg.method == "GET"
assert msg.path == "/ws"
assert msg.upgrade
assert await payload.read() == body_request
# The connection switches protocols only after the body is fully read.
assert upgrade
assert tail == after


def test_http_request_upgrade_empty_body_allowed(parser: HttpRequestParser) -> None:
text = (
b"GET /ws HTTP/1.1\r\n"
b"Host: a\r\n"
b"Connection: Upgrade\r\n"
b"Upgrade: websocket\r\n"
b"Content-Length: 0\r\n\r\n"
b"some raw data"
)
messages, upgrade, tail = parser.feed_data(text)
msg = messages[0][0]
assert msg.upgrade
assert upgrade
assert tail == b"some raw data"


@pytest.fixture
def xfail_c_parser_url(request: pytest.FixtureRequest) -> None:
if isinstance(request.getfixturevalue("parser"), HttpRequestParserPy):
Expand Down
51 changes: 51 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,57 @@ def raw_get(path: str) -> bytes:
assert len(handled) == pipelined_requests + 1


async def test_declined_websocket_upgrade_reads_body(
aiohttp_server: AiohttpServer,
) -> None:
body_read = b""

async def ws_handler(request: web.Request) -> web.Response:
nonlocal body_read
# Decline the upgrade; read the body as a normal handler may.
body_read = await request.read()
return web.Response(text="declined")

async def after_handler(request: web.Request) -> web.Response:
return web.Response(text="after")

app = web.Application()
app.router.add_get("/ws", ws_handler)
app.router.add_get("/after", after_handler)
server = await aiohttp_server(app)

body = b"FooBarBaz\r\n\r\n"
# Use raw connection in order to pipeline requests.
pipeline = (
(
b"GET /ws HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Connection: Upgrade\r\n"
b"Upgrade: websocket\r\n"
b"Content-Length: %d\r\n\r\n" % len(body)
)
+ body
+ b"GET /after HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
)

reader, writer = await asyncio.open_connection(server.host, server.port)
try:
writer.write(pipeline)
await writer.drain()
# The trailing request sends Connection: close, so the server closes
# once it has answered it -- reading to EOF gathers both responses.
response = await asyncio.wait_for(reader.read(), 5)
finally:
writer.close()
with suppress(ConnectionResetError, BrokenPipeError):
await writer.wait_closed()

assert body_read == body
# Both the upgrade request and the pipelined request were served.
assert response.count(b"HTTP/1.") == 2, response
assert response.count(b" 200 ") == 2, response


@pytest.mark.parametrize("decompressed_size", [4 * 1024 * 1024, 32 * 1024 * 1024])
async def test_unread_compressed_body_drain_is_bounded(
aiohttp_server: AiohttpServer,
Expand Down
Loading