|
4 | 4 |
|
5 | 5 | import jax.numpy as jnp |
6 | 6 | import jax.lax as lax |
7 | | -from jax import grad, vmap |
| 7 | +from jax import grad, vmap, device_put |
8 | 8 |
|
9 | 9 | import linx.const as const |
10 | 10 | from linx.special_funcs import Li, K1, K2 |
@@ -648,27 +648,27 @@ def p_massive_MB(T, mu, m, g): |
648 | 648 |
|
649 | 649 | try: |
650 | 650 | gpus = jax.devices('gpu') |
651 | | - P_QED_tab = jax.device_put( |
| 651 | + P_QED_tab = device_put( |
652 | 652 | P_QED_tab, device=gpus[0] |
653 | 653 | ) |
654 | | - dPdT_QED_tab = jax.device_put( |
| 654 | + dPdT_QED_tab = device_put( |
655 | 655 | dPdT_QED_tab, device=gpus[0] |
656 | 656 | ) |
657 | | - d2PdT2_QED_tab = jax.device_put( |
| 657 | + d2PdT2_QED_tab = device_put( |
658 | 658 | d2PdT2_QED_tab , device=gpus[0] |
659 | 659 | ) |
660 | 660 |
|
661 | | - f_nue_scat_tab = jax.device_put( |
| 661 | + f_nue_scat_tab = device_put( |
662 | 662 | f_nue_scat_tab, device=gpus[0] |
663 | 663 | ) |
664 | | - f_numu_scat_tab = jax.device_put( |
| 664 | + f_numu_scat_tab = device_put( |
665 | 665 | f_numu_scat_tab, device=gpus[0] |
666 | 666 | ) |
667 | 667 |
|
668 | | - f_nue_ann_tab = jax.device_put( |
| 668 | + f_nue_ann_tab = device_put( |
669 | 669 | f_nue_ann_tab, device=gpus[0] |
670 | 670 | ) |
671 | | - f_numu_ann_tab = jax.device_put( |
| 671 | + f_numu_ann_tab = device_put( |
672 | 672 | f_numu_ann_tab, device=gpus[0] |
673 | 673 | ) |
674 | 674 | except: |
|
0 commit comments