Skip to content

Does a universal function can be compiled in tensorflow? #36

@eserie

Description

@eserie

Let's consider a simple compiled function in tensorflow.

import tensorflow as tf
a = tf.random.normal(shape=(2, 10))
b = tf.random.normal(shape=(10, 3))

@tf.function
def tf_compiled_func(a, b):
    c = tf.matmul(a, b)
    return c

tf_compiled_func(a, b)

This bunch of code works.

However, its "universal" version :

@tf.function
def compiled_universal_func(a, b):
    a, b = ep.astensors(a, b)
    c = a.matmul(b)
    return c.raw


a = tf.random.normal(shape=(2, 10))
b = tf.random.normal(shape=(10, 3))
compiled_universal_func(a, b)

does not work and raises the error:

.../lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

AttributeError: in user code:

    <ipython-input-8-6edbe80953ee>:4 compiled_universal_func  *
        c = a.matmul(b)
    .../lib/python3.7/site-packages/eagerpy/tensor/tensorflow.py:499 matmul  *
        if self.ndim != 2 or other.ndim != 2:
    .../lib/python3.7/site-packages/eagerpy/tensor/base.py:115 ndim
        return cast(int, self.raw.ndim)

    AttributeError: 'Tensor' object has no attribute 'ndim'

(but it works if we comment the @tf.function)

Let's notice that the equivalent thing with jax seems to work:

import jax
from jax import jit

@jit
def compiled_universal_func(a, b):
    a, b = ep.astensors(a, b)
    c = a.matmul(b)
    return c.raw


seed = 1701
key = jax.random.PRNGKey(seed)
a = jax.random.normal(shape=(2, 10), key=key)
b = jax.random.normal(shape=(10, 3), key=key)
compiled_universal_func(a, b)

Is it a problem with the integration of eagerpy with tensorflow ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions