Skip to content

Commit e5ea645

Browse files
committed
Simplify GraphMemoryResourceAttributes.
1 parent 4ac5c99 commit e5ea645

File tree

2 files changed

+43
-46
lines changed

2 files changed

+43
-46
lines changed

cuda_core/cuda/core/experimental/_memory/_graph_memory_resource.pyx

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from libc.stdint cimport uintptr_t, intptr_t, uint64_t
7+
from libc.stdint cimport intptr_t
88

99
from cuda.bindings cimport cydriver
1010
from cuda.core.experimental._memory._buffer cimport Buffer, MemoryResource
@@ -14,8 +14,6 @@ from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
1414
from functools import cache
1515
from typing import TYPE_CHECKING
1616

17-
from cuda.core.experimental._utils.cuda_utils import driver
18-
1917
if TYPE_CHECKING:
2018
from cuda.core.experimental._memory.buffer import DevicePointerT
2119

@@ -41,64 +39,63 @@ cdef class GraphMemoryResourceAttributes:
4139
if not attr.startswith("_")
4240
)
4341

44-
@GMRA_mem_attribute(int)
42+
cdef int _getattribute(self, cydriver.CUgraphMem_attribute attr_enum, void* value) except?-1:
43+
with nogil:
44+
HANDLE_RETURN(cydriver.cuDeviceGetGraphMemAttribute(self._dev_id, attr_enum, value))
45+
return 0
46+
47+
cdef int _setattribute(self, cydriver.CUgraphMem_attribute attr_enum, void* value) except?-1:
48+
with nogil:
49+
HANDLE_RETURN(cydriver.cuDeviceSetGraphMemAttribute(self._dev_id, attr_enum, value))
50+
return 0
51+
52+
@property
4553
def reserved_mem_current(self):
4654
"""Current amount of backing memory allocated."""
55+
cdef cydriver.cuuint64_t value
56+
self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT, &value)
57+
return int(value)
4758

48-
@GMRA_mem_attribute(int, settable=True)
59+
@property
4960
def reserved_mem_high(self):
5061
"""
5162
High watermark of backing memory allocated. It can be set to zero to
5263
reset it to the current usage.
5364
"""
65+
cdef cydriver.cuuint64_t value
66+
self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH, &value)
67+
return int(value)
68+
69+
@reserved_mem_high.setter
70+
def reserved_mem_high(self, value: int):
71+
if value != 0:
72+
raise AttributeError(f"Attribute 'reserved_mem_high' may only be set to zero (got {value}).")
73+
cdef cydriver.cuuint64_t zero = 0
74+
self._setattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH, &zero)
5475

55-
@GMRA_mem_attribute(int)
76+
@property
5677
def used_mem_current(self):
5778
"""Current amount of memory in use."""
79+
cdef cydriver.cuuint64_t value
80+
self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT, &value)
81+
return int(value)
5882

59-
@GMRA_mem_attribute(int, settable=True)
83+
@property
6084
def used_mem_high(self):
6185
"""
6286
High watermark of memory in use. It can be set to zero to reset it to
6387
the current usage.
6488
"""
65-
66-
67-
cdef GMRA_mem_attribute(property_type: type, settable: bool = False):
68-
_settable = settable
69-
70-
def decorator(stub):
71-
attr_enum = getattr(
72-
driver.CUgraphMem_attribute, f"CU_GRAPH_MEM_ATTR_{stub.__name__.upper()}"
73-
)
74-
75-
def fget(GraphMemoryResourceAttributes self) -> property_type:
76-
value = GMRA_getattribute(self._dev_id, <cydriver.CUgraphMem_attribute><uintptr_t> attr_enum)
77-
return property_type(value)
78-
79-
if _settable:
80-
def fset(GraphMemoryResourceAttributes self, uint64_t value):
81-
if value != 0:
82-
raise AttributeError(f"Attribute {stub.__name__!r} may only be set to zero (got {value}).")
83-
GMRA_setattribute(self._dev_id, <cydriver.CUgraphMem_attribute><uintptr_t> attr_enum)
84-
else:
85-
fset = None
86-
87-
return property(fget=fget, fset=fset, doc=stub.__doc__)
88-
return decorator
89-
90-
91-
cdef inline uint64_t GMRA_getattribute(int device_id, cydriver.CUgraphMem_attribute attr_enum):
92-
cdef uint64_t value
93-
with nogil:
94-
HANDLE_RETURN(cydriver.cuDeviceGetGraphMemAttribute(device_id, attr_enum, <void *> &value))
95-
return value
96-
97-
98-
cdef inline void GMRA_setattribute(int device_id, cydriver.CUgraphMem_attribute attr_enum):
99-
cdef uint64_t zero = 0
100-
with nogil:
101-
HANDLE_RETURN(cydriver.cuDeviceSetGraphMemAttribute(device_id, attr_enum, <void *> &zero))
89+
cdef cydriver.cuuint64_t value
90+
self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, &value)
91+
return int(value)
92+
93+
@used_mem_high.setter
94+
def used_mem_high(self, value: int):
95+
if value != 0:
96+
raise AttributeError(f"Attribute 'used_mem_high' may only be set to zero (got {value}).")
97+
cdef cydriver.cuuint64_t zero = 0
98+
self._setattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, &zero)
10299

103100

104101
cdef class cyGraphMemoryResource(MemoryResource):

cuda_core/tests/test_graph_mem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ def test_graph_mem_set_attributes(init_cuda, mode):
169169
assert gmr.attributes.used_mem_high > 0
170170

171171
# Incorrect attribute usage.
172-
with pytest.raises(AttributeError, match=r"property 'reserved_mem_current' .* has no setter"):
172+
with pytest.raises(AttributeError, match=r"attribute 'reserved_mem_current' .* is not writable"):
173173
gmr.attributes.reserved_mem_current = 0
174174

175175
with pytest.raises(AttributeError, match=r"Attribute 'reserved_mem_high' may only be set to zero \(got 1\)\."):
176176
gmr.attributes.reserved_mem_high = 1
177177

178-
with pytest.raises(AttributeError, match=r"property 'used_mem_current' .* has no setter"):
178+
with pytest.raises(AttributeError, match=r"attribute 'used_mem_current' .* is not writable"):
179179
gmr.attributes.used_mem_current = 0
180180

181181
with pytest.raises(AttributeError, match=r"Attribute 'used_mem_high' may only be set to zero \(got 1\)\."):

0 commit comments

Comments
 (0)