Skip to content

[CI] Support dynamic multi-GPU architecture testing (sm80/86/90/100) to fix NoneType extension load failures #191

Description

@Flink-ddd

Problem Description

Currently, running test_det_gemm.py in our CI pipeline results in multiple test failures with the following error:

    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
>       return _C.det_gemm_fwd(a, b)
               ^^^^^^^^^^^^^^^
E       AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'

Root Cause:

This is caused by a CUDA architecture mismatch. The current CI script (ci/run_gpu_ci.sh) hardcodes the compilation target to sm86 (export TORCH_CUDA_ARCH_LIST=8.6). When the cloud provider (RunPod) provisions a GPU with a different architecture, such as an H100 (sm90) or B200 (sm100), PyTorch silently fails to load the compiled .so C++ extension. Consequently, the _C module becomes None, triggering the AttributeError during test execution.

Proposed Solution

To resolve this and ensure cross-architecture compatibility (including the latest Blackwell series), we need to implement the following pipeline improvements:

Dynamic GPU Provisioning:

Introduce a TARGET_SM environment variable in ci/run_gpu_ci.sh to dynamically rent the appropriate GPU instance based on the target architecture:

  • TARGET_SM=100 → B200 / B100 (sm100)
  • TARGET_SM=90 → H100 (sm90)
  • TARGET_SM=80 → A100 (sm80)
  • TARGET_SM=86 → RTX A4000 / A40 (sm86, default fallback)

Dynamic Compilation Flags:

Replace the hardcoded TORCH_CUDA_ARCH_LIST with a dynamically injected variable that matches the rented hardware (e.g., 10.0, 9.0, 8.0, or 8.6).

Fail-Fast Validation:

Add a sanity check (python -c "import torch; import rl_engine") immediately after compilation. If the C++ extension fails to load, the script should exit immediately with a clear architecture mismatch warning, rather than executing the entire test suite and generating misleading logs.

GitHub Actions Matrix Strategy:

Update the CI workflow YAML to utilize a matrix strategy, triggering parallel test jobs for sm80, sm86, sm90, and sm100. This will provide comprehensive coverage and prevent architecture-specific regressions on future PRs across all supported NVIDIA generations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.type: ci-cdModify GitHub Actions, automated tests, and packaging/deployment tasks.

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions