|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 | """ |
4 | | -Tests for Python object protocols (__eq__, __hash__, __weakref__, __repr__). |
| 4 | +Tests for Python object protocols (__eq__, __hash__, __weakref__, __repr__, pickle). |
5 | 5 |
|
6 | 6 | This module tests that core cuda.core classes properly implement standard Python |
7 | | -object protocols for identity, hashing, weak references, and string representation. |
| 7 | +object protocols for identity, hashing, weak references, string representation, |
| 8 | +and serialization. |
8 | 9 | """ |
9 | 10 |
|
10 | 11 | import itertools |
|
15 | 16 | from helpers.graph_kernels import compile_common_kernels |
16 | 17 | from helpers.misc import try_create_condition |
17 | 18 |
|
18 | | -from cuda.core import Buffer, Device, Kernel, LaunchConfig, Program, Stream, system |
| 19 | +from cuda.core import ( |
| 20 | + Buffer, |
| 21 | + Device, |
| 22 | + DeviceMemoryResource, |
| 23 | + DeviceMemoryResourceOptions, |
| 24 | + Kernel, |
| 25 | + LaunchConfig, |
| 26 | + Program, |
| 27 | + Stream, |
| 28 | + system, |
| 29 | +) |
19 | 30 | from cuda.core._graph._graphdef import GraphDef |
20 | 31 | from cuda.core._program import _can_load_generated_ptx |
21 | 32 |
|
@@ -208,6 +219,30 @@ def sample_kernel_alt(sample_object_code_alt): |
208 | 219 | return sample_object_code_alt.get_kernel("test_kernel_alt") |
209 | 220 |
|
210 | 221 |
|
| 222 | +# ============================================================================= |
| 223 | +# Fixtures - IPC samples (for pickle tests) |
| 224 | +# ============================================================================= |
| 225 | + |
| 226 | +POOL_SIZE = 2097152 |
| 227 | + |
| 228 | + |
| 229 | +@pytest.fixture |
| 230 | +def sample_ipc_buffer_descriptor(ipc_device): |
| 231 | + """An IPCBufferDescriptor.""" |
| 232 | + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) |
| 233 | + mr = DeviceMemoryResource(ipc_device, options=options) |
| 234 | + buf = mr.allocate(64) |
| 235 | + return buf.get_ipc_descriptor() |
| 236 | + |
| 237 | + |
| 238 | +@pytest.fixture |
| 239 | +def sample_ipc_event_descriptor(ipc_device): |
| 240 | + """An IPCEventDescriptor.""" |
| 241 | + stream = ipc_device.create_stream() |
| 242 | + e = stream.record(options={"ipc_enabled": True}) |
| 243 | + return e.get_ipc_descriptor() |
| 244 | + |
| 245 | + |
211 | 246 | # ============================================================================= |
212 | 247 | # Fixtures - Graph types (GraphDef and GraphNode) |
213 | 248 | # ============================================================================= |
@@ -606,6 +641,20 @@ def sample_switch_node_alt(sample_graphdef): |
606 | 641 | ("sample_kernel", lambda k: Kernel.from_handle(int(k.handle))), |
607 | 642 | ] |
608 | 643 |
|
| 644 | +# Types with __reduce__ support (pickle/cloudpickle). |
| 645 | +# Event, Buffer, and memory resources are excluded: Event only supports |
| 646 | +# IPC-based serialization via multiprocessing reduction; Buffer and memory |
| 647 | +# resource __reduce__ use a cross-process registry that doesn't support |
| 648 | +# same-process roundtrips. |
| 649 | +PICKLE_TYPES = [ |
| 650 | + "sample_device", |
| 651 | + "sample_object_code_cubin", |
| 652 | + "sample_ipc_buffer_descriptor", |
| 653 | + "sample_ipc_event_descriptor", |
| 654 | +] |
| 655 | + |
| 656 | +PICKLE_MODULES = ["pickle", "cloudpickle"] |
| 657 | + |
609 | 658 | # Derived type groupings for collection tests |
610 | 659 | DICT_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES)) |
611 | 660 | WEAK_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES) & set(WEAKREF_TYPES)) |
@@ -796,3 +845,18 @@ def test_repr_format(fixture_name, pattern, request): |
796 | 845 | obj = request.getfixturevalue(fixture_name) |
797 | 846 | result = repr(obj) |
798 | 847 | assert re.fullmatch(pattern, result) |
| 848 | + |
| 849 | + |
| 850 | +# ============================================================================= |
| 851 | +# Pickle tests |
| 852 | +# ============================================================================= |
| 853 | + |
| 854 | + |
| 855 | +@pytest.mark.parametrize("pickle_module", PICKLE_MODULES) |
| 856 | +@pytest.mark.parametrize("fixture_name", PICKLE_TYPES) |
| 857 | +def test_pickle_roundtrip(fixture_name, pickle_module, request): |
| 858 | + """Object survives a pickle/cloudpickle roundtrip.""" |
| 859 | + mod = pytest.importorskip(pickle_module) |
| 860 | + obj = request.getfixturevalue(fixture_name) |
| 861 | + result = mod.loads(mod.dumps(obj)) |
| 862 | + assert type(result) is type(obj) |
0 commit comments