Skip to content

Problem occured in enn_demo.ipynb #11

Description

@fazaghifari

Hi!

I was trying the enn_demo.ipynb on google colab. Everything seems fine until I run this block of code.

# Train the experiment
experiment.train(FLAGS.num_batch)

and this error appears. Is there something wrong with the JAX version?

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/enn/networks/ensembles.py](https://kh9bbgsdon-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221220-060108-RC02_496713401#) in apply(params, states, inputs, index)
     82       sub_states = jax.tree_map(particle_selector, states)
     83       out, new_sub_states = model.apply(sub_params, sub_states, inputs)
---> 84       new_states = jax.tree_multimap(
     85           lambda s, nss: s.at[index, ...].set(nss), states, new_sub_states)
     86       return out, new_states

AttributeError: module 'jax' has no attribute 'tree_multimap'

Thanks,
Adam

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions