Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/12234.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed zstd decompression failure for multi-frame responses (e.g. chunked transfer encoding).
-- by :user:`rootvector2`
52 changes: 51 additions & 1 deletion aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,57 @@ def decompress_sync(
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
else max_length
)
return self._obj.decompress(data, zstd_max_length)

# The zstd streaming API raises EOFError when trying to decompress
# data after reaching frame EOF. If there is additional compressed
# data, it belongs to a new frame and must be handled by a fresh
# decompressor instance.
decompressed_chunks: list[bytes] = []
total_size = 0
pending_data = data
stalled_pending_data: bytes | None = None

while True:
chunk_max_length = (
zstd_max_length
if zstd_max_length < 0
else max(0, zstd_max_length - total_size)
)
if chunk_max_length == 0:
break

try:
chunk = self._obj.decompress(pending_data, chunk_max_length)
except EOFError:
if not pending_data:
raise
self._obj = ZstdDecompressor()
continue

if chunk:
decompressed_chunks.append(chunk)
total_size += len(chunk)

if zstd_max_length >= 0 and total_size >= zstd_max_length:
break

if self._obj.unused_data:
if not chunk and self._obj.unused_data == pending_data:
if stalled_pending_data == pending_data:
raise EOFError(
"Stalled while decoding zstd frames: "
"unused_data did not shrink"
)
stalled_pending_data = pending_data
else:
stalled_pending_data = None
pending_data = self._obj.unused_data
self._obj = ZstdDecompressor()
continue

break

return b"".join(decompressed_chunks)

def flush(self) -> bytes:
return b""
55 changes: 54 additions & 1 deletion tests/test_compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pytest

from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor
from aiohttp.compression_utils import (
ZLibBackend,
ZLibCompressor,
ZLibDecompressor,
ZSTDDecompressor,
)


@pytest.mark.usefixtures("parametrize_zlib_backend")
Expand Down Expand Up @@ -33,3 +38,51 @@ async def test_compression_round_trip_in_event_loop() -> None:
compressed_data = await compressor.compress(data) + compressor.flush()
decompressed_data = await decompressor.decompress(compressed_data)
assert data == decompressed_data


def test_zstd_decompressor_stalled_unused_data_raises(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class StallingZstdDecompressor:
def __init__(self) -> None:
self.unused_data = b""

def decompress(self, data: bytes, max_length: int) -> bytes:
self.unused_data = data
return b""

monkeypatch.setattr("aiohttp.compression_utils.HAS_ZSTD", True)
monkeypatch.setattr(
"aiohttp.compression_utils.ZstdDecompressor", StallingZstdDecompressor
)

decompressor = ZSTDDecompressor()
with pytest.raises(EOFError, match="unused_data did not shrink"):
decompressor.decompress_sync(b"malformed")


def test_zstd_decompressor_allows_single_unchanged_unused_data_rollover(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class SingleRolloverZstdDecompressor:
_calls = 0

def __init__(self) -> None:
self.unused_data = b""

def decompress(self, data: bytes, max_length: int) -> bytes:
type(self)._calls += 1
if type(self)._calls == 1:
self.unused_data = data
return b""

self.unused_data = b""
return b"decoded"

monkeypatch.setattr("aiohttp.compression_utils.HAS_ZSTD", True)
monkeypatch.setattr(
"aiohttp.compression_utils.ZstdDecompressor", SingleRolloverZstdDecompressor
)

decompressor = ZSTDDecompressor()
assert decompressor.decompress_sync(b"frame") == b"decoded"
89 changes: 89 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,6 +2163,95 @@ async def test_feed_eof_no_err_zstandard(self, protocol: BaseProtocol) -> None:
dbuf.feed_eof()
assert [b"line"] == list(buf._buffer)

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_feed_data_zstd_multiple_frames(self, protocol: BaseProtocol) -> None:
assert zstandard is not None
payload1 = b"A" * 50_000
payload2 = b"B" * 50_000

compressor = zstandard.ZstdCompressor()
frame1 = compressor.compress(payload1) + compressor.flush()
compressor = zstandard.ZstdCompressor()
frame2 = compressor.compress(payload2) + compressor.flush()

buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
dbuf = DeflateBuffer(buf, "zstd")

dbuf.feed_data(frame1)
dbuf.feed_data(frame2)
dbuf.feed_eof()

assert b"".join(buf._buffer) == payload1 + payload2

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_feed_data_zstd_partial_frame_across_chunks(
self, protocol: BaseProtocol
) -> None:
assert zstandard is not None
payload = b"partial-frame-data-" * 8192

compressor = zstandard.ZstdCompressor()
frame = compressor.compress(payload) + compressor.flush()
split = len(frame) // 2

buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
dbuf = DeflateBuffer(buf, "zstd")

dbuf.feed_data(frame[:split])
dbuf.feed_data(frame[split:])
dbuf.feed_eof()

assert b"".join(buf._buffer) == payload

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_feed_data_zstd_multiple_frames_single_chunk(
self, protocol: BaseProtocol
) -> None:
assert zstandard is not None
payload1 = b"frame-1-" * 4096
payload2 = b"frame-2-" * 4096
payload3 = b"frame-3-" * 4096

compressor = zstandard.ZstdCompressor()
frame1 = compressor.compress(payload1) + compressor.flush()
compressor = zstandard.ZstdCompressor()
frame2 = compressor.compress(payload2) + compressor.flush()
compressor = zstandard.ZstdCompressor()
frame3 = compressor.compress(payload3) + compressor.flush()

buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
dbuf = DeflateBuffer(buf, "zstd")

dbuf.feed_data(frame1 + frame2 + frame3)
dbuf.feed_eof()

assert b"".join(buf._buffer) == payload1 + payload2 + payload3

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_feed_data_zstd_mixed_small_large_frames(
self, protocol: BaseProtocol
) -> None:
assert zstandard is not None
small = b"s"
medium = b"m" * 2048
large = b"L" * (2**20)

compressor = zstandard.ZstdCompressor()
frame_small = compressor.compress(small) + compressor.flush()
compressor = zstandard.ZstdCompressor()
frame_medium = compressor.compress(medium) + compressor.flush()
compressor = zstandard.ZstdCompressor()
frame_large = compressor.compress(large) + compressor.flush()

buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
dbuf = DeflateBuffer(buf, "zstd")

dbuf.feed_data(frame_small + frame_medium)
dbuf.feed_data(frame_large)
dbuf.feed_eof()

assert b"".join(buf._buffer) == small + medium + large

async def test_empty_body(self, protocol: BaseProtocol) -> None:
buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
dbuf = DeflateBuffer(buf, "deflate")
Expand Down
Loading