Skip to content

Commit b0c4aec

Browse files
committed
fix: Fix zstd decompression of multi-frame responses
ZstdDecompressor is one-shot-per-frame: once a frame ends, subsequent decompress() calls raise EOFError. This broke HTTP responses where the server sends multiple zstd frames (common with chunked transfer encoding). Detect frame boundaries via eof/unused_data attributes and create fresh decompressor instances for subsequent frames.
1 parent 993989c commit b0c4aec

5 files changed

Lines changed: 141 additions & 2 deletions

File tree

CHANGES/12234.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixed zstd decompression failing with ``ClientPayloadError`` when the server
2+
sends a response as multiple zstd frames -- by :user:`josu-moreno`.

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Jordan Borean
214214
Josep Cugat
215215
Josh Junon
216216
Joshu Coats
217+
Josu Moreno
217218
Julia Tsemusheva
218219
Julien Duponchelle
219220
Jungkook Park

aiohttp/compression_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,28 @@ def decompress_sync(
342342
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
343343
else max_length
344344
)
345-
return self._obj.decompress(data, zstd_max_length)
345+
result = self._obj.decompress(data, zstd_max_length)
346+
347+
# Handle multi-frame zstd streams (RFC 8878 §3.1.1):
348+
# ZstdDecompressor handles one frame only. When a frame ends,
349+
# eof becomes True and any trailing data goes to unused_data.
350+
# We create a fresh decompressor to continue with the next frame.
351+
while self._obj.eof and self._obj.unused_data:
352+
unused = self._obj.unused_data
353+
self._obj = ZstdDecompressor()
354+
if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED:
355+
zstd_max_length -= len(result)
356+
if zstd_max_length <= 0:
357+
break
358+
result += self._obj.decompress(unused, zstd_max_length)
359+
360+
# Frame ended exactly at chunk boundary — no unused_data, but the
361+
# next feed_data() call would fail on the spent decompressor.
362+
# Prepare a fresh one for the next chunk.
363+
if self._obj.eof:
364+
self._obj = ZstdDecompressor()
365+
366+
return result
346367

347368
def flush(self) -> bytes:
348369
return b""

tests/test_compression_utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
11
"""Tests for compression utils."""
22

3+
import sys
4+
35
import pytest
46

5-
from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor
7+
from aiohttp.compression_utils import (
8+
ZLibBackend,
9+
ZLibCompressor,
10+
ZLibDecompressor,
11+
ZSTDDecompressor,
12+
)
13+
14+
try:
15+
if sys.version_info >= (3, 14):
16+
import compression.zstd as zstandard # noqa: I900
17+
else:
18+
import backports.zstd as zstandard
19+
except ImportError: # pragma: no cover
20+
zstandard = None # type: ignore[assignment]
621

722

823
@pytest.mark.usefixtures("parametrize_zlib_backend")
@@ -33,3 +48,30 @@ async def test_compression_round_trip_in_event_loop() -> None:
3348
compressed_data = await compressor.compress(data) + compressor.flush()
3449
decompressed_data = await decompressor.decompress(compressed_data)
3550
assert data == decompressed_data
51+
52+
53+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
54+
def test_zstd_multi_frame_unlimited() -> None:
55+
d = ZSTDDecompressor()
56+
frame1 = zstandard.compress(b"AAAA")
57+
frame2 = zstandard.compress(b"BBBB")
58+
result = d.decompress_sync(frame1 + frame2)
59+
assert result == b"AAAABBBB"
60+
61+
62+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
63+
def test_zstd_multi_frame_max_length_partial() -> None:
64+
d = ZSTDDecompressor()
65+
frame1 = zstandard.compress(b"AAAA")
66+
frame2 = zstandard.compress(b"BBBB")
67+
result = d.decompress_sync(frame1 + frame2, max_length=6)
68+
assert result == b"AAAABB"
69+
70+
71+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
72+
def test_zstd_multi_frame_max_length_exhausted() -> None:
73+
d = ZSTDDecompressor()
74+
frame1 = zstandard.compress(b"AAAA")
75+
frame2 = zstandard.compress(b"BBBB")
76+
result = d.decompress_sync(frame1 + frame2, max_length=4)
77+
assert result == b"AAAA"

tests/test_http_parser.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,6 +2081,79 @@ async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None:
20812081
assert b"zstd data" == out._buffer[0]
20822082
assert out.is_eof()
20832083

2084+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2085+
async def test_http_payload_zstandard_multi_frame(
2086+
self, protocol: BaseProtocol
2087+
) -> None:
2088+
frame1 = zstandard.compress(b"first")
2089+
frame2 = zstandard.compress(b"second")
2090+
payload = frame1 + frame2
2091+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2092+
p = HttpPayloadParser(
2093+
out,
2094+
length=len(payload),
2095+
compression="zstd",
2096+
headers_parser=HeadersParser(),
2097+
)
2098+
p.feed_data(payload)
2099+
assert b"firstsecond" == b"".join(out._buffer)
2100+
assert out.is_eof()
2101+
2102+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2103+
async def test_http_payload_zstandard_multi_frame_chunked(
2104+
self, protocol: BaseProtocol
2105+
) -> None:
2106+
frame1 = zstandard.compress(b"chunk1")
2107+
frame2 = zstandard.compress(b"chunk2")
2108+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2109+
p = HttpPayloadParser(
2110+
out,
2111+
length=len(frame1) + len(frame2),
2112+
compression="zstd",
2113+
headers_parser=HeadersParser(),
2114+
)
2115+
p.feed_data(frame1)
2116+
p.feed_data(frame2)
2117+
assert b"chunk1chunk2" == b"".join(out._buffer)
2118+
assert out.is_eof()
2119+
2120+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2121+
async def test_http_payload_zstandard_frame_split_mid_chunk(
2122+
self, protocol: BaseProtocol
2123+
) -> None:
2124+
frame1 = zstandard.compress(b"AAAA")
2125+
frame2 = zstandard.compress(b"BBBB")
2126+
combined = frame1 + frame2
2127+
split_point = len(frame1) + 3 # 3 bytes into frame2
2128+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2129+
p = HttpPayloadParser(
2130+
out,
2131+
length=len(combined),
2132+
compression="zstd",
2133+
headers_parser=HeadersParser(),
2134+
)
2135+
p.feed_data(combined[:split_point])
2136+
p.feed_data(combined[split_point:])
2137+
assert b"AAAABBBB" == b"".join(out._buffer)
2138+
assert out.is_eof()
2139+
2140+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2141+
async def test_http_payload_zstandard_many_small_frames(
2142+
self, protocol: BaseProtocol
2143+
) -> None:
2144+
parts = [f"part{i}".encode() for i in range(10)]
2145+
payload = b"".join(zstandard.compress(p) for p in parts)
2146+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2147+
p = HttpPayloadParser(
2148+
out,
2149+
length=len(payload),
2150+
compression="zstd",
2151+
headers_parser=HeadersParser(),
2152+
)
2153+
p.feed_data(payload)
2154+
assert b"".join(parts) == b"".join(out._buffer)
2155+
assert out.is_eof()
2156+
20842157

20852158
class TestDeflateBuffer:
20862159
async def test_feed_data(self, protocol: BaseProtocol) -> None:

0 commit comments

Comments
 (0)