Skip to content

Commit 86e4ebc

Browse files
authored
Merge pull request #20 from cgiovanetti/feature/diffrax-throw
Add diffrax throw option to control solver error handling
2 parents b9c7728 + bbf7201 commit 86e4ebc

2 files changed

Lines changed: 59 additions & 38 deletions

File tree

linx/abundances.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,45 +15,48 @@
1515
from linx.thermo import rho_EM_std_v, p_EM_std_v, nB
1616
from linx.special_funcs import zeta_3
1717

18-
class AbundanceModel(eqx.Module):
18+
class AbundanceModel(eqx.Module):
1919
"""
20-
Abundance model and BBN abundance prediction.
20+
Abundance model and BBN abundance prediction.
2121
2222
Attributes
2323
----------
2424
nuclear_net : NuclearRates
25-
Nuclear network to be used for BBN prediction.
25+
Nuclear network to be used for BBN prediction.
2626
weak_rates : WeakRates
27-
Weak rates for neutron-proton interconversion.
27+
Weak rates for neutron-proton interconversion.
2828
species_dict : dict
29-
Dictionary of species considered in LINX.
29+
Dictionary of species considered in LINX.
3030
species_Z : list
31-
Number of protons in each species.
31+
Number of protons in each species.
3232
species_N : list
33-
Number of neutrons in each species.
33+
Number of neutrons in each species.
3434
species_A : list
35-
Atomic mass number of each species.
35+
Atomic mass number of each species.
3636
species_excess_mass : list
37-
Excess mass (mass - A*amu) of each species.
37+
Excess mass (mass - A*amu) of each species.
3838
species_spin : list
39-
Spin of each species.
39+
Spin of each species.
4040
species_binding_energy : list
41-
Binding energy of each species.
41+
Binding energy of each species.
4242
species_mass : list
43-
Mass of each species.
43+
Mass of each species.
44+
throw : bool
45+
Whether to raise exceptions on solver failure.
4446
"""
45-
nuclear_net : nucl.NuclearRates
46-
weak_rates : wr.WeakRates
47+
nuclear_net : nucl.NuclearRates
48+
weak_rates : wr.WeakRates
4749
species_dict : dict
4850
species_Z : list
4951
species_N : list
50-
species_A : list
52+
species_A : list
5153
species_excess_mass : dict
5254
species_spin : list
5355
species_binding_energy : list
5456
species_mass : list
57+
throw : bool
5558

