Skip to content

Commit 35b5213

Browse files
committed
Add PCH support to cuda.core.Program (#670)
When `create_pch` is set in ProgramOptions, compile() now automatically resizes the NVRTC PCH heap and retries with a fresh program when PCH creation fails due to heap exhaustion. The `pch_status` property reports the outcome ("created", "not_attempted", "failed", or None). Made-with: Cursor
1 parent 8d0ccdd commit 35b5213

File tree

4 files changed

+141
-5
lines changed

4 files changed

+141
-5
lines changed

cuda_core/cuda/core/_program.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ cdef class Program:
1616
object _compile_lock # Per-instance lock for compile-time mutation
1717
bint _use_libdevice # Flag for libdevice loading
1818
bint _libdevice_added
19+
bytes _nvrtc_code # Source code for NVRTC retry (PCH auto-resize)
20+
str _pch_status # PCH creation outcome after compile

cuda_core/cuda/core/_program.pyx

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,25 @@ cdef class Program:
105105
"""
106106
return Program_compile(self, target_type, name_expressions, logs)
107107

108+
@property
109+
def pch_status(self) -> str | None:
110+
"""PCH creation outcome from the most recent :meth:`compile` call.
111+
112+
Possible values:
113+
114+
* ``"created"`` — PCH file was written successfully.
115+
* ``"not_attempted"`` — PCH creation was not attempted (e.g. the
116+
compiler decided not to, or automatic PCH processing skipped it).
117+
* ``"failed"`` — an error prevented PCH creation.
118+
* ``None`` — PCH was not requested, or the program has not been
119+
compiled yet, or the NVRTC bindings are too old to report status.
120+
121+
When ``create_pch`` is set in :class:`ProgramOptions` and the PCH
122+
heap is too small, :meth:`compile` automatically resizes the heap
123+
and retries, so ``"created"`` should be the common outcome.
124+
"""
125+
return self._pch_status
126+
108127
@property
109128
def backend(self) -> str:
110129
"""Return this Program instance's underlying backend."""
@@ -477,6 +496,8 @@ def _find_libdevice_path():
477496
return find_bitcode_lib("device")
478497

479498

499+
500+
480501
cdef inline bint _process_define_macro_inner(list options, object macro) except? -1:
481502
"""Process a single define macro, returning True if successful."""
482503
if isinstance(macro, str):
@@ -548,6 +569,8 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
548569
self._use_libdevice = False
549570
self._libdevice_added = False
550571

572+
self._pch_status = None
573+
551574
if code_type == "c++":
552575
assert_type(code, str)
553576
if options.extra_sources is not None:
@@ -562,6 +585,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
562585
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
563586
&nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL))
564587
self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog)
588+
self._nvrtc_code = code_bytes
565589
self._backend = "NVRTC"
566590
self._linker = None
567591

@@ -649,9 +673,15 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
649673
return 0
650674

651675

652-
cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs):
653-
"""Compile using NVRTC backend and return ObjectCode."""
654-
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
676+
cdef object _nvrtc_compile_and_extract(
677+
cynvrtc.nvrtcProgram prog, str target_type, object name_expressions,
678+
object logs, list options_list, str name,
679+
):
680+
"""Run nvrtcCompileProgram on *prog* and extract the output.
681+
682+
This is the inner compile+extract loop, factored out so the PCH
683+
auto-retry path can call it on a fresh program handle.
684+
"""
655685
cdef size_t output_size = 0
656686
cdef size_t logsize = 0
657687
cdef vector[const char*] options_vec
@@ -669,7 +699,6 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp
669699
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcAddNameExpression(prog, name_ptr))
670700

671701
# Build options array
672-
options_list = self._options.as_bytes("nvrtc", target_type)
673702
options_vec.resize(len(options_list))
674703
for i in range(len(options_list)):
675704
options_vec[i] = <const char*>(<bytes>options_list[i])
@@ -716,7 +745,72 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp
716745
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLog(prog, data_ptr))
717746
logs.write(log.decode("utf-8", errors="backslashreplace"))
718747

