File tree Expand file tree Collapse file tree 3 files changed +9
-0
lines changed
Expand file tree Collapse file tree 3 files changed +9
-0
lines changed Original file line number Diff line number Diff line change 22
33import gymnasium as gym
44import hydra
5+ import jax
56import jax .numpy as jnp
67import numpy as np
78import omegaconf
1718from flowrl .utils .logger import CompositeLogger
1819from flowrl .utils .misc import set_seed_everywhere
1920
21+ jax .config .update ("jax_default_matmul_precision" , "float32" )
22+
2023SUPPORTED_AGENTS : Dict [str , BaseAgent ] = {
2124 "sac" : SACAgent ,
2225 "td3" : TD3Agent ,
Original file line number Diff line number Diff line change 33import gymnasium as gym
44import gymnasium_robotics
55import hydra
6+ import jax
67import numpy as np
78import omegaconf
89import wandb
1617from flowrl .utils .logger import CompositeLogger
1718from flowrl .utils .misc import set_seed_everywhere
1819
20+ jax .config .update ("jax_default_matmul_precision" , "float32" )
21+
1922SUPPORTED_AGENTS : Dict [str , BaseAgent ] = {
2023 "sac" : SACAgent ,
2124 "td3" : TD3Agent ,
Original file line number Diff line number Diff line change 44import gymnasium as gym
55import gymnasium_robotics
66import hydra
7+ import jax
78import jax .numpy as jnp
89import numpy as np
910import wandb
1617from flowrl .utils .logger import CompositeLogger
1718from flowrl .utils .misc import set_seed_everywhere
1819
20+ jax .config .update ("jax_default_matmul_precision" , "float32" )
21+
1922SUPPORTED_AGENTS : Dict [str , BaseAgent ] = {
2023 "ppo" : PPOAgent ,
2124}
You can’t perform that action at this time.
0 commit comments