Skip to content

Commit 9da92ce

Browse files
committed
fix device_put error
1 parent 43f26ba commit 9da92ce

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

linx/thermo.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import jax.numpy as jnp
66
import jax.lax as lax
7-
from jax import grad, vmap
7+
from jax import grad, vmap, device_put
88

99
import linx.const as const
1010
from linx.special_funcs import Li, K1, K2
@@ -648,27 +648,27 @@ def p_massive_MB(T, mu, m, g):
648648

649649
try:
650650
gpus = jax.devices('gpu')
651-
P_QED_tab = jax.device_put(
651+
P_QED_tab = device_put(
652652
P_QED_tab, device=gpus[0]
653653
)
654-
dPdT_QED_tab = jax.device_put(
654+
dPdT_QED_tab = device_put(
655655
dPdT_QED_tab, device=gpus[0]
656656
)
657-
d2PdT2_QED_tab = jax.device_put(
657+
d2PdT2_QED_tab = device_put(
658658
d2PdT2_QED_tab , device=gpus[0]
659659
)
660660

661-
f_nue_scat_tab = jax.device_put(
661+
f_nue_scat_tab = device_put(
662662
f_nue_scat_tab, device=gpus[0]
663663
)
664-
f_numu_scat_tab = jax.device_put(
664+
f_numu_scat_tab = device_put(
665665
f_numu_scat_tab, device=gpus[0]
666666
)
667667

668-
f_nue_ann_tab = jax.device_put(
668+
f_nue_ann_tab = device_put(
669669
f_nue_ann_tab, device=gpus[0]
670670
)
671-
f_numu_ann_tab = jax.device_put(
671+
f_numu_ann_tab = device_put(
672672
f_numu_ann_tab, device=gpus[0]
673673
)
674674
except:

0 commit comments

Comments
 (0)