Skip to content

Commit da88e4d

Browse files
authored
Merge pull request #24 from cgiovanetti/siddharth/interpax-thermo
Replace jnp.interp with interpax.interp1d in thermo.py
2 parents 86e4ebc + 3e7ceaa commit da88e4d

1 file changed

Lines changed: 29 additions & 28 deletions

File tree

linx/thermo.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import os
1+
import os
22

33
import numpy as np
44

5-
import jax.numpy as jnp
5+
import jax.numpy as jnp
66
import jax.lax as lax
77
from jax import grad, vmap, device_put, devices
8+
import interpax
89

910
import linx.const as const
1011
from linx.special_funcs import Li, K1, K2
@@ -633,10 +634,10 @@ def p_massive_MB(T, mu, m, g):
633634

634635
file_dir = os.path.dirname(__file__)
635636

636-
# QED Corrections
637-
P_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt")
638-
dPdT_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt")
639-
d2PdT2_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt")
637+
# QED Corrections - flip to ensure monotonically increasing T for interpax.interp1d
638+
P_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt"), axis=0)
639+
dPdT_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt"), axis=0)
640+
d2PdT2_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt"), axis=0)
640641

641642
# Effect of standard value of electron mass in scattering matrix elements
642643
f_nue_scat_tab = np.loadtxt(file_dir+"/data/background/"+"nue_scatt.txt")
@@ -705,13 +706,13 @@ def rho_EM_std(T_g, mu=0, LO=True, NLO=True):
705706
"""
706707

707708
corr_QED = (
708-
-jnp.interp(
709-
T_g, jnp.flip(P_QED_tab[:,0]),
710-
jnp.flip(LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2])
711-
)
712-
+ T_g*jnp.interp(
713-
T_g, jnp.flip(dPdT_QED_tab[:,0]),
714-
jnp.flip(LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2])
709+
-interpax.interp1d(
710+
T_g, P_QED_tab[:,0],
711+
LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2]
712+
)
713+
+ T_g*interpax.interp1d(
714+
T_g, dPdT_QED_tab[:,0],
715+
LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2]
715716
)
716717
)
717718

@@ -746,9 +747,9 @@ def p_EM_std(T_g, mu=0, LO=True, NLO=True):
746747
Units of MeV^4.
747748
"""
748749

749-
corr_QED = jnp.interp(
750-
T_g, jnp.flip(P_QED_tab[:,0]),
751-
jnp.flip(LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2])
750+
corr_QED = interpax.interp1d(
751+
T_g, P_QED_tab[:,0],
752+
LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2]
752753
)
753754

754755
return (
@@ -782,9 +783,9 @@ def rho_plus_p_EM_std(T_g, mu=0, LO=True, NLO=True):
782783
Units of MeV^4.
783784
"""
784785

785-
corr_QED = T_g * jnp.interp(
786-
T_g, jnp.flip(dPdT_QED_tab[:,0]),
787-
jnp.flip(LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2])
786+
corr_QED = T_g * interpax.interp1d(
787+
T_g, dPdT_QED_tab[:,0],
788+
LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2]
788789
)
789790

790791
return (
@@ -1019,10 +1020,10 @@ def G(T_1, mu_1, T_2, mu_2):
10191020

10201021
def G_nue_with_me(T_1, mu_1, T_2, mu_2):
10211022

1022-
def interp_f(f_tab):
1023-
1024-
return jnp.interp(
1025-
T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1]
1023+
def interp_f(f_tab):
1024+
# Tables have boundary values 0.0 (low T) and 1.0 (high T)
1025+
return interpax.interp1d(
1026+
T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0)
10261027
)
10271028

10281029
f_nue_ann = lax.cond(
@@ -1043,12 +1044,12 @@ def interp_f(f_tab):
10431044
)
10441045
)
10451046

1046-
def G_numt_with_me(T_1, mu_1, T_2, mu_2):
1047-
1048-
def interp_f(f_tab):
1047+
def G_numt_with_me(T_1, mu_1, T_2, mu_2):
10491048

1050-
return jnp.interp(
1051-
T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1]
1049+
def interp_f(f_tab):
1050+
# Tables have boundary values 0.0 (low T) and 1.0 (high T)
1051+
return interpax.interp1d(
1052+
T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0)
10521053
)
10531054

10541055
f_numt_ann = lax.cond(

0 commit comments

Comments
 (0)