11"""Tests for workflow encoding configuration lifecycle."""
22
33import gc
4+ import json
45
56import pytest
67from pydantic import SecretStr
1213 configure_workflow_encoding ,
1314)
1415from mistralai .extra .workflows import (
15- WorkflowEncodingConfig ,
16+ BlobStorageConfig ,
17+ EncryptedStrField ,
18+ PayloadCompressionConfig ,
1619 PayloadEncryptionConfig ,
1720 PayloadEncryptionMode ,
21+ PayloadOffloadingConfig ,
22+ StorageProvider ,
23+ WorkflowEncodingConfig ,
24+ ZstdCompressionConfig ,
25+ )
26+ from mistralai .extra .workflows .encoding .models import (
27+ EncodedPayloadOptions ,
28+ WorkflowContext ,
1829)
30+ from mistralai .extra .workflows .encoding .payload_encoder import PayloadEncoder
31+ from mistralai .extra .tests .fixtures .workflow_encoding import InMemoryBlobStorage
1932
2033
2134@pytest .fixture
@@ -29,7 +42,9 @@ def encryption_config() -> WorkflowEncodingConfig:
2942 )
3043
3144
32- def test_payload_encoder_cleanup_on_client_gc (encryption_config : WorkflowEncodingConfig ):
45+ def test_payload_encoder_cleanup_on_client_gc (
46+ encryption_config : WorkflowEncodingConfig ,
47+ ):
3348 """Test that PayloadEncoder is cleaned up when client is garbage collected."""
3449 initial_config_count = len (_workflow_configs )
3550
@@ -56,7 +71,9 @@ def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodin
5671 assert len (_workflow_configs ) == initial_config_count
5772
5873
59- def test_multiple_clients_independent_configs (encryption_config : WorkflowEncodingConfig ):
74+ def test_multiple_clients_independent_configs (
75+ encryption_config : WorkflowEncodingConfig ,
76+ ):
6077 """Test that multiple clients have independent configs."""
6178 initial_config_count = len (_workflow_configs )
6279
@@ -132,3 +149,189 @@ def test_reconfigure_same_client(encryption_config: WorkflowEncodingConfig):
132149 del client
133150 gc .collect ()
134151 assert config_id not in _workflow_configs
152+
153+
154+ @pytest .mark .asyncio
155+ async def test_payload_encoder_compresses_network_inputs ():
156+ config = WorkflowEncodingConfig (
157+ payload_compression = PayloadCompressionConfig (
158+ min_size_bytes = 1 , algorithm_config = ZstdCompressionConfig (level = 3 )
159+ )
160+ )
161+ encoder = PayloadEncoder (encoding_config = config )
162+ payload = {"data" : "x" * 20_000 }
163+
164+ encoded = await encoder .encode_network_input (
165+ payload , WorkflowContext (namespace = "test" , execution_id = "exec" )
166+ )
167+
168+ assert encoded .encoding_options == [EncodedPayloadOptions .COMPRESSED ]
169+ assert encoded .encoding_metadata == {}
170+
171+ compressed_payload = json .loads (encoded .get_payload ())
172+ assert compressed_payload ["algorithm_config" ] == {"algorithm" : "zstd" , "level" : 3 }
173+
174+ decoded = await encoder .decode_network_result (encoded .model_dump (mode = "json" ))
175+ assert decoded == payload
176+
177+
178+ @pytest .mark .asyncio
179+ async def test_payload_encoder_partially_encrypts_before_offloading (monkeypatch ):
180+ storage = InMemoryBlobStorage ()
181+ monkeypatch .setattr (
182+ "mistralai.extra.workflows.encoding.payload_encoder.get_blob_storage" ,
183+ lambda _ : storage ,
184+ )
185+ config = WorkflowEncodingConfig (
186+ payload_encryption = PayloadEncryptionConfig (
187+ mode = PayloadEncryptionMode .PARTIAL ,
188+ main_key = SecretStr ("0" * 64 ),
189+ ),
190+ payload_offloading = PayloadOffloadingConfig (
191+ min_size_bytes = 1 ,
192+ storage_config = BlobStorageConfig (
193+ storage_provider = StorageProvider .S3 ,
194+ bucket_name = "test-bucket" ,
195+ ),
196+ ),
197+ )
198+ encoder = PayloadEncoder (encoding_config = config )
199+ payload = {
200+ "data" : "plain value" ,
201+ "secret" : EncryptedStrField (data = "secret value" ).model_dump (),
202+ }
203+
204+ encoded = await encoder .encode_network_input (
205+ payload , WorkflowContext (namespace = "test" , execution_id = "exec" )
206+ )
207+ offloaded_payload = json .loads (encoded .get_payload ())
208+ offloaded_bytes = storage .blobs [offloaded_payload ["key" ]]
209+
210+ assert encoded .encoding_options == [
211+ EncodedPayloadOptions .PARTIALLY_ENCRYPTED ,
212+ EncodedPayloadOptions .OFFLOADED ,
213+ ]
214+ assert b"plain value" in offloaded_bytes
215+ assert b"secret value" not in offloaded_bytes
216+
217+ decoded = await encoder .decode_network_result (encoded .model_dump (mode = "json" ))
218+ assert decoded == payload
219+
220+
221+ @pytest .mark .asyncio
222+ @pytest .mark .parametrize (
223+ ("encryption_mode" , "expected_options" ),
224+ [
225+ (
226+ PayloadEncryptionMode .PARTIAL ,
227+ [
228+ EncodedPayloadOptions .PARTIALLY_ENCRYPTED ,
229+ EncodedPayloadOptions .COMPRESSED ,
230+ EncodedPayloadOptions .OFFLOADED ,
231+ ],
232+ ),
233+ (
234+ PayloadEncryptionMode .FULL ,
235+ [
236+ EncodedPayloadOptions .COMPRESSED ,
237+ EncodedPayloadOptions .OFFLOADED ,
238+ EncodedPayloadOptions .ENCRYPTED ,
239+ ],
240+ ),
241+ ],
242+ )
243+ async def test_payload_encoder_compression_offloading_encryption_roundtrip (
244+ monkeypatch ,
245+ encryption_mode : PayloadEncryptionMode ,
246+ expected_options : list [EncodedPayloadOptions ],
247+ ):
248+ storage = InMemoryBlobStorage ()
249+ monkeypatch .setattr (
250+ "mistralai.extra.workflows.encoding.payload_encoder.get_blob_storage" ,
251+ lambda _ : storage ,
252+ )
253+ config = WorkflowEncodingConfig (
254+ payload_encryption = PayloadEncryptionConfig (
255+ mode = encryption_mode ,
256+ main_key = SecretStr ("0" * 64 ),
257+ ),
258+ payload_compression = PayloadCompressionConfig (
259+ min_size_bytes = 1 , algorithm_config = ZstdCompressionConfig (level = 3 )
260+ ),
261+ payload_offloading = PayloadOffloadingConfig (
262+ min_size_bytes = 1 ,
263+ storage_config = BlobStorageConfig (
264+ storage_provider = StorageProvider .S3 ,
265+ bucket_name = "test-bucket" ,
266+ ),
267+ ),
268+ )
269+ encoder = PayloadEncoder (encoding_config = config )
270+ payload = {
271+ "data" : "x" * 20_000 ,
272+ "secret" : EncryptedStrField (data = "secret value" ).model_dump (),
273+ }
274+
275+ encoded = await encoder .encode_network_input (
276+ payload , WorkflowContext (namespace = "test" , execution_id = "exec" )
277+ )
278+
279+ assert encoded .encoding_options == expected_options
280+ assert encoded .encoding_metadata == {}
281+ assert len (storage .blobs ) == 1
282+ decoded = await encoder .decode_network_result (encoded .model_dump (mode = "json" ))
283+ assert decoded == payload
284+
285+
286+ @pytest .mark .asyncio
287+ async def test_payload_encoder_does_not_partially_encrypt_when_no_marked_fields ():
288+ config = WorkflowEncodingConfig (
289+ payload_encryption = PayloadEncryptionConfig (
290+ mode = PayloadEncryptionMode .PARTIAL ,
291+ main_key = SecretStr ("0" * 64 ),
292+ ),
293+ payload_compression = PayloadCompressionConfig (
294+ min_size_bytes = 1 , algorithm_config = ZstdCompressionConfig (level = 3 )
295+ ),
296+ )
297+ encoder = PayloadEncoder (encoding_config = config )
298+ payload = {"data" : "x" * 20_000 }
299+
300+ encoded = await encoder .encode_network_input (
301+ payload , WorkflowContext (namespace = "test" , execution_id = "exec" )
302+ )
303+
304+ assert encoded .encoding_options == [EncodedPayloadOptions .COMPRESSED ]
305+ assert encoded .encoding_metadata == {}
306+ decoded = await encoder .decode_network_result (encoded .model_dump (mode = "json" ))
307+ assert decoded == payload
308+
309+
310+ @pytest .mark .asyncio
311+ async def test_payload_encoder_event_payload_orders_compression_before_full_encryption ():
312+ config = WorkflowEncodingConfig (
313+ payload_encryption = PayloadEncryptionConfig (
314+ mode = PayloadEncryptionMode .FULL ,
315+ main_key = SecretStr ("0" * 64 ),
316+ ),
317+ payload_compression = PayloadCompressionConfig (
318+ min_size_bytes = 1 , algorithm_config = ZstdCompressionConfig (level = 3 )
319+ ),
320+ )
321+ encoder = PayloadEncoder (encoding_config = config )
322+ payload = json .dumps ({"data" : "x" * 20_000 }).encode ()
323+
324+ (
325+ encoded ,
326+ encoding_options ,
327+ encoding_metadata ,
328+ ) = await encoder .encode_event_payload_content (payload )
329+ decoded = await encoder .decode_payload_content (
330+ encoded , encoding_options , encoding_metadata
331+ )
332+
333+ assert encoding_options == [
334+ EncodedPayloadOptions .COMPRESSED ,
335+ EncodedPayloadOptions .ENCRYPTED ,
336+ ]
337+ assert decoded == payload
0 commit comments