Skip to content

Porting PyTorch weight to Jax #1819

@ranlucienwang

Description

@ranlucienwang

Assume we have a PyTorch Model and a Jax model. Is there a framework where you can port PyTorch layer weight to Jax? I might need to implement many models from PyTorch to Jax, and the only way I can think of that can test the correctness of the algorithm is by initializing and then porting the models.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions