Skip to content

Commit 9fd0b0e

Browse files
authored
Merge pull request #5 from cgiovanetti/fix-diffrax-deprecation-warning
Fix diffrax deprecation warning
2 parents 63996cd + 17b52bd commit 9fd0b0e

5 files changed

Lines changed: 16 additions & 10 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ jobs:
2626
- name: Run tests
2727
run: |
2828
cd pytest
29-
pytest .
29+
pytest -m "not slow"

linx/background.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import equinox as eqx
66

7-
from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, DiscreteTerminatingEvent
7+
from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, Event
88

99
import linx.thermo as thermo
1010
import linx.const as const
@@ -109,14 +109,13 @@ def __call__(
109109

110110
Y0 = (lna_init, T_EM_init, T_nu_init)
111111

112-
def T_EM_check(state, **kwargs):
113-
114-
return state.y[1] < T_end
112+
def T_EM_check(t, y, args, **kwargs):
113+
return y[1] < T_end
115114

116115
sol = diffeqsolve(
117116
ODETerm(self.dY), solver, args=(lna_init, rho_extra_init),
118117
t0=0., t1=jnp.inf, dt0=None, y0=Y0,
119-
saveat=SaveAt(steps=True), discrete_terminating_event = DiscreteTerminatingEvent(T_EM_check),
118+
saveat=SaveAt(steps=True), event=Event(T_EM_check),
120119
stepsize_controller = PIDController(
121120
rtol=rtol, atol=atol
122121
),

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
markers =
3+
slow: marks tests as slow (deselect with '-m "not slow"')

pytest/test_numpyro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import sys
2-
sys.path.append('../scripts')
2+
import os
3+
import pytest
4+
# Add absolute path to the scripts directory
5+
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'scripts'))
36

47
from run_numpyro import run
58

69

10+
@pytest.mark.slow
711
def test_run_numpyro():
812
try:
913
run(bbn_only=True, n_steps_svi=5, n_warmup_mcmc=5, n_samples_mcmc=5, n_chains=1)

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ cloudpickle
55
corner
66
cosmopower-jax
77
Cython
8-
diffrax==0.4.1
8+
diffrax>=0.6.2
99
dynesty
1010
emcee
1111
equinox
12-
jax==0.4.28
13-
jaxlib==0.4.28
12+
jax>=0.4.38
13+
jaxlib>=0.4.38
1414
jaxopt
1515
jaxtyping
1616
joblib

0 commit comments

Comments
 (0)