56-
def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
59+
def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True):
5760
"""
5861
Initialize the AbundanceModel with nuclear and weak rate networks.
5962
@@ -65,6 +68,9 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
6568
weak_rates : WeakRates, optional
6669
Weak interaction rates for neutron-proton interconversion.
6770
Defaults to standard WeakRates instance.
71+
throw : bool, optional
72+
If True, raise exceptions on solver failure. Default is True.
73+
Set to False for parameter scans where some combinations may fail.
6874
6975
Notes
7076
-----
@@ -76,6 +82,7 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
7682

7783
self.nuclear_net = nuclear_net
7884
self.weak_rates = weak_rates
85+
self.throw = throw
7986

8087
self.species_dict = {
8188
0:'n', 1:'p', 2:'d', 3:'t', 4:'He3', 5:'a', 6:'Li7', 7:'Be7',
@@ -291,15 +298,18 @@ def __call__(
291298
saveat = SaveAt(t1=True)
292299

293300
sol = diffeqsolve(
294-
ODETerm(self.Y_prime), solver,
295-
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
296-
args = (
297-
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
301+
ODETerm(self.Y_prime), solver,
302+
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
303+
args=(
304+
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
298305
nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q
299-
), saveat=saveat, stepsize_controller = PIDController(
306+
),
307+
saveat=saveat,
308+
stepsize_controller=PIDController(
300309
rtol=rtol, atol=atol,
301-
),
302-
max_steps=max_steps
310+
),
311+
max_steps=max_steps,
312+
throw=self.throw
303313
)
304314

305315
if save_history:
@@ -369,12 +379,13 @@ def dt_prime(rho_tot, t, args):
369379
rho_tot_fin = rho_tot_vec[-1]
370380

371381
sol_t = diffeqsolve(
372-
ODETerm(dt_prime), Tsit5(),
373-
t0=rho_tot_init, t1=rho_tot_fin,
374-
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
382+
ODETerm(dt_prime), Tsit5(),
383+
t0=rho_tot_init, t1=rho_tot_fin,
384+
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
375385
dt0=None, max_steps=4096,
376-
saveat=SaveAt(ts=rho_tot_vec),
377-
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10)
386+
saveat=SaveAt(ts=rho_tot_vec),
387+
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
388+
throw=self.throw
378389
)
379390

380391
return sol_t.ys
@@ -441,13 +452,14 @@ def dlna_prime(rho_tot, t, args):
441452
rho_tot_init = rho_tot_vec[0]
442453
rho_tot_fin = rho_tot_vec[-1]
443454

444-
# a_0 = 1 arbitrarily, will rescale later.
455+
# a_0 = 1 arbitrarily, will rescale later.
445456
sol_lna = diffeqsolve(
446-
ODETerm(dlna_prime), Tsit5(),
447-
t0=rho_tot_init, t1=rho_tot_fin,
457+
ODETerm(dlna_prime), Tsit5(),
458+
t0=rho_tot_init, t1=rho_tot_fin,
448459
y0=0., dt0=None, max_steps=4096,
449460
saveat=SaveAt(ts=rho_tot_vec),
450-
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10)
461+
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
462+
throw=self.throw
451463
)
452464

453465
a_fin = const.T0CMB / T_g_vec[-1]

linx/background.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
thermo.rho_massless_FD, in_axes=(0, None, None)
1717
)
1818

19-
class BackgroundModel(eqx.Module):
19+
class BackgroundModel(eqx.Module):
2020
"""Background model.
2121
2222
Attributes
@@ -31,15 +31,19 @@ class BackgroundModel(eqx.Module):
3131
Whether to use leading order QED correction. Default is `True`.
3232
NLO : bool, optional
3333
Whether to use next-to-leading order QED correction. Default is True.
34+
throw : bool, optional
35+
Whether to raise exceptions on solver failure. Default is `True`.
36+
Set to `False` for parameter scans where some combinations may fail.
3437
"""
3538

3639
decoupled : bool
3740
use_FD : bool
3841
collision_me : bool
3942
LO : bool
4043
NLO : bool
44+
throw : bool
4145

42-
def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO = True):
46+
def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO=True, throw=True):
4347
"""
4448
Initialize the BackgroundModel with thermodynamic options.
4549
@@ -56,13 +60,17 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO
5660
If True, include leading order QED corrections. Default is True.
5761
NLO : bool, optional
5862
If True, include next-to-leading order QED corrections. Default is True.
63+
throw : bool, optional
64+
If True, raise exceptions on solver failure. Default is True.
65+
Set to False for parameter scans where some combinations may fail.
5966
"""
6067

6168
self.decoupled = decoupled
6269
self.use_FD = use_FD
6370
self.collision_me = collision_me
6471
self.LO = LO
6572
self.NLO = NLO
73+
self.throw = throw
6674

6775
@eqx.filter_jit
6876
def __call__(
@@ -131,12 +139,13 @@ def T_EM_check(t, y, args, **kwargs):
131139

132140
sol = diffeqsolve(
133141
ODETerm(self.dY), solver, args=(lna_init, rho_extra_init),
134-
t0=0., t1=jnp.inf, dt0=None, y0=Y0,
142+
t0=0., t1=jnp.inf, dt0=None, y0=Y0,
135143
saveat=SaveAt(steps=True), event=Event(T_EM_check),
136-
stepsize_controller = PIDController(
144+
stepsize_controller=PIDController(
137145
rtol=rtol, atol=atol
138-
),
139-
max_steps=max_steps
146+
),
147+
max_steps=max_steps,
148+
throw=self.throw
140149
)
141150

142151
a_vec = jnp.exp(sol.ys[0])

0 commit comments

Comments
 (0)