Skip to content

Commit 758fc19

Browse files
committed
Fix tests
1 parent 362d0be commit 758fc19

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

.github/workflows/test-wheel-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ jobs:
101101
continue-on-error: false
102102
with:
103103
# for artifact fetching, graphics libs
104-
dependencies: "jq wget libgl1 libegl1"
104+
dependencies: "jq wget libgl1 libegl1 g++"
105105
dependent_exes: "jq wget"
106106

107107
- name: Set environment variables

cuda_core/tests/example_tests/test_basic_examples.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,42 @@
1010

1111
import pytest
1212

13-
DISALLOW_LIST = {
14-
"gl_interop_plasma.py", # requires a display
13+
from cuda.core import Device
14+
15+
16+
def pytorch_installed():
17+
try:
18+
import torch # noqa: F401
19+
20+
return True
21+
except ImportError:
22+
return False
23+
24+
25+
def has_compute_capability_9_or_higher():
26+
dev = Device()
27+
arch = dev.compute_capability
28+
return arch > (9, 0)
29+
30+
31+
REQUIREMENTS = {
32+
"gl_interop_plasma.py": (lambda: False), # requires a display
33+
"pytorch_example.py": pytorch_installed, # requires PyTorch
34+
"thread_block_cluster.py": has_compute_capability_9_or_higher, # requires CC 9.0+
1535
}
1636

1737
samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
1838
sample_files = [
19-
x
39+
(x, REQUIREMENTS.get(x, lambda: True))
2040
for x in (os.path.basename(x) for x in glob.glob(samples_path + "**/*.py", recursive=True))
21-
if x not in DISALLOW_LIST
2241
]
2342

2443

25-
@pytest.mark.parametrize("example", sample_files)
44+
@pytest.mark.parametrize("example,requirement", sample_files)
2645
class TestExamples:
27-
def test_example(self, example):
46+
def test_example(self, example, requirement):
47+
if not requirement():
48+
pytest.skip(f"Skipping {example} due to unmet requirements")
2849
example_path = os.path.join(samples_path, example)
2950
process = subprocess.run([sys.executable, example_path], capture_output=True) # noqa: S603
3051
if process.returncode != 0:

0 commit comments

Comments
 (0)