Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 118 additions & 25 deletions src/mistralai/extra/tests/test_workflow_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,32 @@
NetworkEncodedInput,
WorkflowContext,
)
from mistralai.extra.workflows.encoding.payload_encoder import PayloadEncoder
from mistralai.extra.workflows.encoding.payload_encoder import (
CompressedPayloadData,
PayloadEncoder,
)
from mistralai.extra.tests.fixtures.workflow_encoding import InMemoryBlobStorage


_COMPRESSED_TEST_PAYLOAD = CompressedPayloadData.from_payload(
b"compressed-data", ZstdCompressionConfig(level=3)
)


def _compressed_payload_json(
compressed_payload: CompressedPayloadData,
*,
invalid_compression: dict[str, object] | None = None,
invalid_base64: bool = False,
) -> bytes:
payload_data = compressed_payload.model_dump(mode="json")
if invalid_compression is not None:
payload_data["compression"] = invalid_compression
if invalid_base64:
payload_data["b64payload"] = f"{payload_data['b64payload']}!"
return json.dumps(payload_data).encode()


@pytest.fixture
def encryption_config() -> WorkflowEncodingConfig:
"""Create a test encryption config."""
Expand Down Expand Up @@ -169,15 +191,77 @@ async def test_payload_encoder_compresses_network_inputs():
)

assert encoded.encoding_options == [EncodedPayloadOptions.COMPRESSED]
assert encoded.encoding_metadata == {
"compression": {"algorithm": "zstd", "level": 3}
}
assert not encoded.get_payload().startswith(b"{")
compressed_payload = CompressedPayloadData.model_validate_json(encoded.get_payload())
assert compressed_payload.compression == ZstdCompressionConfig(level=3)

decoded = await encoder.decode_network_result(encoded.model_dump(mode="json"))
assert decoded == payload


@pytest.mark.asyncio
async def test_payload_encoder_content_keeps_two_value_contract_for_compression():
# Workflow workers use this low-level API directly from their Temporal codec.
# Keep compression self-describing without changing the two-value contract.
config = WorkflowEncodingConfig(
payload_compression=PayloadCompressionConfig(
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
)
)
encoder = PayloadEncoder(encoding_config=config)
raw = json.dumps({"data": "x" * 20_000}).encode()

encoded_data, encoding_options = await encoder.encode_payload_content(
raw, WorkflowContext(namespace="test", execution_id="exec")
)

assert isinstance(encoded_data, bytes)
assert encoding_options == [EncodedPayloadOptions.COMPRESSED]


@pytest.mark.asyncio
async def test_payload_encoder_wraps_compression_config_in_payload_content():
# Temporal metadata only carries encoding_options, so compressed bytes must
# include the algorithm config needed to decode them independently.
config = WorkflowEncodingConfig(
payload_compression=PayloadCompressionConfig(
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
)
)
encoder = PayloadEncoder(encoding_config=config)
raw = json.dumps({"data": "x" * 20_000}).encode()

encoded_data, encoding_options = await encoder.encode_payload_content(
raw, WorkflowContext(namespace="test", execution_id="exec")
)
compressed_payload = CompressedPayloadData.model_validate_json(encoded_data)

assert encoding_options == [EncodedPayloadOptions.COMPRESSED]
assert compressed_payload.compression == ZstdCompressionConfig(level=3)
assert compressed_payload.get_payload() != raw


@pytest.mark.asyncio
async def test_payload_encoder_decodes_compressed_payload_content_without_metadata():
# This mirrors Temporal payload decoding, where the codec passes only bytes
# plus encoding_options back into PayloadEncoder.
config = WorkflowEncodingConfig(
payload_compression=PayloadCompressionConfig(
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
)
)
encoder = PayloadEncoder(encoding_config=config)
raw = json.dumps({"data": "x" * 20_000}).encode()

encoded_data, encoding_options = await encoder.encode_payload_content(
raw, WorkflowContext(namespace="test", execution_id="exec")
)
decoded = await PayloadEncoder(WorkflowEncodingConfig()).decode_payload_content(
encoded_data, encoding_options
)

assert decoded == raw


