Skip to content

Commit 579b86c

Browse files
cpcloudcursoragent
andcommitted
Redesign PCH API: auto-retry on heap exhaustion, add pch_status property
Replace the low-level 1:1 wrappers (get_pch_create_status, get_pch_heap_size_required) with a higher-level interface: - compile() now automatically resizes the PCH heap and retries with a fresh NVRTC program when PCH creation fails due to heap exhaustion. - program.pch_status property returns a clean string ("created", "not_attempted", "failed") or None. - get_pch_heap_size() / set_pch_heap_size() retained for manual control. - Factored compile+extract into _nvrtc_compile_and_extract() to support the retry path without duplicating code. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 045b33a commit 579b86c

File tree

4 files changed

+111
-109
lines changed

4 files changed

+111
-109
lines changed

cuda_core/cuda/core/_program.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ cdef class Program:
1616
object _compile_lock # Per-instance lock for compile-time mutation
1717
bint _use_libdevice # Flag for libdevice loading
1818
bint _libdevice_added
19+
bytes _nvrtc_code # Source code for NVRTC retry (PCH auto-resize)
20+
str _pch_status # PCH creation outcome after compile

cuda_core/cuda/core/_program.pyx

Lines changed: 92 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -105,52 +105,24 @@ cdef class Program:
105105
"""
106106
return Program_compile(self, target_type, name_expressions, logs)
107107

108-
def get_pch_create_status(self):
109-
"""Return NVRTC's PCH creation status for the latest compile.
110-
111-
Returns the ``nvrtcResult`` enum value describing the outcome of
112-
the most recent PCH creation attempt. Possible values include
113-
``NVRTC_SUCCESS``, ``NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED``,
114-
``NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED``, and
115-
``NVRTC_ERROR_PCH_CREATE``.
116-
117-
This API is only available for ``code_type="c++"`` programs compiled
118-
through NVRTC, and requires CUDA 12.8+ NVRTC bindings.
119-
120-
Returns
121-
-------
122-
:class:`~cuda.bindings.nvrtc.nvrtcResult`
123-
The PCH creation status.
124-
125-
Raises
126-
------
127-
NVRTCError
128-
If the program handle is invalid.
129-
"""
130-
return Program_get_pch_create_status(self)
131-
132-
def get_pch_heap_size_required(self):
133-
"""Return the PCH heap size required to compile this program.
134-
135-
This is only meaningful after a compile where
136-
:meth:`get_pch_create_status` returned
137-
``NVRTC_SUCCESS`` or ``NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED``.
108+
@property
109+
def pch_status(self) -> str | None:
110+
"""PCH creation outcome from the most recent :meth:`compile` call.
138111

139-
This API is only available for ``code_type="c++"`` programs compiled
140-
through NVRTC, and requires CUDA 12.8+ NVRTC bindings.
112+
Possible values:
141113

142-
Returns
143-
-------
144-
int
145-
Required PCH heap size in bytes.
114+
* ``"created"`` — PCH file was written successfully.
115+
* ``"not_attempted"`` — PCH creation was not attempted (e.g. the
116+
compiler decided not to, or automatic PCH processing skipped it).
117+
* ``"failed"`` — an error prevented PCH creation.
118+
* ``None`` — PCH was not requested, or the program has not been
119+
compiled yet, or the NVRTC bindings are too old to report status.
146120

147-
Raises
148-
------
149-
NVRTCError
150-
If the program handle is invalid or the query is not
151-
applicable for the current PCH creation status.
121+
When ``create_pch`` is set in :class:`ProgramOptions` and the PCH
122+
heap is too small, :meth:`compile` automatically resizes the heap
123+
and retries, so ``"created"`` should be the common outcome.
152124
"""
153-
return Program_get_pch_heap_size_required(self)
125+
return self._pch_status
154126

155127
@staticmethod
156128
def get_pch_heap_size():
@@ -543,11 +515,6 @@ def _find_libdevice_path():
543515
return find_bitcode_lib("device")
544516

545517

546-
cdef inline void _require_nvrtc_backend_for_pch(Program self) except *:
547-
if self._backend != "NVRTC":
548-
raise RuntimeError("PCH APIs are only available for Program instances using the NVRTC backend")
549-
550-
551518
cdef inline void _require_nvrtc_pch_api(str api_name) except *:
552519
if not hasattr(nvrtc, api_name):
553520
version = get_binding_version()
@@ -557,30 +524,6 @@ cdef inline void _require_nvrtc_pch_api(str api_name) except *:
557524
)
558525

559526

