Skip to content

Commit 4c0c375

Browse files
committed
address PR feedback to not use enums and instead use strings as a more pythonic version
1 parent 0e26ea8 commit 4c0c375

File tree

4 files changed

+83
-50
lines changed

4 files changed

+83
-50
lines changed

cuda_core/cuda/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
GraphCompleteOptions,
3838
GraphDebugPrintOptions,
3939
)
40-
from cuda.core._graphics import GraphicsRegisterFlags, GraphicsResource # noqa: E402
40+
from cuda.core._graphics import GraphicsResource # noqa: E402
4141
from cuda.core._launch_config import LaunchConfig # noqa: E402
4242
from cuda.core._launcher import launch # noqa: E402
4343
from cuda.core._layout import _StridedLayout # noqa: E402

cuda_core/cuda/core/_graphics.pyx

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,34 @@ from cuda.core._resource_handles cimport (
1313
from cuda.core._stream cimport Stream, Stream_accept
1414
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
1515

16-
from enum import IntEnum
17-
1816
from cuda.core._memory import Buffer
1917

20-
__all__ = ['GraphicsResource', 'GraphicsRegisterFlags']
21-
22-
23-
class GraphicsRegisterFlags(IntEnum):
24-
"""Flags for registering a graphics resource with CUDA.
25-
26-
These flags specify the intended usage when registering a graphics
27-
resource (e.g., an OpenGL buffer) for CUDA access.
28-
"""
29-
NONE = cydriver.CU_GRAPHICS_REGISTER_FLAGS_NONE
30-
"""No hints about how this resource will be used. CUDA may read and write."""
31-
READ_ONLY = cydriver.CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY
32-
"""CUDA will not write to this resource."""
33-
WRITE_DISCARD = cydriver.CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD
34-
"""CUDA will not read from this resource and will write over the entire contents."""
35-
SURFACE_LOAD_STORE = cydriver.CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST
36-
"""CUDA will bind this resource to a surface reference."""
37-
TEXTURE_GATHER = cydriver.CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER
38-
"""CUDA will perform texture gather operations on this resource."""
18+
__all__ = ['GraphicsResource']
19+
20+
_REGISTER_FLAGS = {
21+
"none": cydriver.CU_GRAPHICS_REGISTER_FLAGS_NONE,
22+
"read_only": cydriver.CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY,
23+
"write_discard": cydriver.CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD,
24+
"surface_load_store": cydriver.CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST,
25+
"texture_gather": cydriver.CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER,
26+
}
27+
28+
29+
def _parse_register_flags(flags):
30+
if flags is None:
31+
return 0
32+
if isinstance(flags, str):
33+
flags = (flags,)
34+
result = 0
35+
for f in flags:
36+
try:
37+
result |= _REGISTER_FLAGS[f]
38+
except KeyError:
39+
raise ValueError(
40+
f"Unknown register flag {f!r}. "
41+
f"Valid flags: {', '.join(sorted(_REGISTER_FLAGS))}"
42+
) from None
43+
return result
3944

4045

4146
class _MappedBufferContext:
@@ -114,16 +119,20 @@ cdef class GraphicsResource:
114119
)
115120

116121
@classmethod
117-
def from_gl_buffer(cls, int gl_buffer, *, int flags=0) -> GraphicsResource:
122+
def from_gl_buffer(cls, int gl_buffer, *, flags=None) -> GraphicsResource:
118123
"""Register an OpenGL buffer object for CUDA access.
119124

120125
Parameters
121126
----------
122127
gl_buffer : int
123128
The OpenGL buffer name (``GLuint``) to register.
124-
flags : int, optional
125-
Registration flags from :class:`GraphicsRegisterFlags`.
126-
Defaults to :attr:`GraphicsRegisterFlags.NONE`.
129+
flags : str or sequence of str, optional
130+
Registration flags specifying intended usage. Accepted values:
131+
``"none"``, ``"read_only"``, ``"write_discard"``,
132+
``"surface_load_store"``, ``"texture_gather"``.
133+
Multiple flags can be combined by passing a sequence
134+
(e.g., ``("surface_load_store", "read_only")``).
135+
Defaults to ``None`` (no flags).
127136

128137
Returns
129138
-------
@@ -135,21 +144,24 @@ cdef class GraphicsResource:
135144
CUDAError
136145
If the registration fails (e.g., no current GL context, invalid
137146
buffer name, or operating system error).
147+
ValueError
148+
If an unknown flag string is provided.
138149
"""
139150
cdef GraphicsResource self = GraphicsResource.__new__(cls)
140151
cdef cydriver.CUgraphicsResource resource
141152
cdef cydriver.GLuint cy_buffer = <cydriver.GLuint>gl_buffer
153+
cdef unsigned int cy_flags = _parse_register_flags(flags)
142154
with nogil:
143155
HANDLE_RETURN(
144-
cydriver.cuGraphicsGLRegisterBuffer(&resource, cy_buffer, <unsigned int>flags)
156+
cydriver.cuGraphicsGLRegisterBuffer(&resource, cy_buffer, cy_flags)
145157
)
146158
self._handle = create_graphics_resource_handle(resource)
147159
self._mapped = False
148160
return self
149161

