Skip to content

Commit 1f99cdf

Browse files
authored
fix(core): unit-test peer access planning without multi-GPU hardware (#1773)
Extract the peer-access transition planning from DeviceMemoryResource so stale-state regressions can be covered on single-GPU systems. Keep the existing multi-GPU integration tests for end-to-end peer access behavior. Made-with: Cursor
1 parent b4a704c commit 1f99cdf

File tree

4 files changed

+142
-42
lines changed

4 files changed

+142
-42
lines changed

cuda_core/cuda/core/_memory/_device_memory_resource.pyx

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import multiprocessing
2525
import platform # no-cython-lint
2626
import uuid
2727

28+
from ._peer_access_utils import plan_peer_access_update
2829
from cuda.core._utils.cuda_utils import check_multiprocessing_start_method
2930

3031
__all__ = ['DeviceMemoryResource', 'DeviceMemoryResourceOptions']
@@ -281,17 +282,24 @@ cdef inline _DMR_query_peer_access(DeviceMemoryResource self):
281282
cdef inline _DMR_set_peer_accessible_by(DeviceMemoryResource self, devices):
282283
from .._device import Device
283284

284-
cdef set[int] target_ids = {Device(dev).device_id for dev in devices}
285-
target_ids.discard(self._dev_id)
286285
this_dev = Device(self._dev_id)
287-
cdef list bad = [dev for dev in target_ids if not this_dev.can_access_peer(dev)]
288-
if bad:
289-
raise ValueError(f"Device {self._dev_id} cannot access peer(s): {', '.join(map(str, bad))}")
286+
cdef object resolve_device_id = lambda dev: Device(dev).device_id
287+
cdef object plan
288+
cdef tuple target_ids
289+
cdef tuple to_add
290+
cdef tuple to_rm
290291
if not self._mempool_owned:
291292
_DMR_query_peer_access(self)
292-
cdef set[int] cur_ids = set(self._peer_accessible_by)
293-
cdef set[int] to_add = target_ids - cur_ids
294-
cdef set[int] to_rm = cur_ids - target_ids
293+
plan = plan_peer_access_update(
294+
owner_device_id=self._dev_id,
295+
current_peer_ids=self._peer_accessible_by,
296+
requested_devices=devices,
297+
resolve_device_id=resolve_device_id,
298+
can_access_peer=this_dev.can_access_peer,
299+
)
300+
target_ids = plan.target_ids
301+
to_add = plan.to_add
302+
to_rm = plan.to_remove
295303
cdef size_t count = len(to_add) + len(to_rm)
296304
cdef cydriver.CUmemAccessDesc* access_desc = NULL
297305
cdef size_t i = 0
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
from collections.abc import Callable, Iterable
8+
from dataclasses import dataclass
9+
10+
11+
@dataclass(frozen=True)
12+
class PeerAccessPlan:
13+
"""Normalized peer-access target state and the driver updates it requires."""
14+
15+
target_ids: tuple[int, ...]
16+
to_add: tuple[int, ...]
17+
to_remove: tuple[int, ...]
18+
19+
20+
def normalize_peer_access_targets(
21+
owner_device_id: int,
22+
requested_devices: Iterable[object],
23+
*,
24+
resolve_device_id: Callable[[object], int],
25+
) -> tuple[int, ...]:
26+
"""Return sorted, unique peer device IDs, excluding the owner device."""
27+
28+
target_ids = {resolve_device_id(device) for device in requested_devices}
29+
target_ids.discard(owner_device_id)
30+
return tuple(sorted(target_ids))
31+
32+
33+
def plan_peer_access_update(
34+
owner_device_id: int,
35+
current_peer_ids: Iterable[int],
36+
requested_devices: Iterable[object],
37+
*,
38+
resolve_device_id: Callable[[object], int],
39+
can_access_peer: Callable[[int], bool],
40+
) -> PeerAccessPlan:
41+
"""Compute the peer-access target state and add/remove deltas."""
42+
43+
target_ids = normalize_peer_access_targets(
44+
owner_device_id,
45+
requested_devices,
46+
resolve_device_id=resolve_device_id,
47+
)
48+
bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id))
49+
if bad:
50+
bad_ids = ", ".join(str(dev_id) for dev_id in bad)
51+
raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}")
52+
53+
current_ids = set(current_peer_ids)
54+
target_id_set = set(target_ids)
55+
return PeerAccessPlan(
56+
target_ids=target_ids,
57+
to_add=tuple(sorted(target_id_set - current_ids)),
58+
to_remove=tuple(sorted(current_ids - target_id_set)),
59+
)

