Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@
- changed-files:
- any-glob-to-any-file:
- 'torchrl/modules/**'
- 'test/test_modules.py'
- 'test/test_tensordictmodules.py'
- 'test/test_actors.py'
- 'test/modules/**'

"Transforms":
- changed-files:
Expand Down Expand Up @@ -183,8 +181,7 @@
- changed-files:
- any-glob-to-any-file:
- 'torchrl/data/replay_buffers/**'
- 'test/test_rb.py'
- 'test/test_storage_map.py'
- 'test/rb/**'

"Services":
- changed-files:
Expand Down
6 changes: 3 additions & 3 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ run_distributed_tests() {
local json_report_dir="${RUNNER_ARTIFACT_DIR:-${root_dir}}"
local json_report_args="--json-report --json-report-file=${json_report_dir}/test-results-distributed.json --json-report-indent=2"

# Run both test_distributed.py and test_rb_distributed.py (both use torch.distributed)
# Run both test/test_distributed.py and test/rb/test_rb_distributed.py (both use torch.distributed)
# Note: distributed tests always run on GPU, no need for GPU_MARKER_FILTER here
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py test/test_rb_distributed.py \
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py test/rb/test_rb_distributed.py \
${json_report_args} \
--instafail --durations 200 -vv --capture no \
--timeout=120 --mp_fork_if_no_cuda
Expand All @@ -362,7 +362,7 @@ run_non_distributed_tests() {
# - Shard 2: test/envs/, test_collectors.py (multiprocessing-heavy)
# - Shard 3: Everything else (can use pytest-xdist for parallelism)
local shard="${TORCHRL_TEST_SHARD:-all}"
local common_ignores="--ignore test/test_rlhf.py --ignore test/test_distributed.py --ignore test/test_rb_distributed.py --ignore test/llm --ignore test/test_setup.py"
local common_ignores="--ignore test/test_rlhf.py --ignore test/test_distributed.py --ignore test/rb/test_rb_distributed.py --ignore test/llm --ignore test/test_setup.py"
local common_args="--instafail --durations 200 -vv --capture no --timeout=120 --mp_fork_if_no_cuda"

# JSON report output for flaky test tracking
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ per-file-ignores =
test/smoke_test_deps.py: F401
test_*.py: F841, E731, E266
test/opengl_rendering.py: F401
test/test_modules.py: F841, E731, E266, TOR101
test/test_tensordictmodules.py: F841, E731, E266, TOR101
test/modules/test_*.py: F841, E731, E266, TOR101
torchrl/objectives/cql.py: TOR101
torchrl/objectives/deprecated.py: TOR101
torchrl/objectives/iql.py: TOR101
Expand Down
45 changes: 45 additions & 0 deletions test/modules/_modules_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import sys

import torch
from packaging import version

_has_transformers = importlib.util.find_spec("transformers") is not None
_has_vllm = importlib.util.find_spec("vllm") is not None

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
IS_WINDOWS = sys.platform == "win32"


def _has_triton_backend() -> bool:
"""Mirror of the triton-availability check inside the RNN backend.

Triton must be installed, CUDA must be available, and the Triton build
must expose the ``triton.language.extra.libdevice`` submodule
(Triton >= 2.2). Older Triton installations are routed to scan/pad
backends, so the triton-specific tests are skipped there.
"""
if importlib.util.find_spec("triton") is None or not torch.cuda.is_available():
return False
return importlib.util.find_spec("triton.language.extra.libdevice") is not None


_has_triton = _has_triton_backend()
_triton_skip_reason = "requires triton (>= 2.2) and CUDA"

_has_functorch = False
try:
try:
from torch import vmap as vmap # noqa: F401
except ImportError:
from functorch import vmap as vmap # noqa: F401

_has_functorch = True
except ImportError:
pass
16 changes: 16 additions & 0 deletions test/modules/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import pytest
import torch


@pytest.fixture
def double_prec_fixture():
dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
yield
torch.set_default_dtype(dtype)
Loading
Loading