Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .github/actions/install-deps/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ inputs:
description: 'Dependency groups to install (space-separated, e.g., "test doc")'
required: false
default: ''
resolution:
description: 'Resolution strategy (e.g., "lowest-direct"). Use "none" for the default strategy.'
required: false
default: 'none'

runs:
using: composite
Expand All @@ -20,7 +24,7 @@ runs:
shell: bash
run: uv venv

- name: Install dependencies (options=[${{ inputs.options }}], groups=[${{ inputs.groups }}])
- name: Install dependencies (options=[${{ inputs.options }}], groups=[${{ inputs.groups }}], resolution=[${{ inputs.resolution }}])
shell: bash
run: |
# Map "none" to the empty string
Expand All @@ -42,6 +46,12 @@ runs:
cmd="$cmd --group $group"
done

# Add resolution strategy
resolution_input="${{ inputs.resolution }}"
if [ "$resolution_input" != "none" ] && [ -n "$resolution_input" ]; then
cmd="$cmd --resolution $resolution_input"
fi

# Print and execute the command
echo $cmd
eval $cmd
7 changes: 4 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
tests:
# Default config: py3.14, ubuntu-latest, float32, full options.
# The idea is to make each of those params vary one by one, to limit the number of tests to run.
name: Tests (py${{ matrix.python-version || '3.14' }}, ${{ matrix.os || 'ubuntu-latest' }}, ${{ matrix.dtype || 'float32' }}, ${{ matrix.options || 'full' }}${{ matrix.extra_groups && format(', {0}', matrix.extra_groups) || '' }})
name: Tests (py${{ matrix.python-version || '3.14' }}, ${{ matrix.os || 'ubuntu-latest' }}, ${{ matrix.dtype || 'float32' }}, ${{ matrix.options || 'full' }}${{ matrix.resolution && format(', {0}', matrix.resolution) || '' }})
runs-on: ${{ matrix.os || 'ubuntu-latest' }}
strategy:
fail-fast: false
Expand All @@ -38,7 +38,7 @@ jobs:
- options: 'none'
# Lower-bounds of all dependencies and Python version.
- python-version: '3.10.0'
extra_groups: 'lower_bounds'
resolution: 'lowest-direct'

steps:
- name: Checkout repository
Expand All @@ -52,7 +52,8 @@ jobs:
- uses: ./.github/actions/install-deps
with:
options: ${{ matrix.options || 'full' }}
groups: test ${{ matrix.extra_groups }}
groups: test
resolution: ${{ matrix.resolution || 'none' }}

- name: Run unit tests
run: uv run pytest -W error tests/unit --cov=src --cov-report=xml
Expand Down
51 changes: 27 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ Source = "https://github.com/SimplexLab/TorchJD"
Changelog = "https://github.com/SimplexLab/TorchJD/blob/main/CHANGELOG.md"

[dependency-groups]
_numpy = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
]
_quadprog = [
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
]
_qpsolvers = [
"qpsolvers>=1.0.1", # Does not work before 1.0.1
]
_cvxpy = [
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
]
_ecos = [
"ecos>=2.0.14", # Does not work before 2.0.14
]

check = [
"ruff>=0.14.14",
"ty>=0.0.14",
Expand All @@ -90,46 +106,33 @@ test = [
]

plot = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
{include-group = "_numpy"},
"plotly[kaleido]>=5.19.0", # Recent version to avoid problems, could be relaxed
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
"matplotlib>=3.10.0", # Recent version to avoid problems, could be relaxed
]
# Dependency group allowing to easily resolve version of the recommended dependencies to the lower
# bound.
lower_bounds = [
"torch==2.3.0",
"numpy==1.21.2",
"quadprog==0.1.9",
"qpsolvers==1.0.1",
]

[project.optional-dependencies]
quadprog_projector = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"qpsolvers>=1.0.1", # Does not work before 1.0.1
{include-group = "_numpy"},
{include-group = "_quadprog"},
{include-group = "_qpsolvers"},
]
nash_mtl = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"cvxpy>=1.3.0", # Could be relaxed
"ecos>=2.0.14", # Does not work before 2.0.14
{include-group = "_numpy"},
{include-group = "_cvxpy"},
{include-group = "_ecos"},
]
cagrad = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
{include-group = "_numpy"},
{include-group = "_cvxpy"},
]
fairgrad = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
{include-group = "_numpy"},
"scipy",
]
full = [
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"qpsolvers>=1.0.1", # Does not work before 1.0.1
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
"ecos>=2.0.14", # Does not work before 2.0.14
"scipy",
"torchjd[quadprog_projector,nash_mtl,cagrad,fairgrad]",
]

[tool.pytest.ini_options]
Expand Down
Loading