@pytest.mark.asyncio
async def test_payload_encoder_skips_compression_below_min_size():
config = WorkflowEncodingConfig(
Expand Down Expand Up @@ -292,9 +376,8 @@ async def test_payload_encoder_decodes_compressed_payload_with_decoder_config(
)

assert encoded.encoding_options == [EncodedPayloadOptions.COMPRESSED]
assert encoded.encoding_metadata == {
"compression": {"algorithm": "zstd", "level": 22}
}
compressed_payload = CompressedPayloadData.model_validate_json(encoded.get_payload())
assert compressed_payload.compression == ZstdCompressionConfig(level=22)
assert decoded == payload


Expand Down Expand Up @@ -355,6 +438,8 @@ async def test_payload_encoder_decodes_encrypted_compressed_payload_with_differe

@pytest.mark.asyncio
async def test_payload_encoder_decodes_with_tampered_compression_level():
# Zstd decompression must depend on the frame data, not on the compression
# level that was used when the payload was encoded.
encoder = PayloadEncoder(
encoding_config=WorkflowEncodingConfig(
payload_compression=PayloadCompressionConfig(
Expand All @@ -367,10 +452,12 @@ async def test_payload_encoder_decodes_with_tampered_compression_level():
encoded = await encoder.encode_network_input(
payload, WorkflowContext(namespace="test", execution_id="exec")
)
compressed_payload = CompressedPayloadData.model_validate_json(encoded.get_payload())
tampered_payload = compressed_payload.model_copy(
update={"compression": ZstdCompressionConfig(level=1)}
)
tampered = NetworkEncodedInput.from_data(
encoded.get_payload(),
encoded.encoding_options,
{"compression": {"algorithm": "zstd", "level": 1}},
tampered_payload.model_dump_json().encode(), encoded.encoding_options
)

decoded = await PayloadEncoder(WorkflowEncodingConfig()).decode_network_result(
Expand All @@ -382,18 +469,25 @@ async def test_payload_encoder_decodes_with_tampered_compression_level():

@pytest.mark.asyncio
@pytest.mark.parametrize(
"encoding_metadata",
"compressed_payload",
[
{},
{"compression": {"algorithm": "lz4", "level": 1}},
{"compression": {"level": 3}},
b"compressed-data",
_compressed_payload_json(
_COMPRESSED_TEST_PAYLOAD,
invalid_compression={"algorithm": "lz4", "level": 1},
),
_compressed_payload_json(
_COMPRESSED_TEST_PAYLOAD,
invalid_compression={"level": 3},
),
_compressed_payload_json(_COMPRESSED_TEST_PAYLOAD, invalid_base64=True),
],
)
async def test_payload_encoder_invalid_compression_metadata_is_error(
encoding_metadata: dict[str, object],
async def test_payload_encoder_invalid_compressed_payload_is_error(
compressed_payload: bytes,
):
encoded = NetworkEncodedInput.from_data(
b"compressed-data", [EncodedPayloadOptions.COMPRESSED], encoding_metadata
compressed_payload, [EncodedPayloadOptions.COMPRESSED]
)

with pytest.raises(WorkflowPayloadCompressionException):
Expand All @@ -404,10 +498,12 @@ async def test_payload_encoder_invalid_compression_metadata_is_error(

@pytest.mark.asyncio
async def test_payload_encoder_corrupted_compressed_data_is_error():
compressed_payload = CompressedPayloadData.from_payload(
b"corrupted-data", ZstdCompressionConfig(level=3)
)
encoded = NetworkEncodedInput.from_data(
b"corrupted-data",
compressed_payload.model_dump_json().encode(),
[EncodedPayloadOptions.COMPRESSED],
{"compression": {"algorithm": "zstd", "level": 3}},
)

with pytest.raises(zstandard.ZstdError):
Expand Down Expand Up @@ -571,18 +667,15 @@ async def test_payload_encoder_encodes_event_content_without_offloading():
)
payload = json.dumps({"data": "x" * 20_000}).encode()

encoded, encoding_options, encoding_metadata = await encoder.encode_payload_content(
encoded, encoding_options = await encoder.encode_payload_content(
payload,
allow_offloading=False,
force_full_encryption=True,
)
decoded = await decoder.decode_payload_content(
encoded, encoding_options, encoding_metadata
)
decoded = await decoder.decode_payload_content(encoded, encoding_options)

assert encoding_options == [
EncodedPayloadOptions.COMPRESSED,
EncodedPayloadOptions.ENCRYPTED,
]
assert encoding_metadata == {"compression": {"algorithm": "zstd", "level": 3}}
assert decoded == payload
36 changes: 9 additions & 27 deletions src/mistralai/extra/workflows/encoding/payload_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

_ALGORITHM_CONFIG_ADAPTER: TypeAdapter[AlgorithmConfig] = TypeAdapter(AlgorithmConfig)
_COMPRESSION_METADATA_KEY = "compression"


class Compressor(ABC):
Expand Down Expand Up @@ -60,6 +59,14 @@ def decompress(self, data: bytes) -> bytes:
return result


def compressor_from_config(algo_config: AlgorithmConfig) -> Compressor:
if isinstance(algo_config, ZstdCompressionConfig):
return ZstdCompressor(algo_config)
raise WorkflowPayloadCompressionException(
f"Unsupported compression algorithm: {algo_config.algorithm!r}"
)


@lru_cache(maxsize=8)
def _build_compressor_for_config(config_json: str) -> Compressor:
try:
Expand All @@ -69,11 +76,7 @@ def _build_compressor_for_config(config_json: str) -> Compressor:
f"Invalid compression config in payload: {exc}"
) from exc

if isinstance(algo_config, ZstdCompressionConfig):
return ZstdCompressor(algo_config)
raise WorkflowPayloadCompressionException(
f"Unsupported compression algorithm: {algo_config.algorithm!r}"
)
return compressor_from_config(algo_config)


def build_compressor(
Expand All @@ -84,24 +87,3 @@ def build_compressor(
return _build_compressor_for_config(
compression_config.algorithm_config.model_dump_json()
)


def compression_metadata(compressor: Compressor) -> dict[str, object]:
return {
_COMPRESSION_METADATA_KEY: compressor.algorithm_config.model_dump(mode="json")
}


def compressor_from_metadata(metadata: dict[str, object]) -> Compressor:
config = metadata.get(_COMPRESSION_METADATA_KEY)
if not isinstance(config, dict):
raise WorkflowPayloadCompressionException(
"Missing compression config in payload metadata"
)
try:
algo_config = _ALGORITHM_CONFIG_ADAPTER.validate_python(config)
except ValidationError as exc:
raise WorkflowPayloadCompressionException(
f"Invalid compression config in payload metadata: {exc}"
) from exc
return _build_compressor_for_config(algo_config.model_dump_json())
Loading