Skip to content

Commit 3e7ceaa

Browse files
smsharmaclaude
andcommitted
Replace jnp.interp with interpax.interp1d in thermo.py
Follow-up to PR #16 which replaced jnp.interp with interpax in abundances.py. This change applies the same improvement to thermo.py for consistency and better performance. Changes: - Import interpax module - Flip QED correction tables at load time (instead of at each call) for monotonically increasing x coordinates required by interpax - Replace 6 jnp.interp calls with interpax.interp1d: - rho_EM_std: 2 calls for QED corrections - p_EM_std: 1 call for QED correction - rho_plus_p_EM_std: 1 call for QED correction - G_nue_with_me: 1 call for collision factor interpolation - G_numt_with_me: 1 call for collision factor interpolation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 8d81f2b commit 3e7ceaa

1 file changed

Lines changed: 29 additions & 28 deletions

File tree

linx/thermo.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import os
1+
import os
22

33
import numpy as np
44

5-
import jax.numpy as jnp
5+
import jax.numpy as jnp
66
import jax.lax as lax
77
from jax import grad, vmap, device_put, devices
8+
import interpax
89

910
import linx.const as const
1011
from linx.special_funcs import Li, K1, K2
@@ -633,10 +634,10 @@ def p_massive_MB(T, mu, m, g):
633634

634635
file_dir = os.path.dirname(__file__)
635636

636-
# QED Corrections
637-
P_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt")
638-
dPdT_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt")
639-
d2PdT2_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt")
637+
# QED Corrections - flip to ensure monotonically increasing T for interpax.interp1d
638+
P_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt"), axis=0)
639+
dPdT_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt"), axis=0)
640+
d2PdT2_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt"), axis=0)
640641

641642
# Effect of standard value of electron mass in scattering matrix elements
642643
f_nue_scat_tab = np.loadtxt(file_dir+"/data/background/"+"nue_scatt.txt")
@@ -704,13 +705,13 @@ def rho_EM_std(T_g, mu=0, LO=True, NLO=True):
704705
"""
705706

706707
corr_QED = (
707-
-jnp.interp(
708-
T_g, jnp.flip(P_QED_tab[:,0]),
709-
jnp.flip(LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2])
710-
)
711-
+ T_g*jnp.interp(
712-
T_g, jnp.flip(dPdT_QED_tab[:,0]),
713-
jnp.flip(LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2])
708+
-interpax.interp1d(
709+
T_g, P_QED_tab[:,0],
710+
LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2]
711+
)
712+
+ T_g*interpax.interp1d(
713+
T_g, dPdT_QED_tab[:,0],
714+
LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2]
714715
)
715716
)
716717

@@ -745,9 +746,9 @@ def p_EM_std(T_g, mu=0, LO=True, NLO=True):
745746
Units of MeV^4.
746747
"""
747748

748-
corr_QED = jnp.interp(
749-
T_g, jnp.flip(P_QED_tab[:,0]),
750-
jnp.flip(LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2])
749+
corr_QED = interpax.interp1d(
750+
T_g, P_QED_tab[:,0],
751+
LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2]
751752
)
752753

753754
return (
@@ -781,9 +782,9 @@ def rho_plus_p_EM_std(T_g, mu=0, LO=True, NLO=True):
781782
Units of MeV^4.
782783
"""
783784

784-
corr_QED = T_g * jnp.interp(
785-
T_g, jnp.flip(dPdT_QED_tab[:,0]),
786-
jnp.flip(LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2])
785+
corr_QED = T_g * interpax.interp1d(
786+
T_g, dPdT_QED_tab[:,0],
787+
LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2]
787788
)
788789

789790
return (
@@ -1018,10 +1019,10 @@ def G(T_1, mu_1, T_2, mu_2):
10181019

10191020
def G_nue_with_me(T_1, mu_1, T_2, mu_2):
10201021

1021-
def interp_f(f_tab):
1022-
1023-
return jnp.interp(
1024-
T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1]
1022+
def interp_f(f_tab):
1023+
# Tables have boundary values 0.0 (low T) and 1.0 (high T)
1024+
return interpax.interp1d(
1025+
T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0)
10251026
)
10261027

10271028
f_nue_ann = lax.cond(
@@ -1042,12 +1043,12 @@ def interp_f(f_tab):
10421043
)
10431044
)
10441045

1045-
def G_numt_with_me(T_1, mu_1, T_2, mu_2):
1046-
1047-
def interp_f(f_tab):
1046+
def G_numt_with_me(T_1, mu_1, T_2, mu_2):
10481047

1049-
return jnp.interp(
1050-
T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1]
1048+
def interp_f(f_tab):
1049+
# Tables have boundary values 0.0 (low T) and 1.0 (high T)
1050+
return interpax.interp1d(
1051+
T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0)
10511052
)
10521053

10531054
f_numt_ann = lax.cond(

0 commit comments

Comments
 (0)