Skip to content

Commit 653939e

Browse files
authored
Add pickle/cloudpickle regression tests for cuda.core objects (#1810)
Adds parametrized roundtrip tests covering both pickle and cloudpickle for Device, ObjectCode, IPCBufferDescriptor, and IPCEventDescriptor. These tests guard against the Cython @classmethod identity issue fixed in #1660 (cloudpickle fails when __reduce__ references a @classmethod). Closes #1671 Made-with: Cursor
1 parent cf62591 commit 653939e

File tree

1 file changed

+67
-3
lines changed

1 file changed

+67
-3
lines changed

cuda_core/tests/test_object_protocols.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
"""
4-
Tests for Python object protocols (__eq__, __hash__, __weakref__, __repr__).
4+
Tests for Python object protocols (__eq__, __hash__, __weakref__, __repr__, pickle).
55
66
This module tests that core cuda.core classes properly implement standard Python
7-
object protocols for identity, hashing, weak references, and string representation.
7+
object protocols for identity, hashing, weak references, string representation,
8+
and serialization.
89
"""
910

1011
import itertools
@@ -15,7 +16,17 @@
1516
from helpers.graph_kernels import compile_common_kernels
1617
from helpers.misc import try_create_condition
1718

18-
from cuda.core import Buffer, Device, Kernel, LaunchConfig, Program, Stream, system
19+
from cuda.core import (
20+
Buffer,
21+
Device,
22+
DeviceMemoryResource,
23+
DeviceMemoryResourceOptions,
24+
Kernel,
25+
LaunchConfig,
26+
Program,
27+
Stream,
28+
system,
29+
)
1930
from cuda.core._graph._graphdef import GraphDef
2031
from cuda.core._program import _can_load_generated_ptx
2132

@@ -208,6 +219,30 @@ def sample_kernel_alt(sample_object_code_alt):
208219
return sample_object_code_alt.get_kernel("test_kernel_alt")
209220

210221

222+
# =============================================================================
223+
# Fixtures - IPC samples (for pickle tests)
224+
# =============================================================================
225+
226+
POOL_SIZE = 2097152
227+
228+
229+
@pytest.fixture
230+
def sample_ipc_buffer_descriptor(ipc_device):
231+
"""An IPCBufferDescriptor."""
232+
options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True)
233+
mr = DeviceMemoryResource(ipc_device, options=options)
234+
buf = mr.allocate(64)
235+
return buf.get_ipc_descriptor()
236+
237+
238+
@pytest.fixture
239+
def sample_ipc_event_descriptor(ipc_device):
240+
"""An IPCEventDescriptor."""
241+
stream = ipc_device.create_stream()
242+
e = stream.record(options={"ipc_enabled": True})
243+
return e.get_ipc_descriptor()
244+
245+
211246
# =============================================================================
212247
# Fixtures - Graph types (GraphDef and GraphNode)
213248
# =============================================================================
@@ -606,6 +641,20 @@ def sample_switch_node_alt(sample_graphdef):
606641
("sample_kernel", lambda k: Kernel.from_handle(int(k.handle))),
607642
]
608643

644+
# Types with __reduce__ support (pickle/cloudpickle).
645+
# Event, Buffer, and memory resources are excluded: Event only supports
646+
# IPC-based serialization via multiprocessing reduction; Buffer and memory
647+
# resource __reduce__ use a cross-process registry that doesn't support
648+
# same-process roundtrips.
649+
PICKLE_TYPES = [
650+
"sample_device",
651+
"sample_object_code_cubin",
652+
"sample_ipc_buffer_descriptor",
653+
"sample_ipc_event_descriptor",
654+
]
655+
656+
PICKLE_MODULES = ["pickle", "cloudpickle"]
657+
609658
# Derived type groupings for collection tests
610659
DICT_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES))
611660
WEAK_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES) & set(WEAKREF_TYPES))
@@ -796,3 +845,18 @@ def test_repr_format(fixture_name, pattern, request):
796845
obj = request.getfixturevalue(fixture_name)
797846
result = repr(obj)
798847
assert re.fullmatch(pattern, result)
848+
849+
850+
# =============================================================================
851+
# Pickle tests
852+
# =============================================================================
853+
854+
855+
@pytest.mark.parametrize("pickle_module", PICKLE_MODULES)
856+
@pytest.mark.parametrize("fixture_name", PICKLE_TYPES)
857+
def test_pickle_roundtrip(fixture_name, pickle_module, request):
858+
"""Object survives a pickle/cloudpickle roundtrip."""
859+
mod = pytest.importorskip(pickle_module)
860+
obj = request.getfixturevalue(fixture_name)
861+
result = mod.loads(mod.dumps(obj))
862+
assert type(result) is type(obj)

0 commit comments

Comments
 (0)