Skip to content

Regression colab example - JAX implementation  #8

@conorhassan

Description

@conorhassan

Hi, great paper!

I implemented the regression colab example (or at least the first VBLLMLP example) in JAX. I wrote the equiv. of the distributions.py by subclassing numpyro.distributions and implemented the Regression and VBLLMLP classes in flax. The model is training but the uncertainty bands are a bit of a mess.

Are there any plans to implement in JAX? Would be keen to maybe help out a little if there was. Would be keen to find the errors in my colab somehow too...

Here is the colab: https://colab.research.google.com/drive/1Rh895u0jP9xEpK7eMOz9JHUX_2CluyLO?usp=sharing

Thanks,
Conor

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions