Skip to content

Commit 3ccd590

Browse files
lukebaumanncopybara-github
authored andcommitted
Use the internal pxla API explicitly because the public pxla is deprecated as of JAX 0.8.2.
Also updated `jax.extend.ifrt_programs` import to import the module and make an alias for the `jax.extend.ifrt_programs.ifrt_programs` variable. PiperOrigin-RevId: 852516195
1 parent b5279c6 commit 3ccd590

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

pathwaysutils/plugin_executable.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
from typing import List, Sequence, Tuple, Union
1919

2020
import jax
21-
from jax.interpreters import pxla
22-
from jax.extend.ifrt_programs import ifrt_programs
21+
from jax._src.interpreters import pxla # pylint: disable=protected-access
22+
from jax.extend import ifrt_programs
23+
24+
25+
ifrt_programs = ifrt_programs.ifrt_programs
2326

2427

2528
class PluginExecutable:

0 commit comments

Comments
 (0)