719-
return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=self._options.name)
748+
return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=name)
749+
750+
751+
cdef bint _has_nvrtc_pch_apis():
752+
return hasattr(nvrtc, "nvrtcGetPCHCreateStatus")
753+
754+
755+
cdef str _PCH_STATUS_CREATED = "created"
756+
cdef str _PCH_STATUS_NOT_ATTEMPTED = "not_attempted"
757+
cdef str _PCH_STATUS_FAILED = "failed"
758+
759+
760+
cdef str _read_pch_status(cynvrtc.nvrtcProgram prog):
761+
"""Query nvrtcGetPCHCreateStatus and translate to a high-level string."""
762+
cdef cynvrtc.nvrtcResult err
763+
with nogil:
764+
err = cynvrtc.nvrtcGetPCHCreateStatus(prog)
765+
if err == cynvrtc.nvrtcResult.NVRTC_SUCCESS:
766+
return _PCH_STATUS_CREATED
767+
if err == cynvrtc.nvrtcResult.NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED:
768+
return None # sentinel: caller should auto-retry
769+
if err == cynvrtc.nvrtcResult.NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED:
770+
return _PCH_STATUS_NOT_ATTEMPTED
771+
return _PCH_STATUS_FAILED
772+
773+
774+
cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs):
775+
"""Compile using NVRTC backend and return ObjectCode."""
776+
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
777+
cdef list options_list = self._options.as_bytes("nvrtc", target_type)
778+
779+
result = _nvrtc_compile_and_extract(
780+
prog, target_type, name_expressions, logs, options_list, self._options.name,
781+
)
782+
783+
if not self._options.create_pch or not _has_nvrtc_pch_apis():
784+
self._pch_status = None
785+
return result
786+
787+
# PCH was requested — check creation status
788+
cdef str status = _read_pch_status(prog)
789+
if status is not None:
790+
self._pch_status = status
791+
return result
792+
793+
# Heap exhausted — auto-resize and retry with a fresh program
794+
cdef size_t required = 0
795+
with nogil:
796+
cynvrtc.nvrtcGetPCHHeapSizeRequired(prog, &required)
797+
cynvrtc.nvrtcSetPCHHeapSize(required)
798+
799+
cdef cynvrtc.nvrtcProgram retry_prog
800+
cdef const char* code_ptr = <const char*>self._nvrtc_code
801+
cdef const char* name_ptr = <const char*>self._options._name
802+
with nogil:
803+
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
804+
&retry_prog, code_ptr, name_ptr, 0, NULL, NULL))
805+
self._h_nvrtc = create_nvrtc_program_handle(retry_prog)
806+
807+
result = _nvrtc_compile_and_extract(
808+
retry_prog, target_type, name_expressions, logs, options_list, self._options.name,
809+
)
810+
811+
status = _read_pch_status(retry_prog)
812+
self._pch_status = status if status is not None else _PCH_STATUS_FAILED
813+
return result
720814

721815

722816
cdef object Program_compile_nvvm(Program self, str target_type, object logs):

cuda_core/docs/source/release/0.6.0-notes.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ New features
5454
- Added CUDA version compatibility check at import time to detect mismatches between
5555
``cuda.core`` and the installed ``cuda-bindings`` version.
5656

57+
- ``Program.compile()`` now automatically resizes the NVRTC PCH heap and
58+
retries when precompiled header creation fails due to heap exhaustion.
59+
The ``pch_status`` property reports the PCH creation outcome
60+
(``"created"``, ``"not_attempted"``, ``"failed"``, or ``None``).
61+
5762

5863
Fixes and enhancements
5964
----------------------

cuda_core/tests/test_program.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ def _get_nvrtc_version_for_tests():
5757
return None
5858

5959

60+
def _has_nvrtc_pch_apis_for_tests():
61+
required = (
62+
"nvrtcGetPCHHeapSize",
63+
"nvrtcSetPCHHeapSize",
64+
"nvrtcGetPCHCreateStatus",
65+
"nvrtcGetPCHHeapSizeRequired",
66+
)
67+
return all(hasattr(nvrtc, name) for name in required)
68+
69+
70+
nvrtc_pch_available = pytest.mark.skipif(
71+
(_get_nvrtc_version_for_tests() or 0) < 12800 or not _has_nvrtc_pch_apis_for_tests(),
72+
reason="PCH runtime APIs require NVRTC >= 12.8 bindings",
73+
)
74+
75+
6076
_libnvvm_version = None
6177
_libnvvm_version_attempted = False
6278

@@ -316,6 +332,25 @@ def test_cpp_program_with_pch_options(init_cuda, tmp_path):
316332
program.close()
317333

318334

335+
@nvrtc_pch_available
336+
def test_cpp_program_pch_auto_creates(init_cuda, tmp_path):
337+
code = 'extern "C" __global__ void my_kernel() {}'
338+
pch_path = str(tmp_path / "test.pch")
339+
program = Program(code, "c++", ProgramOptions(create_pch=pch_path))
340+
assert program.pch_status is None # not compiled yet
341+
program.compile("ptx")
342+
assert program.pch_status in ("created", "not_attempted", "failed")
343+
program.close()
344+
345+
346+
def test_cpp_program_pch_status_none_without_pch(init_cuda):
347+
code = 'extern "C" __global__ void my_kernel() {}'
348+
program = Program(code, "c++")
349+
program.compile("ptx")
350+
assert program.pch_status is None
351+
program.close()
352+
353+
319354
options = [
320355
ProgramOptions(max_register_count=32),
321356
ProgramOptions(debug=True),

0 commit comments

Comments
 (0)