Let's consider a simple compiled function in tensorflow.
import tensorflow as tf
a = tf.random.normal(shape=(2, 10))
b = tf.random.normal(shape=(10, 3))
@tf.function
def tf_compiled_func(a, b):
c = tf.matmul(a, b)
return c
tf_compiled_func(a, b)
This bunch of code works.
However, its "universal" version :
@tf.function
def compiled_universal_func(a, b):
a, b = ep.astensors(a, b)
c = a.matmul(b)
return c.raw
a = tf.random.normal(shape=(2, 10))
b = tf.random.normal(shape=(10, 3))
compiled_universal_func(a, b)
does not work and raises the error:
.../lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
975 except Exception as e: # pylint:disable=broad-except
976 if hasattr(e, "ag_error_metadata"):
--> 977 raise e.ag_error_metadata.to_exception(e)
978 else:
979 raise
AttributeError: in user code:
<ipython-input-8-6edbe80953ee>:4 compiled_universal_func *
c = a.matmul(b)
.../lib/python3.7/site-packages/eagerpy/tensor/tensorflow.py:499 matmul *
if self.ndim != 2 or other.ndim != 2:
.../lib/python3.7/site-packages/eagerpy/tensor/base.py:115 ndim
return cast(int, self.raw.ndim)
AttributeError: 'Tensor' object has no attribute 'ndim'
(but it works if we comment the @tf.function)
Let's notice that the equivalent thing with jax seems to work:
import jax
from jax import jit
@jit
def compiled_universal_func(a, b):
a, b = ep.astensors(a, b)
c = a.matmul(b)
return c.raw
seed = 1701
key = jax.random.PRNGKey(seed)
a = jax.random.normal(shape=(2, 10), key=key)
b = jax.random.normal(shape=(10, 3), key=key)
compiled_universal_func(a, b)
Is it a problem with the integration of eagerpy with tensorflow ?
Let's consider a simple compiled function in tensorflow.
This bunch of code works.
However, its "universal" version :
does not work and raises the error:
(but it works if we comment the
@tf.function)Let's notice that the equivalent thing with jax seems to work:
Is it a problem with the integration of eagerpy with tensorflow ?