560-
cdef inline object Program_get_pch_create_status(Program self):
561-
_require_nvrtc_backend_for_pch(self)
562-
_require_nvrtc_pch_api("nvrtcGetPCHCreateStatus")
563-
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
564-
cdef cynvrtc.nvrtcResult err
565-
with nogil:
566-
err = cynvrtc.nvrtcGetPCHCreateStatus(prog)
567-
if err == cynvrtc.nvrtcResult.NVRTC_ERROR_INVALID_PROGRAM:
568-
HANDLE_RETURN_NVRTC(prog, err)
569-
return nvrtc.nvrtcResult(err)
570-
571-
572-
cdef inline size_t Program_get_pch_heap_size_required(Program self) except? 0:
573-
_require_nvrtc_backend_for_pch(self)
574-
_require_nvrtc_pch_api("nvrtcGetPCHHeapSizeRequired")
575-
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
576-
cdef size_t size = 0
577-
cdef cynvrtc.nvrtcResult err
578-
with nogil:
579-
err = cynvrtc.nvrtcGetPCHHeapSizeRequired(prog, &size)
580-
HANDLE_RETURN_NVRTC(prog, err)
581-
return size
582-
583-
584527
cdef inline size_t Program_get_pch_heap_size() except? 0:
585528
_require_nvrtc_pch_api("nvrtcGetPCHHeapSize")
586529
cdef size_t size = 0
@@ -666,6 +609,8 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
666609
self._use_libdevice = False
667610
self._libdevice_added = False
668611

612+
self._pch_status = None
613+
669614
if code_type == "c++":
670615
assert_type(code, str)
671616
if options.extra_sources is not None:
@@ -680,6 +625,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
680625
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
681626
&nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL))
682627
self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog)
628+
self._nvrtc_code = code_bytes
683629
self._backend = "NVRTC"
684630
self._linker = None
685631

@@ -767,9 +713,15 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
767713
return 0
768714

769715

770-
cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs):
771-
"""Compile using NVRTC backend and return ObjectCode."""
772-
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
716+
cdef object _nvrtc_compile_and_extract(
717+
cynvrtc.nvrtcProgram prog, str target_type, object name_expressions,
718+
object logs, list options_list, str name,
719+
):
720+
"""Run nvrtcCompileProgram on *prog* and extract the output.
721+
722+
This is the inner compile+extract loop, factored out so the PCH
723+
auto-retry path can call it on a fresh program handle.
724+
"""
773725
cdef size_t output_size = 0
774726
cdef size_t logsize = 0
775727
cdef vector[const char*] options_vec
@@ -787,7 +739,6 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp
787739
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcAddNameExpression(prog, name_ptr))
788740

789741
# Build options array
790-
options_list = self._options.as_bytes("nvrtc", target_type)
791742
options_vec.resize(len(options_list))
792743
for i in range(len(options_list)):
793744
options_vec[i] = <const char*>(<bytes>options_list[i])
@@ -834,7 +785,72 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp
834785
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLog(prog, data_ptr))
835786
logs.write(log.decode("utf-8", errors="backslashreplace"))
836787

837-
return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=self._options.name)
788+
return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=name)
789+
790+
791+
cdef bint _has_nvrtc_pch_apis():
792+
return hasattr(nvrtc, "nvrtcGetPCHCreateStatus")
793+
794+
795+
cdef str _PCH_STATUS_CREATED = "created"
796+
cdef str _PCH_STATUS_NOT_ATTEMPTED = "not_attempted"
797+
cdef str _PCH_STATUS_FAILED = "failed"
798+
799+
800+
cdef str _read_pch_status(cynvrtc.nvrtcProgram prog):
801+
"""Query nvrtcGetPCHCreateStatus and translate to a high-level string."""
802+
cdef cynvrtc.nvrtcResult err
803+
with nogil:
804+
err = cynvrtc.nvrtcGetPCHCreateStatus(prog)
805+
if err == cynvrtc.nvrtcResult.NVRTC_SUCCESS:
806+
return _PCH_STATUS_CREATED
807+
if err == cynvrtc.nvrtcResult.NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED:
808+
return None # sentinel: caller should auto-retry
809+
if err == cynvrtc.nvrtcResult.NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED:
810+
return _PCH_STATUS_NOT_ATTEMPTED
811+
return _PCH_STATUS_FAILED
812+
813+
814+
cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs):
815+
"""Compile using NVRTC backend and return ObjectCode."""
816+
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
817+
cdef list options_list = self._options.as_bytes("nvrtc", target_type)
818+
819+
result = _nvrtc_compile_and_extract(
820+
prog, target_type, name_expressions, logs, options_list, self._options.name,
821+
)
822+
823+
if not self._options.create_pch or not _has_nvrtc_pch_apis():
824+
self._pch_status = None
825+
return result
826+
827+
# PCH was requested — check creation status
828+
cdef str status = _read_pch_status(prog)
829+
if status is not None:
830+
self._pch_status = status
831+
return result
832+
833+
# Heap exhausted — auto-resize and retry with a fresh program
834+
cdef size_t required = 0
835+
with nogil:
836+
cynvrtc.nvrtcGetPCHHeapSizeRequired(prog, &required)
837+
cynvrtc.nvrtcSetPCHHeapSize(required)
838+
839+
cdef cynvrtc.nvrtcProgram retry_prog
840+
cdef const char* code_ptr = <const char*>self._nvrtc_code
841+
cdef const char* name_ptr = <const char*>self._options._name
842+
with nogil:
843+
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
844+
&retry_prog, code_ptr, name_ptr, 0, NULL, NULL))
845+
self._h_nvrtc = create_nvrtc_program_handle(retry_prog)
846+
847+
result = _nvrtc_compile_and_extract(
848+
retry_prog, target_type, name_expressions, logs, options_list, self._options.name,
849+
)
850+
851+
status = _read_pch_status(retry_prog)
852+
self._pch_status = status if status is not None else _PCH_STATUS_FAILED
853+
return result
838854

