|
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 |
|
13 | | -DISALLOW_LIST = { |
14 | | - "gl_interop_plasma.py", # requires a display |
| 13 | +from cuda.core import Device |
| 14 | + |
| 15 | + |
| 16 | +def pytorch_installed(): |
| 17 | + try: |
| 18 | + import torch # noqa: F401 |
| 19 | + |
| 20 | + return True |
| 21 | + except ImportError: |
| 22 | + return False |
| 23 | + |
| 24 | + |
| 25 | +def has_compute_capability_9_or_higher(): |
| 26 | + dev = Device() |
| 27 | + arch = dev.compute_capability |
| 28 | + return arch > (9, 0) |
| 29 | + |
| 30 | + |
| 31 | +REQUIREMENTS = { |
| 32 | + "gl_interop_plasma.py": (lambda: False), # requires a display |
| 33 | + "pytorch_example.py": pytorch_installed, # requires PyTorch |
| 34 | + "thread_block_cluster.py": has_compute_capability_9_or_higher, # requires CC 9.0+ |
15 | 35 | } |
16 | 36 |
|
17 | 37 | samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples") |
18 | 38 | sample_files = [ |
19 | | - x |
| 39 | + (x, REQUIREMENTS.get(x, lambda: True)) |
20 | 40 | for x in (os.path.basename(x) for x in glob.glob(samples_path + "**/*.py", recursive=True)) |
21 | | - if x not in DISALLOW_LIST |
22 | 41 | ] |
23 | 42 |
|
24 | 43 |
|
25 | | -@pytest.mark.parametrize("example", sample_files) |
| 44 | +@pytest.mark.parametrize("example,requirement", sample_files) |
26 | 45 | class TestExamples: |
27 | | - def test_example(self, example): |
| 46 | + def test_example(self, example, requirement): |
| 47 | + if not requirement(): |
| 48 | + pytest.skip(f"Skipping {example} due to unmet requirements") |
28 | 49 | example_path = os.path.join(samples_path, example) |
29 | 50 | process = subprocess.run([sys.executable, example_path], capture_output=True) # noqa: S603 |
30 | 51 | if process.returncode != 0: |
|
0 commit comments