Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import io
import logging
import pickle # nosec # noqa: S403
from typing import Any
from typing import Any, cast

from ..exceptions import WorkflowCheckpointException

Expand All @@ -59,6 +59,10 @@
# Marker to identify pickled values in serialized JSON
_PICKLE_MARKER = "__pickled__"
_TYPE_MARKER = "__type__"
_RESERVED_DICT_KEYS: frozenset[str] = frozenset({
_PICKLE_MARKER,
_TYPE_MARKER,
})

# Types that are natively JSON-serializable and don't need pickling
_JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
Expand All @@ -69,6 +73,13 @@
# Module prefix for OpenAI SDK types that are always allowed
_OPENAI_MODULE_PREFIX = "openai.types."

_DENIED_ALLOWED_TYPE_KEYS: frozenset[str] = frozenset({
"agent_framework._workflows._checkpoint_encoding:_RestrictedUnpickler",
"agent_framework._workflows._checkpoint_encoding:_base64_to_unpickle",
"agent_framework._workflows._checkpoint_encoding:decode_checkpoint_value",
"agent_framework._workflows._checkpoint_encoding:encode_checkpoint_value",
})

# Built-in types considered safe for checkpoint deserialization.
# Each entry is a ``module:qualname`` string matching the format produced by
# :func:`_type_to_key`. These are the classes for which pickle's
Expand Down Expand Up @@ -128,14 +139,19 @@ def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
def find_class(self, module: str, name: str) -> type:
type_key = f"{module}:{name}"

if (
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
or type_key in self._allowed_types
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
or module.startswith(_OPENAI_MODULE_PREFIX)
):
if type_key in _DENIED_ALLOWED_TYPE_KEYS:
raise pickle.UnpicklingError(f"Checkpoint deserialization blocked for type '{type_key}'.")

if type_key in _BUILTIN_ALLOWED_TYPE_KEYS or type_key in self._allowed_types:
return super().find_class(module, name) # type: ignore[no-any-return] # nosec

if module.startswith(_FRAMEWORK_MODULE_PREFIX) or module.startswith(_OPENAI_MODULE_PREFIX):
if "." in name:
raise pickle.UnpicklingError(f"Checkpoint deserialization blocked for type '{type_key}'.")
resolved = super().find_class(module, name) # nosec
if isinstance(resolved, type):
return resolved