839855

840856
cdef object Program_compile_nvvm(Program self, str target_type, object logs):

cuda_core/docs/source/release/0.6.0-notes.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ New features
5454
- Added CUDA version compatibility check at import time to detect mismatches between
5555
``cuda.core`` and the installed ``cuda-bindings`` version.
5656

57-
- Added NVRTC PCH runtime APIs on ``Program``: ``get_pch_create_status()``,
58-
``get_pch_heap_size_required()``, ``get_pch_heap_size()``, and
59-
``set_pch_heap_size()``.
57+
- ``Program.compile()`` now automatically resizes the NVRTC PCH heap and
58+
retries when precompiled header creation fails due to heap exhaustion.
59+
The ``pch_status`` property reports the PCH creation outcome
60+
(``"created"``, ``"not_attempted"``, ``"failed"``, or ``None``).
61+
``Program.get_pch_heap_size()`` and ``Program.set_pch_heap_size()``
62+
are available for manual heap management.
6063

6164

6265
Fixes and enhancements

cuda_core/tests/test_program.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,7 @@ def _has_nvrtc_pch_apis_for_tests():
6464
"nvrtcGetPCHCreateStatus",
6565
"nvrtcGetPCHHeapSizeRequired",
6666
)
67-
try:
68-
return all(hasattr(nvrtc, name) for name in required)
69-
except Exception:
70-
return False
67+
return all(hasattr(nvrtc, name) for name in required)
7168

7269

7370
nvrtc_pch_available = pytest.mark.skipif(
@@ -336,28 +333,13 @@ def test_cpp_program_with_pch_options(init_cuda, tmp_path):
336333

337334

338335
@nvrtc_pch_available
339-
def test_cpp_program_pch_runtime_apis(init_cuda, tmp_path):
336+
def test_cpp_program_pch_auto_creates(init_cuda, tmp_path):
340337
code = 'extern "C" __global__ void my_kernel() {}'
341-
options = ProgramOptions(create_pch=str(tmp_path / "test.pch"))
342-
program = Program(code, "c++", options)
343-
assert program.backend == "NVRTC"
338+
pch_path = str(tmp_path / "test.pch")
339+
program = Program(code, "c++", ProgramOptions(create_pch=pch_path))
340+
assert program.pch_status is None # not compiled yet
344341
program.compile("ptx")
345-
346-
status = program.get_pch_create_status()
347-
valid_status_names = (
348-
"NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED",
349-
"NVRTC_ERROR_PCH_CREATE",
350-
"NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED",
351-
)
352-
valid_statuses = {nvrtc.nvrtcResult.NVRTC_SUCCESS}
353-
valid_statuses.update(
354-
getattr(nvrtc.nvrtcResult, name) for name in valid_status_names if hasattr(nvrtc.nvrtcResult, name)
355-
)
356-
assert status in valid_statuses
357-
358-
required_heap_size = program.get_pch_heap_size_required()
359-
assert isinstance(required_heap_size, int)
360-
assert required_heap_size > 0
342+
assert program.pch_status in ("created", "not_attempted", "failed")
361343
program.close()
362344

363345

@@ -378,12 +360,11 @@ def test_cpp_program_pch_set_heap_size_rejects_negative():
378360
Program.set_pch_heap_size(-1)
379361

380362

381-
def test_cpp_program_pch_runtime_apis_require_nvrtc_backend(init_cuda, ptx_code_object):
382-
program = Program(ptx_code_object.code.decode(), "ptx")
383-
with pytest.raises(RuntimeError, match="only available for Program instances using the NVRTC backend"):
384-
program.get_pch_create_status()
385-
with pytest.raises(RuntimeError, match="only available for Program instances using the NVRTC backend"):
386-
program.get_pch_heap_size_required()
363+
def test_cpp_program_pch_status_none_without_pch(init_cuda):
364+
code = 'extern "C" __global__ void my_kernel() {}'
365+
program = Program(code, "c++")
366+
program.compile("ptx")
367+
assert program.pch_status is None
387368
program.close()
388369

389370

0 commit comments

Comments
 (0)