Skip to content

Commit 57b4d7b

Browse files
authored
[JAX] Remove import jax.extend.ffi (NVIDIA#2193)
* remove import jax.extend.ffi Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent 5b3092a commit 57b4d7b

File tree

6 files changed

+6
-42
lines changed

6 files changed

+6
-42
lines changed

transformer_engine/jax/cpp_extensions/activation.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from typing import Sequence, Union, Callable, Optional, Tuple
66
import operator
77
from functools import reduce, partial
8-
from packaging import version
98

109
import jax
1110
import jax.numpy as jnp
12-
from jax import dtypes
11+
from jax import dtypes, ffi
1312
from jax.experimental.custom_partitioning import SdyShardingRule
1413
from jax.sharding import PartitionSpec
1514

@@ -37,10 +36,6 @@
3736
ScalingMode,
3837
)
3938

40-
if version.parse(jax.__version__) >= version.parse("0.5.0"):
41-
from jax import ffi # pylint: disable=ungrouped-imports
42-
else:
43-
from jax.extend import ffi # pylint: disable=ungrouped-imports
4439

4540
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
4641

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from dataclasses import dataclass, replace
99
from functools import partial, reduce
1010
from typing import Optional, Tuple
11-
from packaging import version
1211

1312
import jax
1413
import jax.numpy as jnp
15-
from jax import dtypes, lax
14+
from jax import dtypes, lax, ffi
1615
from jax.sharding import PartitionSpec, NamedSharding
1716
from jax.experimental.custom_partitioning import SdyShardingRule
1817

@@ -49,12 +48,6 @@
4948
)
5049

5150

52-
if version.parse(jax.__version__) >= version.parse("0.5.0"):
53-
from jax import ffi # pylint: disable=ungrouped-imports
54-
else:
55-
from jax.extend import ffi # pylint: disable=ungrouped-imports
56-
57-
5851
__all__ = [
5952
"FusedAttnHelper",
6053
"fused_attn_fwd",

transformer_engine/jax/cpp_extensions/base.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,16 @@
77
import warnings
88
from abc import ABCMeta, abstractmethod
99
from functools import partial
10-
from packaging import version
1110

1211
from jax.extend import core
1312
from jax.interpreters import xla, mlir
1413
from jax.experimental.custom_partitioning import custom_partitioning
1514
from jax._src.interpreters import batching
1615
from jax._src import dispatch
16+
from jax import ffi
1717

18-
import jax
1918
import transformer_engine_jax
2019

21-
if version.parse(jax.__version__) >= version.parse("0.5.0"):
22-
from jax import ffi # pylint: disable=ungrouped-imports
23-
else:
24-
from jax.extend import ffi # pylint: disable=ungrouped-imports
25-
2620

2721
class BasePrimitive(metaclass=ABCMeta):
2822
"""

transformer_engine/jax/cpp_extensions/normalization.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
import operator
88
from functools import partial, cache, reduce
99
from typing import Optional, Union
10-
from packaging import version
1110

1211
import jax
1312
import jax.numpy as jnp
14-
from jax import dtypes
13+
from jax import dtypes, ffi
1514
from jax.experimental.custom_partitioning import SdyShardingRule
1615
from jax.interpreters.mlir import ir
1716
from jax.sharding import PartitionSpec
@@ -38,11 +37,6 @@
3837
ScalingMode,
3938
)
4039

41-
if version.parse(jax.__version__) >= version.parse("0.5.0"):
42-
from jax import ffi # pylint: disable=ungrouped-imports
43-
else:
44-
from jax.extend import ffi # pylint: disable=ungrouped-imports
45-
4640

4741
__all__ = [
4842
"layernorm_fwd",

transformer_engine/jax/cpp_extensions/quantization.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
from functools import reduce
77
from typing import Tuple, Optional, Union
88
import math
9-
from packaging import version
109

1110
import jax
1211
import jax.numpy as jnp
13-
from jax import dtypes
12+
from jax import dtypes, ffi
1413
from jax.experimental.custom_partitioning import SdyShardingRule
1514
from jax.sharding import PartitionSpec
1615

@@ -41,11 +40,6 @@
4140
NoScaleTensor,
4241
)
4342

44-
if version.parse(jax.__version__) >= version.parse("0.5.0"):
45-
from jax import ffi # pylint: disable=ungrouped-imports
46-
else:
47-
from jax.extend import ffi # pylint: disable=ungrouped-imports
48-
4943

5044
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
5145

transformer_engine/jax/cpp_extensions/softmax.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,16 @@
66
from functools import partial, reduce
77
import operator
88
import warnings
9-
from packaging import version
109

1110
import jax
1211
import jax.numpy as jnp
13-
from jax import dtypes
12+
from jax import dtypes, ffi
1413
from jax.sharding import PartitionSpec, NamedSharding
1514

1615
from .base import BasePrimitive, register_primitive
1716
from .misc import get_padded_spec, check_valid_batch_dims
1817
from ..softmax import SoftmaxType
1918

20-
if version.parse(jax.__version__) >= version.parse("0.5.0"):
21-
from jax import ffi # pylint: disable=ungrouped-imports
22-
else:
23-
from jax.extend import ffi # pylint: disable=ungrouped-imports
24-
2519

2620
__all__ = [
2721
"scaled_softmax_fwd",

0 commit comments

Comments
 (0)