150162
@classmethod
151163
def from_gl_image(
152-
cls, int image, int target, *, int flags=0
164+
cls, int image, int target, *, flags=None
153165
) -> GraphicsResource:
154166
"""Register an OpenGL texture or renderbuffer for CUDA access.
155167

@@ -159,9 +171,13 @@ cdef class GraphicsResource:
159171
The OpenGL texture or renderbuffer name (``GLuint``) to register.
160172
target : int
161173
The OpenGL target type (e.g., ``GL_TEXTURE_2D``).
162-
flags : int, optional
163-
Registration flags from :class:`GraphicsRegisterFlags`.
164-
Defaults to :attr:`GraphicsRegisterFlags.NONE`.
174+
flags : str or sequence of str, optional
175+
Registration flags specifying intended usage. Accepted values:
176+
``"none"``, ``"read_only"``, ``"write_discard"``,
177+
``"surface_load_store"``, ``"texture_gather"``.
178+
Multiple flags can be combined by passing a sequence
179+
(e.g., ``("surface_load_store", "read_only")``).
180+
Defaults to ``None`` (no flags).
165181

166182
Returns
167183
-------
@@ -172,14 +188,17 @@ cdef class GraphicsResource:
172188
------
173189
CUDAError
174190
If the registration fails.
191+
ValueError
192+
If an unknown flag string is provided.
175193
"""
176194
cdef GraphicsResource self = GraphicsResource.__new__(cls)
177195
cdef cydriver.CUgraphicsResource resource
178196
cdef cydriver.GLuint cy_image = <cydriver.GLuint>image
179197
cdef cydriver.GLenum cy_target = <cydriver.GLenum>target
198+
cdef unsigned int cy_flags = _parse_register_flags(flags)
180199
with nogil:
181200
HANDLE_RETURN(
182-
cydriver.cuGraphicsGLRegisterImage(&resource, cy_image, cy_target, <unsigned int>flags)
201+
cydriver.cuGraphicsGLRegisterImage(&resource, cy_image, cy_target, cy_flags)
183202
)
184203
self._handle = create_graphics_resource_handle(resource)
185204
self._mapped = False