raise pickle.UnpicklingError(
f"Checkpoint deserialization blocked for type '{type_key}'. "
f"To allow this type, either include its 'module:qualname' key in the "
Expand Down Expand Up @@ -217,17 +233,18 @@ def _encode(value: Any) -> Any:

# Recursively encode dict values (keys become strings)
if isinstance(value, dict):
return {str(k): _encode(v) for k, v in value.items()} # type: ignore
typed_dict = cast(dict[Any, Any], value) # type: ignore[redundant-cast]
if any(str(k) in _RESERVED_DICT_KEYS for k in typed_dict):
return _encode_pickle(value)
encoded_dict: dict[str, Any] = {str(k): _encode(v) for k, v in typed_dict.items()}
return encoded_dict
Comment thread
moonbox3 marked this conversation as resolved.

# Recursively encode list items (lists are JSON-native collections)
if isinstance(value, list):
return [_encode(item) for item in value] # type: ignore

# Everything else (tuples, sets, dataclasses, custom objects, etc.): pickle and base64 encode
return {
_PICKLE_MARKER: _pickle_to_base64(value),
_TYPE_MARKER: _type_to_key(type(value)), # type: ignore
}
return _encode_pickle(value)


def _decode(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
Expand All @@ -238,14 +255,15 @@ def _decode(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:

# Handle encoded dicts
if isinstance(value, dict):
typed_dict = cast(dict[str, Any], value)
# Pickled value: decode, unpickle, and verify type
if _PICKLE_MARKER in value and _TYPE_MARKER in value:
obj = _base64_to_unpickle(value[_PICKLE_MARKER], allowed_types=allowed_types) # type: ignore
_verify_type(obj, value.get(_TYPE_MARKER)) # type: ignore
if _PICKLE_MARKER in typed_dict and _TYPE_MARKER in typed_dict:
obj = _base64_to_unpickle(cast(str, typed_dict[_PICKLE_MARKER]), allowed_types=allowed_types)
_verify_type(obj, cast(str, typed_dict.get(_TYPE_MARKER)))
return obj

# Regular dict: decode values recursively
return {k: _decode(v, allowed_types=allowed_types) for k, v in value.items()} # type: ignore
return {k: _decode(v, allowed_types=allowed_types) for k, v in typed_dict.items()}

# Handle encoded lists
if isinstance(value, list):
Expand All @@ -254,6 +272,14 @@ def _decode(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
return value


def _encode_pickle(value: Any) -> dict[str, str]:
"""Encode a value as a pickle envelope."""
return {
_PICKLE_MARKER: _pickle_to_base64(value),
_TYPE_MARKER: _value_type_to_key(value),
}


def _verify_type(obj: Any, expected_type_key: str) -> None:
"""Verify that an unpickled object matches its recorded type.

Expand Down Expand Up @@ -306,6 +332,11 @@ def _base64_to_unpickle(encoded: str, *, allowed_types: frozenset[str] | None =
raise WorkflowCheckpointException(f"Failed to decode pickled checkpoint data: {exc}") from exc


def _type_to_key(t: type) -> str:
def _type_to_key(t: type[Any]) -> str:
"""Convert a type to a module:qualname string."""
return f"{t.__module__}:{t.__qualname__}"


def _value_type_to_key(value: object) -> str:
"""Convert a value's type to a module:qualname string."""
return _type_to_key(type(value))
29 changes: 25 additions & 4 deletions python/packages/core/tests/workflow/test_checkpoint_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
decode_checkpoint_value,
encode_checkpoint_value,
)

Expand Down Expand Up @@ -303,13 +304,33 @@ def test_encode_complex_mixed_structure() -> None:
assert _PICKLE_MARKER in result["dataclass_value"]


def test_encode_preserves_dict_with_pickle_marker_key() -> None:
"""Test that regular dicts containing _PICKLE_MARKER key are recursively encoded."""
def test_encode_round_trips_dict_with_pickle_marker_key() -> None:
"""Test that regular dicts containing reserved marker keys remain user data."""
data = {
_PICKLE_MARKER: "some_value",
_TYPE_MARKER: "some_type",
"other_key": "test",
}
result = encode_checkpoint_value(data)
assert isinstance(result, dict)
assert _PICKLE_MARKER in result
assert result[_PICKLE_MARKER] == "some_value"
assert result["other_key"] == "test"
assert decode_checkpoint_value(result, allowed_types=frozenset()) == data


def test_encode_round_trips_nested_dict_with_pickle_marker_key() -> None:
"""Test that nested marker-shaped dictionaries remain user data."""
data = {
"items": [
{
_PICKLE_MARKER: "some_value",
_TYPE_MARKER: "some_type",
}
]
}
assert decode_checkpoint_value(encode_checkpoint_value(data), allowed_types=frozenset()) == data


def test_decode_preserves_user_dict_matching_old_escape_shape() -> None:
"""Test that user data shaped like an old escape envelope remains unchanged."""
data = {"__agent_framework_checkpoint_dict__": True, "value": {"safe": "data"}}
assert decode_checkpoint_value(data) == data
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any

import pytest

Expand All @@ -25,6 +26,7 @@
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER,
_TYPE_MARKER,
_base64_to_unpickle, # pyright: ignore[reportPrivateUsage]
decode_checkpoint_value,
encode_checkpoint_value,
)
Expand All @@ -37,6 +39,16 @@ def __reduce__(self):
return (os.getpid, ())


class FrameworkHelperPayload:
"""A payload that references a framework helper during unpickling."""

def __init__(self, nested_payload: str) -> None:
self.nested_payload = nested_payload

def __reduce__(self) -> tuple[Any, tuple[str]]:
return (_base64_to_unpickle, (self.nested_payload,))


def test_restricted_decode_blocks_arbitrary_callable():
"""Restricted decoding blocks arbitrary module-level callables."""
pickled = pickle.dumps(os.getpid, protocol=pickle.HIGHEST_PROTOCOL)
Expand Down Expand Up @@ -98,6 +110,54 @@ def test_restricted_decode_prevents_code_execution():
)


def test_restricted_decode_blocks_framework_deserialization_helpers() -> None:
"""Restricted deserialization blocks framework helper callables."""
with tempfile.TemporaryDirectory() as tmpdir:
marker_file = os.path.join(tmpdir, "checkpoint_helper_marker")
nested_payload = pickle.dumps(
type(
"NestedExploit",
(),
{
"__reduce__": lambda self: (
eval,
(f"open({marker_file!r}, 'w').write('pwned')",),
)
},
)(),
protocol=pickle.HIGHEST_PROTOCOL,
)
payload = FrameworkHelperPayload(base64.b64encode(nested_payload).decode("ascii"))
encoded_b64 = base64.b64encode(pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)).decode("ascii")

checkpoint_value = {
_PICKLE_MARKER: encoded_b64,
_TYPE_MARKER: "builtins:int",
}
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())

assert not os.path.exists(marker_file)


def test_restricted_decode_blocks_dotted_framework_global() -> None:
"""Restricted deserialization blocks dotted globals in allowed framework modules."""
module = b"agent_framework._workflows._checkpoint_encoding"
name = b"pickle.loads"
dotted_global_payload = (
b"\x80\x04\x8c" + bytes([len(module)]) + module + b"\x8c" + bytes([len(name)]) + name + b"\x93C\x05NESTD\x85R."
)
encoded_b64 = base64.b64encode(dotted_global_payload).decode("ascii")

checkpoint_value = {
_PICKLE_MARKER: encoded_b64,
_TYPE_MARKER: "builtins:int",
}

with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())


def test_file_checkpoint_storage_accepts_allowed_types():
"""FileCheckpointStorage.__init__ accepts allowed_checkpoint_types."""
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -207,6 +267,28 @@ async def test_file_storage_allows_listed_user_type():
assert loaded.state["data"].value == 99


async def test_file_storage_round_trips_marker_shaped_dict_state() -> None:
"""FileCheckpointStorage preserves marker-shaped dictionaries as user data."""
from agent_framework import WorkflowCheckpoint

with tempfile.TemporaryDirectory() as tmpdir:
storage = FileCheckpointStorage(tmpdir)
state_data = {
_PICKLE_MARKER: "some_value",
_TYPE_MARKER: "some_type",
}
checkpoint = WorkflowCheckpoint(
workflow_name="test",
graph_signature_hash="hash",
state={"data": state_data},
)
checkpoint_id = await storage.save(checkpoint)

loaded = await storage.load(checkpoint_id)

assert loaded.state["data"] == state_data


def test_restricted_unpickler_raises_pickle_error():
"""_RestrictedUnpickler.find_class raises pickle.UnpicklingError, not a framework exception."""
from agent_framework._workflows._checkpoint_encoding import _RestrictedUnpickler
Expand Down
Loading