Skip to content

Commit 7fcb618

Browse files
committed
device_put fix
1 parent 89e7b92 commit 7fcb618

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

linx/thermo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import jax.numpy as jnp
66
import jax.lax as lax
7-
from jax import grad, vmap, device_put
7+
from jax import grad, vmap, device_put, devices
88

99
import linx.const as const
1010
from linx.special_funcs import Li, K1, K2
@@ -647,7 +647,7 @@ def p_massive_MB(T, mu, m, g):
647647
f_numu_ann_tab = np.loadtxt(file_dir+"/data/background/"+"numu_ann.txt")
648648

649649
try:
650-
gpus = jax.devices('gpu')
650+
gpus = devices('gpu')
651651
P_QED_tab = device_put(
652652
P_QED_tab, device=gpus[0]
653653
)

0 commit comments

Comments
 (0)