@@ -18,6 +18,7 @@ from tensorflow import (
1818 io as io ,
1919 keras as keras ,
2020 math as math ,
21+ nn as nn ,
2122 random as random ,
2223 types as types ,
2324)
@@ -37,7 +38,7 @@ from tensorflow.core.protobuf import struct_pb2
3738from tensorflow .dtypes import *
3839from tensorflow .experimental .dtensor import Layout
3940from tensorflow .keras import losses as losses
40- from tensorflow .linalg import eye as eye
41+ from tensorflow .linalg import eye as eye , matmul as matmul
4142
4243# Most tf.math functions are exported as tf, but sadly not all are.
4344from tensorflow .math import (
@@ -385,6 +386,13 @@ def squeeze(
385386) -> Tensor : ...
386387@overload
387388def squeeze (input : RaggedTensor , axis : int | tuple [int , ...] | list [int ], name : str | None = None ) -> RaggedTensor : ...
389+ def split (
390+ value : TensorCompatible ,
391+ num_or_size_splits : int | TensorCompatible ,
392+ axis : int | Tensor = 0 ,
393+ num : int | None = None ,
394+ name : str | None = "split" ,
395+ ) -> list [Tensor ]: ...
388396def tensor_scatter_nd_update (
389397 tensor : TensorCompatible , indices : TensorCompatible , updates : TensorCompatible , name : str | None = None
390398) -> Tensor : ...
@@ -434,4 +442,10 @@ def gather_nd(
434442 name : str | None = None ,
435443 bad_indices_policy : Literal ["" , "DEFAULT" , "ERROR" , "IGNORE" ] = "" ,
436444) -> Tensor : ...
445+ def transpose (
446+ a : Tensor , perm : Sequence [int ] | IntArray | None = None , conjugate : _bool = False , name : str = "transpose"
447+ ) -> Tensor : ...
448+ def clip_by_value (
449+ t : Tensor | IndexedSlices , clip_value_min : TensorCompatible , clip_value_max : TensorCompatible , name : str | None = None
450+ ) -> Tensor : ...
437451def __getattr__ (name : str ): ... # incomplete module
0 commit comments