cuda_core/examples/gl_interop_plasma.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
import numpy as np
6565
from cuda.core import (
6666
Device,
67-
GraphicsRegisterFlags,
6867
GraphicsResource,
6968
LaunchConfig,
7069
Program,
@@ -296,7 +295,7 @@ def main():
296295
# THIS IS THE KEY LINE. GraphicsResource.from_gl_buffer() tells the
297296
# CUDA driver "I want to access this OpenGL buffer from CUDA kernels."
298297
# WRITE_DISCARD means CUDA will overwrite the entire buffer each frame.
299-
resource = GraphicsResource.from_gl_buffer(pbo_id, flags=GraphicsRegisterFlags.WRITE_DISCARD)
298+
resource = GraphicsResource.from_gl_buffer(pbo_id, flags="write_discard")
300299

301300
# --- Step 6: Render loop ---
302301
start_time = time.monotonic()

cuda_core/tests/test_graphics.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from cuda.core import (
1515
Buffer,
1616
Device,
17-
GraphicsRegisterFlags,
1817
GraphicsResource,
1918
StridedMemoryView,
2019
)
@@ -140,17 +139,33 @@ def _gl_context_and_texture(width=16, height=16):
140139

141140

142141
# ---------------------------------------------------------------------------
143-
# GraphicsRegisterFlags tests
142+
# Register flags parsing tests
144143
# ---------------------------------------------------------------------------
145144

146145

147-
class TestGraphicsRegisterFlags:
148-
def test_enum_values(self):
149-
assert int(GraphicsRegisterFlags.NONE) == 0
150-
assert int(GraphicsRegisterFlags.READ_ONLY) == 1
151-
assert int(GraphicsRegisterFlags.WRITE_DISCARD) == 2
152-
assert int(GraphicsRegisterFlags.SURFACE_LOAD_STORE) == 4
153-
assert int(GraphicsRegisterFlags.TEXTURE_GATHER) == 8
146+
class TestRegisterFlags:
147+
def test_parse_none(self):
148+
from cuda.core._graphics import _parse_register_flags
149+
150+
assert _parse_register_flags(None) == 0
151+
152+
def test_parse_single_string(self):
153+
from cuda.core._graphics import _parse_register_flags
154+
155+
assert _parse_register_flags("read_only") == 1
156+
assert _parse_register_flags("write_discard") == 2
157+
158+
def test_parse_combined_flags(self):
159+
from cuda.core._graphics import _parse_register_flags
160+
161+
result = _parse_register_flags(("surface_load_store", "read_only"))
162+
assert result == 4 | 1
163+
164+
def test_parse_invalid_raises(self):
165+
from cuda.core._graphics import _parse_register_flags
166+
167+
with pytest.raises(ValueError, match="Unknown register flag"):
168+
_parse_register_flags("bogus")
154169

155170

156171
# ---------------------------------------------------------------------------
@@ -179,7 +194,7 @@ def test_register_default_flags(self):
179194

180195
def test_register_write_discard(self):
181196
with _gl_context_and_buffer() as (gl_buf, nbytes):
182-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
197+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
183198
assert resource.handle != 0
184199
resource.close()
185200

@@ -212,7 +227,7 @@ def test_register_image(self):
212227
class TestMapUnmap:
213228
def test_map_returns_buffer(self):
214229
with _gl_context_and_buffer(nbytes=4096) as (gl_buf, nbytes):
215-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
230+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
216231
mapped = resource.map()
217232
assert resource.is_mapped
218233
# mapped is a _MappedBufferContext; its .handle and .size delegate to Buffer
@@ -224,7 +239,7 @@ def test_map_returns_buffer(self):
224239

225240
def test_context_manager_unmaps(self):
226241
with _gl_context_and_buffer(nbytes=4096) as (gl_buf, nbytes):
227-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
242+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
228243
with resource.map() as buf:
229244
assert isinstance(buf, Buffer)
230245
assert resource.is_mapped
@@ -234,7 +249,7 @@ def test_context_manager_unmaps(self):
234249

235250
def test_context_manager_unmaps_on_exception(self):
236251
with _gl_context_and_buffer(nbytes=4096) as (gl_buf, nbytes):
237-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
252+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
238253
with pytest.raises(ValueError, match="test error"), resource.map() as _buf:
239254
assert resource.is_mapped
240255
raise ValueError("test error")
@@ -246,7 +261,7 @@ def test_strided_memory_view_from_mapped_buffer(self):
246261
"""End-to-end: register, map, create StridedMemoryView."""
247262
nbytes = 256 * 4 # 256 float32 elements
248263
with _gl_context_and_buffer(nbytes=nbytes) as (gl_buf, _):
249-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
264+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
250265
with resource.map() as buf:
251266
view = StridedMemoryView.from_buffer(buf, shape=(256,), dtype=np.float32)
252267
assert view.ptr == int(buf.handle)
@@ -259,7 +274,7 @@ def test_map_with_stream(self):
259274
dev = Device(0)
260275
dev.set_current()
261276
stream = dev.create_stream()
262-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
277+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
263278
with resource.map(stream=stream) as buf:
264279
assert buf.size > 0
265280
resource.close()
@@ -304,7 +319,7 @@ def test_unmap_after_close_raises(self):
304319
def test_close_while_mapped(self):
305320
"""close() should unmap before unregistering."""
306321
with _gl_context_and_buffer() as (gl_buf, nbytes):
307-
resource = GraphicsResource.from_gl_buffer(gl_buf, flags=GraphicsRegisterFlags.WRITE_DISCARD)
322+
resource = GraphicsResource.from_gl_buffer(gl_buf, flags="write_discard")
308323
resource.map()
309324
assert resource.is_mapped
310325
resource.close() # Should unmap + unregister without error

0 commit comments

Comments
 (0)