|
4 | 4 | import multiprocessing |
5 | 5 | from itertools import cycle |
6 | 6 |
|
7 | | -from cuda.core.experimental import Buffer, Device, DeviceMemoryResource |
| 7 | +from cuda.core.experimental import Buffer, Device, DeviceMemoryResource, DeviceMemoryResourceOptions |
8 | 8 | from utility import IPCBufferTestHelper |
9 | 9 |
|
10 | 10 | CHILD_TIMEOUT_SEC = 4 |
@@ -43,9 +43,8 @@ def test_ipc_workerpool(self, device, ipc_memory_resource): |
43 | 43 |
|
44 | 44 | def test_ipc_workerpool_multi_mr(self, device, ipc_memory_resource): |
45 | 45 | """Test IPC with a worker pool using multiple memory resources.""" |
46 | | - mrs = [ipc_memory_resource] + [ |
47 | | - DeviceMemoryResource(device, dict(max_size=POOL_SIZE, ipc_enabled=True)) for _ in range(NMRS - 1) |
48 | | - ] |
| 46 | + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) |
| 47 | + mrs = [ipc_memory_resource] + [DeviceMemoryResource(device, options=options) for _ in range(NMRS - 1)] |
49 | 48 | buffers = [mr.allocate(NBYTES) for mr, _ in zip(cycle(mrs), range(NTASKS))] |
50 | 49 | with multiprocessing.Pool(processes=NWORKERS, initializer=self.init_worker, initargs=(mrs,)) as pool: |
51 | 50 | pool.starmap( |
@@ -88,9 +87,8 @@ def test_ipc_workerpool(self, device, ipc_memory_resource): |
88 | 87 |
|
89 | 88 | def test_ipc_workerpool_multi_mr(self, device, ipc_memory_resource): |
90 | 89 | """Test IPC with a worker pool using multiple memory resources.""" |
91 | | - mrs = [ipc_memory_resource] + [ |
92 | | - DeviceMemoryResource(device, dict(max_size=POOL_SIZE, ipc_enabled=True)) for _ in range(NMRS - 1) |
93 | | - ] |
| 90 | + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) |
| 91 | + mrs = [ipc_memory_resource] + [DeviceMemoryResource(device, options=options) for _ in range(NMRS - 1)] |
94 | 92 | buffers = [mr.allocate(NBYTES) for mr, _ in zip(cycle(mrs), range(NTASKS))] |
95 | 93 | with multiprocessing.Pool(processes=NWORKERS, initializer=self.init_worker, initargs=(mrs,)) as pool: |
96 | 94 | pool.map(self.process_buffer, buffers) |
|
0 commit comments