Skip to content

Commit d202362

Browse files
committed
Add workflow payload compression encoder
1 parent 7ebe84b commit d202362

12 files changed

Lines changed: 613 additions & 50 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ workflow_payload_offloading = [
4848
workflow_payload_encryption = [
4949
"cryptography>=41.0.0,<47.0.0",
5050
]
51+
workflow_payload_compression = [
52+
"zstandard>=0.25.0,<0.26",
53+
]
5154

5255

5356
[project.urls]
@@ -69,6 +72,7 @@ dev = [
6972
"griffe>=1.7.3,<2",
7073
"authlib>=1.5.2,<2",
7174
"websockets >=13.0",
75+
"zstandard>=0.25.0,<0.26",
7276
]
7377
lint = [
7478
"ruff>=0.11.10,<0.12",

src/mistralai/extra/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class WorkflowPayloadEncryptionException(MistralClientException):
2828
"""Workflow payload encryption exception"""
2929

3030

31+
class WorkflowPayloadCompressionException(MistralClientException):
32+
"""Workflow payload compression exception"""
33+
34+
3135
class RunException(MistralClientException):
3236
"""Conversation run errors."""
3337

src/mistralai/extra/tests/fixtures/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
6+
class InMemoryBlobStorage:
7+
def __init__(self) -> None:
8+
self.blobs: dict[str, bytes] = {}
9+
10+
async def __aenter__(self) -> "InMemoryBlobStorage":
11+
return self
12+
13+
async def __aexit__(self, *_args: Any) -> None:
14+
pass
15+
16+
async def upload_blob(self, key: str, content: bytes) -> str:
17+
self.blobs[key] = content
18+
return key
19+
20+
async def get_blob(self, key: str) -> bytes:
21+
return self.blobs[key]
22+
23+
async def get_blob_properties(self, key: str) -> dict[str, Any] | None:
24+
if key not in self.blobs:
25+
return None
26+
return {"size": len(self.blobs[key]), "last_modified": "test"}

src/mistralai/extra/tests/test_workflow_encoding.py

Lines changed: 206 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for workflow encoding configuration lifecycle."""
22

33
import gc
4+
import json
45

56
import pytest
67
from pydantic import SecretStr
@@ -12,10 +13,22 @@
1213
configure_workflow_encoding,
1314
)
1415
from mistralai.extra.workflows import (
15-
WorkflowEncodingConfig,
16+
BlobStorageConfig,
17+
EncryptedStrField,
18+
PayloadCompressionConfig,
1619
PayloadEncryptionConfig,
1720
PayloadEncryptionMode,
21+
PayloadOffloadingConfig,
22+
StorageProvider,
23+
WorkflowEncodingConfig,
24+
ZstdCompressionConfig,
25+
)
26+
from mistralai.extra.workflows.encoding.models import (
27+
EncodedPayloadOptions,
28+
WorkflowContext,
1829
)
30+
from mistralai.extra.workflows.encoding.payload_encoder import PayloadEncoder
31+
from mistralai.extra.tests.fixtures.workflow_encoding import InMemoryBlobStorage
1932

2033

2134
@pytest.fixture
@@ -29,7 +42,9 @@ def encryption_config() -> WorkflowEncodingConfig:
2942
)
3043

3144

32-
def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodingConfig):
45+
def test_payload_encoder_cleanup_on_client_gc(
46+
encryption_config: WorkflowEncodingConfig,
47+
):
3348
"""Test that PayloadEncoder is cleaned up when client is garbage collected."""
3449
initial_config_count = len(_workflow_configs)
3550

@@ -56,7 +71,9 @@ def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodin
5671
assert len(_workflow_configs) == initial_config_count
5772

5873

59-
def test_multiple_clients_independent_configs(encryption_config: WorkflowEncodingConfig):
74+
def test_multiple_clients_independent_configs(
75+
encryption_config: WorkflowEncodingConfig,
76+
):
6077
"""Test that multiple clients have independent configs."""
6178
initial_config_count = len(_workflow_configs)
6279

@@ -132,3 +149,189 @@ def test_reconfigure_same_client(encryption_config: WorkflowEncodingConfig):
132149
del client
133150
gc.collect()
134151
assert config_id not in _workflow_configs
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_payload_encoder_compresses_network_inputs():
156+
config = WorkflowEncodingConfig(
157+
payload_compression=PayloadCompressionConfig(
158+
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
159+
)
160+
)
161+
encoder = PayloadEncoder(encoding_config=config)
162+
payload = {"data": "x" * 20_000}
163+
164+
encoded = await encoder.encode_network_input(
165+
payload, WorkflowContext(namespace="test", execution_id="exec")
166+
)
167+
168+
assert encoded.encoding_options == [EncodedPayloadOptions.COMPRESSED]
169+
assert encoded.encoding_metadata == {}
170+
171+
compressed_payload = json.loads(encoded.get_payload())
172+
assert compressed_payload["algorithm_config"] == {"algorithm": "zstd", "level": 3}
173+
174+
decoded = await encoder.decode_network_result(encoded.model_dump(mode="json"))
175+
assert decoded == payload
176+
177+
178+
@pytest.mark.asyncio
179+
async def test_payload_encoder_partially_encrypts_before_offloading(monkeypatch):
180+
storage = InMemoryBlobStorage()
181+
monkeypatch.setattr(
182+
"mistralai.extra.workflows.encoding.payload_encoder.get_blob_storage",
183+
lambda _: storage,
184+
)
185+
config = WorkflowEncodingConfig(
186+
payload_encryption=PayloadEncryptionConfig(
187+
mode=PayloadEncryptionMode.PARTIAL,
188+
main_key=SecretStr("0" * 64),
189+
),
190+
payload_offloading=PayloadOffloadingConfig(
191+
min_size_bytes=1,
192+
storage_config=BlobStorageConfig(
193+
storage_provider=StorageProvider.S3,
194+
bucket_name="test-bucket",
195+
),
196+
),
197+
)
198+
encoder = PayloadEncoder(encoding_config=config)
199+
payload = {
200+
"data": "plain value",
201+
"secret": EncryptedStrField(data="secret value").model_dump(),
202+
}
203+
204+
encoded = await encoder.encode_network_input(
205+
payload, WorkflowContext(namespace="test", execution_id="exec")
206+
)
207+
offloaded_payload = json.loads(encoded.get_payload())
208+
offloaded_bytes = storage.blobs[offloaded_payload["key"]]
209+
210+
assert encoded.encoding_options == [
211+
EncodedPayloadOptions.PARTIALLY_ENCRYPTED,
212+
EncodedPayloadOptions.OFFLOADED,
213+
]
214+
assert b"plain value" in offloaded_bytes
215+
assert b"secret value" not in offloaded_bytes
216+
217+
decoded = await encoder.decode_network_result(encoded.model_dump(mode="json"))
218+
assert decoded == payload
219+
220+
221+
@pytest.mark.asyncio
222+
@pytest.mark.parametrize(
223+
("encryption_mode", "expected_options"),
224+
[
225+
(
226+
PayloadEncryptionMode.PARTIAL,
227+
[
228+
EncodedPayloadOptions.PARTIALLY_ENCRYPTED,
229+
EncodedPayloadOptions.COMPRESSED,
230+
EncodedPayloadOptions.OFFLOADED,
231+
],
232+
),
233+
(
234+
PayloadEncryptionMode.FULL,
235+
[
236+
EncodedPayloadOptions.COMPRESSED,
237+
EncodedPayloadOptions.OFFLOADED,
238+
EncodedPayloadOptions.ENCRYPTED,
239+
],
240+
),
241+
],
242+
)
243+
async def test_payload_encoder_compression_offloading_encryption_roundtrip(
244+
monkeypatch,
245+
encryption_mode: PayloadEncryptionMode,
246+
expected_options: list[EncodedPayloadOptions],
247+
):
248+
storage = InMemoryBlobStorage()
249+
monkeypatch.setattr(
250+
"mistralai.extra.workflows.encoding.payload_encoder.get_blob_storage",
251+
lambda _: storage,
252+
)
253+
config = WorkflowEncodingConfig(
254+
payload_encryption=PayloadEncryptionConfig(
255+
mode=encryption_mode,
256+
main_key=SecretStr("0" * 64),
257+
),
258+
payload_compression=PayloadCompressionConfig(
259+
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
260+
),
261+
payload_offloading=PayloadOffloadingConfig(
262+
min_size_bytes=1,
263+
storage_config=BlobStorageConfig(
264+
storage_provider=StorageProvider.S3,
265+
bucket_name="test-bucket",
266+
),
267+
),
268+
)
269+
encoder = PayloadEncoder(encoding_config=config)
270+
payload = {
271+
"data": "x" * 20_000,
272+
"secret": EncryptedStrField(data="secret value").model_dump(),
273+
}
274+
275+
encoded = await encoder.encode_network_input(
276+
payload, WorkflowContext(namespace="test", execution_id="exec")
277+
)
278+
279+
assert encoded.encoding_options == expected_options
280+
assert encoded.encoding_metadata == {}
281+
assert len(storage.blobs) == 1
282+
decoded = await encoder.decode_network_result(encoded.model_dump(mode="json"))
283+
assert decoded == payload
284+
285+
286+
@pytest.mark.asyncio
287+
async def test_payload_encoder_does_not_partially_encrypt_when_no_marked_fields():
288+
config = WorkflowEncodingConfig(
289+
payload_encryption=PayloadEncryptionConfig(
290+
mode=PayloadEncryptionMode.PARTIAL,
291+
main_key=SecretStr("0" * 64),
292+
),
293+
payload_compression=PayloadCompressionConfig(
294+
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
295+
),
296+
)
297+
encoder = PayloadEncoder(encoding_config=config)
298+
payload = {"data": "x" * 20_000}
299+
300+
encoded = await encoder.encode_network_input(
301+
payload, WorkflowContext(namespace="test", execution_id="exec")
302+
)
303+
304+
assert encoded.encoding_options == [EncodedPayloadOptions.COMPRESSED]
305+
assert encoded.encoding_metadata == {}
306+
decoded = await encoder.decode_network_result(encoded.model_dump(mode="json"))
307+
assert decoded == payload
308+
309+
310+
@pytest.mark.asyncio
311+
async def test_payload_encoder_event_payload_orders_compression_before_full_encryption():
312+
config = WorkflowEncodingConfig(
313+
payload_encryption=PayloadEncryptionConfig(
314+
mode=PayloadEncryptionMode.FULL,
315+
main_key=SecretStr("0" * 64),
316+
),
317+
payload_compression=PayloadCompressionConfig(
318+
min_size_bytes=1, algorithm_config=ZstdCompressionConfig(level=3)
319+
),
320+
)
321+
encoder = PayloadEncoder(encoding_config=config)
322+
payload = json.dumps({"data": "x" * 20_000}).encode()
323+
324+
(
325+
encoded,
326+
encoding_options,
327+
encoding_metadata,
328+
) = await encoder.encode_event_payload_content(payload)
329+
decoded = await encoder.decode_payload_content(
330+
encoded, encoding_options, encoding_metadata
331+
)
332+
333+
assert encoding_options == [
334+
EncodedPayloadOptions.COMPRESSED,
335+
EncodedPayloadOptions.ENCRYPTED,
336+
]
337+
assert decoded == payload

src/mistralai/extra/workflows/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
WorkflowExtensions,
1010
)
1111
from .encoding import (
12-
WorkflowEncodingConfig,
13-
PayloadOffloadingConfig,
12+
AlgorithmConfig,
13+
BlobStorageConfig,
14+
EncryptedStrField,
15+
PayloadCompressionConfig,
1416
PayloadEncryptionConfig,
1517
PayloadEncryptionMode,
16-
BlobStorageConfig,
18+
PayloadOffloadingConfig,
1719
StorageProvider,
18-
EncryptedStrField,
20+
WorkflowEncodingConfig,
21+
ZstdCompressionConfig,
1922
configure_workflow_encoding,
2023
generate_two_part_id,
2124
)
@@ -27,10 +30,13 @@
2730
"ConnectorSlot",
2831
"WorkflowExtensions",
2932
"execute_with_connector_auth_async",
33+
"AlgorithmConfig",
3034
"WorkflowEncodingConfig",
3135
"PayloadOffloadingConfig",
3236
"PayloadEncryptionConfig",
3337
"PayloadEncryptionMode",
38+
"PayloadCompressionConfig",
39+
"ZstdCompressionConfig",
3440
"BlobStorageConfig",
3541
"StorageProvider",
3642
"EncryptedStrField",

src/mistralai/extra/workflows/encoding/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
from .config import (
2-
WorkflowEncodingConfig,
3-
PayloadOffloadingConfig,
2+
AlgorithmConfig,
3+
BlobStorageConfig,
4+
PayloadCompressionConfig,
45
PayloadEncryptionConfig,
56
PayloadEncryptionMode,
6-
BlobStorageConfig,
7+
PayloadOffloadingConfig,
78
StorageProvider,
9+
WorkflowEncodingConfig,
10+
ZstdCompressionConfig,
811
)
912
from .models import EncryptedStrField
1013
from .payload_encoder import PayloadEncoder
1114
from .helpers import configure_workflow_encoding, generate_two_part_id
1215

1316
__all__ = [
17+
"AlgorithmConfig",
1418
"WorkflowEncodingConfig",
1519
"PayloadOffloadingConfig",
1620
"PayloadEncryptionConfig",
1721
"PayloadEncryptionMode",
22+
"PayloadCompressionConfig",
23+
"ZstdCompressionConfig",
1824
"BlobStorageConfig",
1925
"StorageProvider",
2026
"EncryptedStrField",

src/mistralai/extra/workflows/encoding/config.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import Enum
2-
from pydantic import SecretStr, BaseModel
3-
from typing import Optional
2+
from typing import Annotated, Literal, Optional, Union
3+
4+
from pydantic import BaseModel, Field, SecretStr
45

56

67
class StorageProvider(str, Enum):
@@ -47,6 +48,22 @@ class PayloadEncryptionConfig(BaseModel):
4748
secondary_key: Optional[SecretStr] = None
4849

4950

51+
class ZstdCompressionConfig(BaseModel):
52+
algorithm: Literal["zstd"] = "zstd"
53+
level: int = Field(default=3, ge=1, le=22)
54+
55+
56+
AlgorithmConfig = Annotated[
57+
Union[ZstdCompressionConfig], Field(discriminator="algorithm")
58+
]
59+
60+
61+
class PayloadCompressionConfig(BaseModel):
62+
min_size_bytes: int = 1024 * 1024 # 1MB
63+
algorithm_config: AlgorithmConfig = Field(default_factory=ZstdCompressionConfig)
64+
65+
5066
class WorkflowEncodingConfig(BaseModel):
5167
payload_offloading: PayloadOffloadingConfig | None = None
5268
payload_encryption: PayloadEncryptionConfig | None = None
69+
payload_compression: PayloadCompressionConfig | None = None

0 commit comments

Comments
 (0)