diff --git a/src/mistralai/extra/tests/test_workflow_encoding.py b/src/mistralai/extra/tests/test_workflow_encoding.py index 8bab8ef1..453a4952 100644 --- a/src/mistralai/extra/tests/test_workflow_encoding.py +++ b/src/mistralai/extra/tests/test_workflow_encoding.py @@ -30,10 +30,32 @@ NetworkEncodedInput, WorkflowContext, ) -from mistralai.extra.workflows.encoding.payload_encoder import PayloadEncoder +from mistralai.extra.workflows.encoding.payload_encoder import ( + CompressedPayloadData, + PayloadEncoder, +) from mistralai.extra.tests.fixtures.workflow_encoding import InMemoryBlobStorage +_COMPRESSED_TEST_PAYLOAD = CompressedPayloadData.from_payload( + b"compressed-data", ZstdCompressionConfig(level=3) +) + + +def _compressed_payload_json( + compressed_payload: CompressedPayloadData, + *, + invalid_compression: dict[str, object] | None = None, + invalid_base64: bool = False, +) -> bytes: + payload_data = compressed_payload.model_dump(mode="json") + if invalid_compression is not None: + payload_data["compression"] = invalid_compression + if invalid_base64: + payload_data["b64payload"] = f"{payload_data['b64payload']}!" + return json.dumps(payload_data).encode() + + @pytest.fixture def encryption_config() -> WorkflowEncodingConfig: """Create a test encryption config.""" @@ -169,15 +191,77 @@ async def test_payload_encoder_compresses_network_inputs(): ) assert encoded.encoding_options == [EncodedPayloadOptions.COMPRESSED] - assert encoded.encoding_metadata == { - "compression": {"algorithm": "zstd", "level": 3} - } - assert not encoded.get_payload().startswith(b"{") + compressed_payload = CompressedPayloadData.model_validate_json(encoded.get_payload()) + assert compressed_payload.compression == ZstdCompressionConfig(level=3) decoded = await encoder.decode_network_result(encoded.model_dump(mode="json")) assert decoded == payload +@pytest.mark.asyncio +async def test_payload_encoder_content_keeps_two_value_contract_for_compression(): + # Workflow workers use this low-level API directly from their Temporal codec. + # Keep compression self-describing without changing the two-value contract. + config = WorkflowEncodingConfig( + payload_compression=PayloadCompressionConfig( + min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3) + ) + ) + encoder = PayloadEncoder(encoding_config=config) + raw = json.dumps({"data": "x" * 20_000}).encode() + + encoded_data, encoding_options = await encoder.encode_payload_content( + raw, WorkflowContext(namespace="test", execution_id="exec") + ) + + assert isinstance(encoded_data, bytes) + assert encoding_options == [EncodedPayloadOptions.COMPRESSED] + + +@pytest.mark.asyncio +async def test_payload_encoder_wraps_compression_config_in_payload_content(): + # Temporal metadata only carries encoding_options, so compressed bytes must + # include the algorithm config needed to decode them independently. + config = WorkflowEncodingConfig( + payload_compression=PayloadCompressionConfig( + min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3) + ) + ) + encoder = PayloadEncoder(encoding_config=config) + raw = json.dumps({"data": "x" * 20_000}).encode() + + encoded_data, encoding_options = await encoder.encode_payload_content( + raw, WorkflowContext(namespace="test", execution_id="exec") + ) + compressed_payload = CompressedPayloadData.model_validate_json(encoded_data) + + assert encoding_options == [EncodedPayloadOptions.COMPRESSED] + assert compressed_payload.compression == ZstdCompressionConfig(level=3) + assert compressed_payload.get_payload() != raw + + +@pytest.mark.asyncio +async def test_payload_encoder_decodes_compressed_payload_content_without_metadata(): + # This mirrors Temporal payload decoding, where the codec passes only bytes + # plus encoding_options back into PayloadEncoder. + config = WorkflowEncodingConfig( + payload_compression=PayloadCompressionConfig( + min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3) + ) + ) + encoder = PayloadEncoder(encoding_config=config) + raw = json.dumps({"data": "x" * 20_000}).encode() + + encoded_data, encoding_options = await encoder.encode_payload_content( + raw, WorkflowContext(namespace="test", execution_id="exec") + ) + decoded = await PayloadEncoder(WorkflowEncodingConfig()).decode_payload_content( + encoded_data, encoding_options + ) + + assert decoded == raw + + @pytest.mark.asyncio async def test_payload_encoder_skips_compression_below_min_size(): config = WorkflowEncodingConfig( @@ -292,9 +376,8 @@ async def test_payload_encoder_decodes_compressed_payload_with_decoder_config( ) assert encoded.encoding_options == [EncodedPayloadOptions.COMPRESSED] - assert encoded.encoding_metadata == { - "compression": {"algorithm": "zstd", "level": 22} - } + compressed_payload = CompressedPayloadData.model_validate_json(encoded.get_payload()) + assert compressed_payload.compression == ZstdCompressionConfig(level=22) assert decoded == payload @@ -355,6 +438,8 @@ async def test_payload_encoder_decodes_encrypted_compressed_payload_with_differe @pytest.mark.asyncio async def test_payload_encoder_decodes_with_tampered_compression_level(): + # Zstd decompression must depend on the frame data, not on the compression + # level that was used when the payload was encoded. encoder = PayloadEncoder( encoding_config=WorkflowEncodingConfig( payload_compression=PayloadCompressionConfig( @@ -367,10 +452,12 @@ async def test_payload_encoder_decodes_with_tampered_compression_level(): encoded = await encoder.encode_network_input( payload, WorkflowContext(namespace="test", execution_id="exec") ) + compressed_payload = CompressedPayloadData.model_validate_json(encoded.get_payload()) + tampered_payload = compressed_payload.model_copy( + update={"compression": ZstdCompressionConfig(level=1)} + ) tampered = NetworkEncodedInput.from_data( - encoded.get_payload(), - encoded.encoding_options, - {"compression": {"algorithm": "zstd", "level": 1}}, + tampered_payload.model_dump_json().encode(), encoded.encoding_options ) decoded = await PayloadEncoder(WorkflowEncodingConfig()).decode_network_result( @@ -382,18 +469,25 @@ async def test_payload_encoder_decodes_with_tampered_compression_level(): @pytest.mark.asyncio @pytest.mark.parametrize( - "encoding_metadata", + "compressed_payload", [ - {}, - {"compression": {"algorithm": "lz4", "level": 1}}, - {"compression": {"level": 3}}, + b"compressed-data", + _compressed_payload_json( + _COMPRESSED_TEST_PAYLOAD, + invalid_compression={"algorithm": "lz4", "level": 1}, + ), + _compressed_payload_json( + _COMPRESSED_TEST_PAYLOAD, + invalid_compression={"level": 3}, + ), + _compressed_payload_json(_COMPRESSED_TEST_PAYLOAD, invalid_base64=True), ], ) -async def test_payload_encoder_invalid_compression_metadata_is_error( - encoding_metadata: dict[str, object], +async def test_payload_encoder_invalid_compressed_payload_is_error( + compressed_payload: bytes, ): encoded = NetworkEncodedInput.from_data( - b"compressed-data", [EncodedPayloadOptions.COMPRESSED], encoding_metadata + compressed_payload, [EncodedPayloadOptions.COMPRESSED] ) with pytest.raises(WorkflowPayloadCompressionException): @@ -404,10 +498,12 @@ async def test_payload_encoder_invalid_compression_metadata_is_error( @pytest.mark.asyncio async def test_payload_encoder_corrupted_compressed_data_is_error(): + compressed_payload = CompressedPayloadData.from_payload( + b"corrupted-data", ZstdCompressionConfig(level=3) + ) encoded = NetworkEncodedInput.from_data( - b"corrupted-data", + compressed_payload.model_dump_json().encode(), [EncodedPayloadOptions.COMPRESSED], - {"compression": {"algorithm": "zstd", "level": 3}}, ) with pytest.raises(zstandard.ZstdError): @@ -571,18 +667,15 @@ async def test_payload_encoder_encodes_event_content_without_offloading(): ) payload = json.dumps({"data": "x" * 20_000}).encode() - encoded, encoding_options, encoding_metadata = await encoder.encode_payload_content( + encoded, encoding_options = await encoder.encode_payload_content( payload, allow_offloading=False, force_full_encryption=True, ) - decoded = await decoder.decode_payload_content( - encoded, encoding_options, encoding_metadata - ) + decoded = await decoder.decode_payload_content(encoded, encoding_options) assert encoding_options == [ EncodedPayloadOptions.COMPRESSED, EncodedPayloadOptions.ENCRYPTED, ] - assert encoding_metadata == {"compression": {"algorithm": "zstd", "level": 3}} assert decoded == payload diff --git a/src/mistralai/extra/workflows/encoding/payload_compressor.py b/src/mistralai/extra/workflows/encoding/payload_compressor.py index 2f372ac4..e8d8c578 100644 --- a/src/mistralai/extra/workflows/encoding/payload_compressor.py +++ b/src/mistralai/extra/workflows/encoding/payload_compressor.py @@ -15,7 +15,6 @@ ) _ALGORITHM_CONFIG_ADAPTER: TypeAdapter[AlgorithmConfig] = TypeAdapter(AlgorithmConfig) -_COMPRESSION_METADATA_KEY = "compression" class Compressor(ABC): @@ -60,6 +59,14 @@ def decompress(self, data: bytes) -> bytes: return result +def compressor_from_config(algo_config: AlgorithmConfig) -> Compressor: + if isinstance(algo_config, ZstdCompressionConfig): + return ZstdCompressor(algo_config) + raise WorkflowPayloadCompressionException( + f"Unsupported compression algorithm: {algo_config.algorithm!r}" + ) + + @lru_cache(maxsize=8) def _build_compressor_for_config(config_json: str) -> Compressor: try: @@ -69,11 +76,7 @@ def _build_compressor_for_config(config_json: str) -> Compressor: f"Invalid compression config in payload: {exc}" ) from exc - if isinstance(algo_config, ZstdCompressionConfig): - return ZstdCompressor(algo_config) - raise WorkflowPayloadCompressionException( - f"Unsupported compression algorithm: {algo_config.algorithm!r}" - ) + return compressor_from_config(algo_config) def build_compressor( @@ -84,24 +87,3 @@ def build_compressor( return _build_compressor_for_config( compression_config.algorithm_config.model_dump_json() ) - - -def compression_metadata(compressor: Compressor) -> dict[str, object]: - return { - _COMPRESSION_METADATA_KEY: compressor.algorithm_config.model_dump(mode="json") - } - - -def compressor_from_metadata(metadata: dict[str, object]) -> Compressor: - config = metadata.get(_COMPRESSION_METADATA_KEY) - if not isinstance(config, dict): - raise WorkflowPayloadCompressionException( - "Missing compression config in payload metadata" - ) - try: - algo_config = _ALGORITHM_CONFIG_ADAPTER.validate_python(config) - except ValidationError as exc: - raise WorkflowPayloadCompressionException( - f"Invalid compression config in payload metadata: {exc}" - ) from exc - return _build_compressor_for_config(algo_config.model_dump_json()) diff --git a/src/mistralai/extra/workflows/encoding/payload_encoder.py b/src/mistralai/extra/workflows/encoding/payload_encoder.py index 0ef49de2..55ae68de 100644 --- a/src/mistralai/extra/workflows/encoding/payload_encoder.py +++ b/src/mistralai/extra/workflows/encoding/payload_encoder.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import binascii import functools import hashlib import json @@ -9,7 +10,7 @@ import urllib.parse from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError if TYPE_CHECKING: from cryptography.exceptions import InvalidTag @@ -25,6 +26,7 @@ from pydantic_core import from_json, to_json from mistralai.extra.workflows.encoding.config import ( + AlgorithmConfig, PayloadEncryptionConfig, PayloadEncryptionMode, PayloadOffloadingConfig, @@ -32,8 +34,7 @@ ) from mistralai.extra.workflows.encoding.payload_compressor import ( build_compressor, - compression_metadata, - compressor_from_metadata, + compressor_from_config, ) from .storage.blob_storage import get_blob_storage, BlobNotFoundError from mistralai.extra.workflows.encoding.models import ( @@ -44,6 +45,7 @@ WorkflowContext, ) from mistralai.extra.exceptions import ( + WorkflowPayloadCompressionException, WorkflowPayloadEncryptionException, WorkflowPayloadOffloadingException, ) @@ -55,6 +57,28 @@ class OffloadedPayloadData(BaseModel): key: str +class CompressedPayloadData(BaseModel): + compression: AlgorithmConfig + b64payload: str + + @classmethod + def from_payload( + cls, payload: bytes, compression: AlgorithmConfig + ) -> "CompressedPayloadData": + return cls( + compression=compression, + b64payload=base64.b64encode(payload).decode("utf-8"), + ) + + def get_payload(self) -> bytes: + try: + return base64.b64decode(self.b64payload, validate=True) + except binascii.Error as exc: + raise WorkflowPayloadCompressionException( + "Invalid compressed payload data" + ) from exc + + class PayloadEncoder: """This class is in charge of payload encoding/decoding operations such as: - Field-level or full-payload encryption @@ -252,17 +276,21 @@ def _compress(self, data: bytes) -> tuple[bytes, bool]: compressed = self.compressor.compress(data) if len(compressed) >= len(data): return data, False - return compressed, True - - def _compression_metadata(self) -> dict[str, object]: - assert self.compressor is not None, ( - "This should never be reached: compression metadata was requested " - "but PayloadEncoder.__init__ did not build a compressor" + compressed_payload = CompressedPayloadData.from_payload( + compressed, self.compressor.algorithm_config ) - return compression_metadata(self.compressor) + return compressed_payload.model_dump_json().encode(), True - def _decompress(self, data: bytes, encoding_metadata: dict[str, object]) -> bytes: - return compressor_from_metadata(encoding_metadata).decompress(data) + def _decompress(self, data: bytes) -> bytes: + try: + compressed_payload = CompressedPayloadData.model_validate_json(data) + except ValidationError as exc: + raise WorkflowPayloadCompressionException( + "Invalid compressed payload metadata" + ) from exc + return compressor_from_config(compressed_payload.compression).decompress( + compressed_payload.get_payload() + ) async def encode_payload_content( self, @@ -271,7 +299,7 @@ async def encode_payload_content( *, allow_offloading: bool = True, force_full_encryption: bool = False, - ) -> tuple[bytes, list[EncodedPayloadOptions], dict[str, object]]: + ) -> tuple[bytes, list[EncodedPayloadOptions]]: """Handle payload encoding. Encoding options are appended in the exact order in which transforms are @@ -281,7 +309,6 @@ async def encode_payload_content( data = data.encode() encoding_options: list[EncodedPayloadOptions] = [] - encoding_metadata: dict[str, object] = {} # Partial encryption needs the original JSON fields. It must run before # compression or offloading, which make field-level markers unavailable. @@ -299,7 +326,6 @@ async def encode_payload_content( data, compressed = self._compress(data) if compressed: encoding_options.append(EncodedPayloadOptions.COMPRESSED) - encoding_metadata.update(self._compression_metadata()) if allow_offloading and self.offloading_config is not None: data, offloaded = await self._handle_offloading(data, context) @@ -316,7 +342,7 @@ async def encode_payload_content( data = self._encrypt(data) encoding_options.append(EncodedPayloadOptions.ENCRYPTED) - return data, encoding_options, encoding_metadata + return data, encoding_options async def decode_payload_content( self, @@ -324,7 +350,6 @@ async def decode_payload_content( encoding_options: List[EncodedPayloadOptions], encoding_metadata: dict[str, object] | None = None, ) -> bytes: - encoding_metadata = encoding_metadata or {} # Decode in the exact reverse order of the encoding_options wire list. for option in reversed(encoding_options): if option == EncodedPayloadOptions.ENCRYPTED: @@ -332,7 +357,7 @@ async def decode_payload_content( elif option == EncodedPayloadOptions.PARTIALLY_ENCRYPTED: data, _ = await self._partially_decrypt_fields(data) elif option == EncodedPayloadOptions.COMPRESSED: - data = self._decompress(data, encoding_metadata) + data = self._decompress(data) elif option == EncodedPayloadOptions.OFFLOADED: if ( self.offloading_config is None @@ -391,14 +416,10 @@ async def encode_network_input( """This method MUST be called to format every payload send to Mistral Workflows control plane to ensure a proper encoding of the payload. """ - ( - encoded_data, - encoding_options, - encoding_metadata, - ) = await self.encode_payload_content(to_json(data), context) - network_input = NetworkEncodedInput.from_data( - encoded_data, encoding_options, encoding_metadata + encoded_data, encoding_options = await self.encode_payload_content( + to_json(data), context ) + network_input = NetworkEncodedInput.from_data(encoded_data, encoding_options) return network_input async def decode_network_result(self, data: Any) -> Any: