Skip to content

Commit 9dc014b

Browse files
authored
fix: change default float format from tf32 to fp32 (#12)
1 parent 414d090 commit 9dc014b

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

examples/online/main_dmc_offpolicy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gymnasium as gym
44
import hydra
5+
import jax
56
import jax.numpy as jnp
67
import numpy as np
78
import omegaconf
@@ -17,6 +18,8 @@
1718
from flowrl.utils.logger import CompositeLogger
1819
from flowrl.utils.misc import set_seed_everywhere
1920

21+
jax.config.update("jax_default_matmul_precision", "float32")
22+
2023
SUPPORTED_AGENTS: Dict[str, BaseAgent] = {
2124
"sac": SACAgent,
2225
"td3": TD3Agent,

examples/online/main_mujoco_offpolicy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gymnasium as gym
44
import gymnasium_robotics
55
import hydra
6+
import jax
67
import numpy as np
78
import omegaconf
89
import wandb
@@ -16,6 +17,8 @@
1617
from flowrl.utils.logger import CompositeLogger
1718
from flowrl.utils.misc import set_seed_everywhere
1819

20+
jax.config.update("jax_default_matmul_precision", "float32")
21+
1922
SUPPORTED_AGENTS: Dict[str, BaseAgent] = {
2023
"sac": SACAgent,
2124
"td3": TD3Agent,

examples/online/main_mujoco_onpolicy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gymnasium as gym
55
import gymnasium_robotics
66
import hydra
7+
import jax
78
import jax.numpy as jnp
89
import numpy as np
910
import wandb
@@ -16,6 +17,8 @@
1617
from flowrl.utils.logger import CompositeLogger
1718
from flowrl.utils.misc import set_seed_everywhere
1819

20+
jax.config.update("jax_default_matmul_precision", "float32")
21+
1922
SUPPORTED_AGENTS: Dict[str, BaseAgent] = {
2023
"ppo": PPOAgent,
2124
}

0 commit comments

Comments
 (0)