cuda_core/tests/test_memory_peer_access.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
from helpers.buffers import PatternGen, compare_buffer_to_constant, make_scratch_buffer
66

7-
import cuda.core
87
from cuda.core import DeviceMemoryResource, DeviceMemoryResourceOptions
98
from cuda.core._utils.cuda_utils import CUDAError
109

@@ -48,39 +47,6 @@ def test_peer_access_basic(mempool_device_x2):
4847
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
4948

5049

51-
def test_peer_access_property_x2(mempool_device_x2):
52-
"""The the dmr.peer_accessible_by property (but not its functionality)."""
53-
# The peer access list is a sorted tuple and always excludes the self
54-
# device.
55-
dev0, dev1 = mempool_device_x2
56-
# Use owned pool to ensure clean initial state (no stale peer access).
57-
dmr = DeviceMemoryResource(dev0, DeviceMemoryResourceOptions())
58-
59-
def check(expected):
60-
assert isinstance(dmr.peer_accessible_by, tuple)
61-
assert dmr.peer_accessible_by == expected
62-
63-
# No access to begin with.
64-
check(expected=())
65-
# fmt: off
66-
dmr.peer_accessible_by = (0,) ; check(expected=())
67-
dmr.peer_accessible_by = (1,) ; check(expected=(1,))
68-
dmr.peer_accessible_by = (0, 1) ; check(expected=(1,))
69-
dmr.peer_accessible_by = () ; check(expected=())
70-
dmr.peer_accessible_by = [0, 1] ; check(expected=(1,))
71-
dmr.peer_accessible_by = set() ; check(expected=())
72-
dmr.peer_accessible_by = [1, 1, 1, 1, 1] ; check(expected=(1,))
73-
# fmt: on
74-
75-
with pytest.raises(ValueError, match=r"device_id must be \>\= 0"):
76-
dmr.peer_accessible_by = [-1] # device ID out of bounds
77-
78-
num_devices = len(cuda.core.Device.get_all_devices())
79-
80-
with pytest.raises(ValueError, match=r"device_id must be within \[0, \d+\)"):
81-
dmr.peer_accessible_by = [num_devices] # device ID out of bounds
82-
83-
8450
def test_peer_access_transitions(mempool_device_x3):
8551
"""Advanced tests for dmr.peer_accessible_by."""
8652

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
8+
import pytest
9+
10+
from cuda.core._memory._peer_access_utils import PeerAccessPlan, plan_peer_access_update
11+
12+
13+
@dataclass(frozen=True)
14+
class DummyDevice:
15+
device_id: int
16+
17+
18+
def _resolve_device_id(device) -> int:
19+
if isinstance(device, DummyDevice):
20+
return device.device_id
21+
return int(device)
22+
23+
24+
def test_plan_peer_access_update_normalizes_requests():
25+
plan = plan_peer_access_update(
26+
owner_device_id=1,
27+
current_peer_ids=(),
28+
requested_devices=[1, DummyDevice(3), 2, DummyDevice(2), 3],
29+
resolve_device_id=_resolve_device_id,
30+
can_access_peer=lambda _device_id: True,
31+
)
32+
33+
assert plan == PeerAccessPlan(
34+
target_ids=(2, 3),
35+
to_add=(2, 3),
36+
to_remove=(),
37+
)
38+
39+
40+
def test_plan_peer_access_update_rejects_inaccessible_peers():
41+
with pytest.raises(ValueError, match=r"Device 0 cannot access peer\(s\): 2, 4"):
42+
plan_peer_access_update(
43+
owner_device_id=0,
44+
current_peer_ids=(1,),
45+
requested_devices=[4, 0, DummyDevice(2), 1],
46+
resolve_device_id=_resolve_device_id,
47+
can_access_peer=lambda device_id: device_id == 1,
48+
)
49+
50+
51+
def test_plan_peer_access_update_covers_all_state_transitions():
52+
states = [(), (1,), (2,), (1, 2)]
53+
for current_state in states:
54+
for requested_state in states:
55+
plan = plan_peer_access_update(
56+
owner_device_id=0,
57+
current_peer_ids=current_state,
58+
requested_devices=requested_state,
59+
resolve_device_id=_resolve_device_id,
60+
can_access_peer=lambda device_id: device_id in {1, 2},
61+
)
62+
63+
assert plan == PeerAccessPlan(
64+
target_ids=requested_state,
65+
to_add=tuple(sorted(set(requested_state) - set(current_state))),
66+
to_remove=tuple(sorted(set(current_state) - set(requested_state))),
67+
)

0 commit comments

Comments
 (0)