1- import os
1+ import os
22
33import numpy as np
44
5- import jax .numpy as jnp
5+ import jax .numpy as jnp
66import jax .lax as lax
77from jax import grad , vmap , device_put , devices
8+ import interpax
89
910import linx .const as const
1011from linx .special_funcs import Li , K1 , K2
@@ -633,10 +634,10 @@ def p_massive_MB(T, mu, m, g):
633634
634635file_dir = os .path .dirname (__file__ )
635636
636- # QED Corrections
637- P_QED_tab = np .loadtxt (file_dir + "/data/background/" + "QED_P_int.txt" )
638- dPdT_QED_tab = np .loadtxt (file_dir + "/data/background/" + "QED_dP_intdT.txt" )
639- d2PdT2_QED_tab = np .loadtxt (file_dir + "/data/background/" + "QED_d2P_intdT2.txt" )
637+ # QED Corrections - flip to ensure monotonically increasing T for interpax.interp1d
638+ P_QED_tab = np .flip ( np . loadtxt (file_dir + "/data/background/" + "QED_P_int.txt" ), axis = 0 )
639+ dPdT_QED_tab = np .flip ( np . loadtxt (file_dir + "/data/background/" + "QED_dP_intdT.txt" ), axis = 0 )
640+ d2PdT2_QED_tab = np .flip ( np . loadtxt (file_dir + "/data/background/" + "QED_d2P_intdT2.txt" ), axis = 0 )
640641
641642# Effect of standard value of electron mass in scattering matrix elements
642643f_nue_scat_tab = np .loadtxt (file_dir + "/data/background/" + "nue_scatt.txt" )
@@ -705,13 +706,13 @@ def rho_EM_std(T_g, mu=0, LO=True, NLO=True):
705706 """
706707
707708 corr_QED = (
708- - jnp . interp (
709- T_g , jnp . flip ( P_QED_tab [:,0 ]),
710- jnp . flip ( LO * P_QED_tab [:,1 ]+ NLO * P_QED_tab [:,2 ])
711- )
712- + T_g * jnp . interp (
713- T_g , jnp . flip ( dPdT_QED_tab [:,0 ]) ,
714- jnp . flip ( LO * dPdT_QED_tab [:,1 ]+ NLO * dPdT_QED_tab [:,2 ])
709+ - interpax . interp1d (
710+ T_g , P_QED_tab [:,0 ],
711+ LO * P_QED_tab [:,1 ]+ NLO * P_QED_tab [:,2 ]
712+ )
713+ + T_g * interpax . interp1d (
714+ T_g , dPdT_QED_tab [:,0 ],
715+ LO * dPdT_QED_tab [:,1 ]+ NLO * dPdT_QED_tab [:,2 ]
715716 )
716717 )
717718
@@ -746,9 +747,9 @@ def p_EM_std(T_g, mu=0, LO=True, NLO=True):
746747 Units of MeV^4.
747748 """
748749
749- corr_QED = jnp . interp (
750- T_g , jnp . flip ( P_QED_tab [:,0 ]) ,
751- jnp . flip ( LO * P_QED_tab [:,1 ] + NLO * P_QED_tab [:,2 ])
750+ corr_QED = interpax . interp1d (
751+ T_g , P_QED_tab [:,0 ],
752+ LO * P_QED_tab [:,1 ] + NLO * P_QED_tab [:,2 ]
752753 )
753754
754755 return (
@@ -782,9 +783,9 @@ def rho_plus_p_EM_std(T_g, mu=0, LO=True, NLO=True):
782783 Units of MeV^4.
783784 """
784785
785- corr_QED = T_g * jnp . interp (
786- T_g , jnp . flip ( dPdT_QED_tab [:,0 ]) ,
787- jnp . flip ( LO * dPdT_QED_tab [:,1 ] + NLO * dPdT_QED_tab [:,2 ])
786+ corr_QED = T_g * interpax . interp1d (
787+ T_g , dPdT_QED_tab [:,0 ],
788+ LO * dPdT_QED_tab [:,1 ] + NLO * dPdT_QED_tab [:,2 ]
788789 )
789790
790791 return (
@@ -1019,10 +1020,10 @@ def G(T_1, mu_1, T_2, mu_2):
10191020
10201021 def G_nue_with_me (T_1 , mu_1 , T_2 , mu_2 ):
10211022
1022- def interp_f (f_tab ):
1023-
1024- return jnp . interp (
1025- T_1 , f_tab [:,0 ], f_tab [:,1 ], left = f_tab [ 0 , 1 ], right = f_tab [ - 1 , 1 ]
1023+ def interp_f (f_tab ):
1024+ # Tables have boundary values 0.0 (low T) and 1.0 (high T)
1025+ return interpax . interp1d (
1026+ T_1 , f_tab [:,0 ], f_tab [:,1 ], extrap = ( 0.0 , 1.0 )
10261027 )
10271028
10281029 f_nue_ann = lax .cond (
@@ -1043,12 +1044,12 @@ def interp_f(f_tab):
10431044 )
10441045 )
10451046
1046- def G_numt_with_me (T_1 , mu_1 , T_2 , mu_2 ):
1047-
1048- def interp_f (f_tab ):
1047+ def G_numt_with_me (T_1 , mu_1 , T_2 , mu_2 ):
10491048
1050- return jnp .interp (
1051- T_1 , f_tab [:,0 ], f_tab [:,1 ], left = f_tab [0 ,1 ], right = f_tab [- 1 ,1 ]
1049+ def interp_f (f_tab ):
1050+ # Tables have boundary values 0.0 (low T) and 1.0 (high T)
1051+ return interpax .interp1d (
1052+ T_1 , f_tab [:,0 ], f_tab [:,1 ], extrap = (0.0 , 1.0 )
10521053 )
10531054
10541055 f_numt_ann = lax .cond (
0 commit comments