Skip to content

Commit 5125b15

Browse files
committed
Fix example requirements
1 parent 758fc19 commit 5125b15

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

cuda_core/tests/example_tests/test_basic_examples.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33

44
# If we have subcategories of examples in the future, this file can be split along those lines
55

6+
import functools
67
import glob
78
import os
89
import subprocess
910
import sys
1011

1112
import pytest
1213

13-
from cuda.core import Device
14+
from cuda.core import Device, system
1415

1516

1617
def pytorch_installed():
@@ -22,16 +23,48 @@ def pytorch_installed():
2223
return False
2324

2425

26+
@functools.cache
27+
def cupy_installed():
28+
try:
29+
import cupy # noqa: F401
30+
31+
return True
32+
except ImportError:
33+
return False
34+
35+
36+
def cffi_installed():
37+
try:
38+
import cffi # noqa: F401
39+
40+
return True
41+
except ImportError:
42+
return False
43+
44+
2545
def has_compute_capability_9_or_higher():
2646
dev = Device()
2747
arch = dev.compute_capability
2848
return arch > (9, 0)
2949

3050

51+
def has_multiple_devices():
52+
return system.get_num_devices() >= 2
53+
54+
3155
REQUIREMENTS = {
32-
"gl_interop_plasma.py": (lambda: False), # requires a display
33-
"pytorch_example.py": pytorch_installed, # requires PyTorch
34-
"thread_block_cluster.py": has_compute_capability_9_or_higher, # requires CC 9.0+
56+
"gl_interop_plasma.py": (lambda: False),
57+
"pytorch_example.py": pytorch_installed,
58+
"thread_block_cluster.py": has_compute_capability_9_or_higher,
59+
"cuda_graphs.py": cupy_installed,
60+
"jit_lto_fractal.py": cupy_installed,
61+
"memory_ops.py": cupy_installed,
62+
"saxpy.py": cupy_installed,
63+
"simple_multi_gpu_example.py": lambda: has_multiple_devices() and cupy_installed(),
64+
"strided_memory_view_cpu.py": cffi_installed,
65+
"strided_memory_view_gpu.py": cupy_installed,
66+
"tma_tensor_map.py": cupy_installed,
67+
"vector_add.py": cupy_installed,
3568
}
3669

3770
samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")

0 commit comments

Comments
 (0)