Skip to content

checkpoint.load_config should ignore None kernel initializer #664

@fengzileee

Description

@fengzileee

Kernel initializers can be set to and saved as None. In those cases, we should not retrieve them from networks.KERNEL_INITIALIZER.
This issue is probably similar to what's encountered here 770e837.

For example, mean_kernel_init_fn could be None.

mean_kernel_init_fn: networks.Initializer | None = None,

When loaded from the saved checkpoint, we will try to access networks.KERNEL_INITIALIZER[None], resulting in an error:

Traceback (most recent call last):
  File "/Users/linfeng/workspace/brax/scripts/reproduce.py", line 47, in <module>
    inference_fn = ppo_checkpoint.load_policy(
        checkpoints[0],
        network_factory=network_factory,
        deterministic=True,
    )
  File "/Users/linfeng/workspace/brax/brax/training/agents/ppo/checkpoint.py", line 83, in load_policy
    config = load_config(path)
  File "/Users/linfeng/workspace/brax/brax/training/agents/ppo/checkpoint.py", line 71, in load_config
    return checkpoint.load_config(config_path)
           ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
  File "/Users/linfeng/workspace/brax/brax/training/checkpoint.py", line 229, in load_config
    networks.KERNEL_INITIALIZER[init_fn_name_]
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
KeyError: None

A reproducible example, with python==3.13.0, brax==0.14.1 and playground==0.1.0, run on macOS 15.5:

#!/usr/bin/env python3
import glob
import functools
from pathlib import Path
from brax.training.agents.ppo import train
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import checkpoint as ppo_checkpoint
from mujoco_playground import wrapper

from mujoco_playground.config import manipulation_params
from mujoco_playground import registry

env_name = "AeroCubeRotateZAxis"
env_cfg = registry.get_default_config(env_name)
env_cfg.episode_length = 10
env = registry.load(env_name, env_cfg)

ppo_params = manipulation_params.brax_ppo_config(env_name)
ppo_params.num_timesteps = 1
ppo_params.num_evals = 1
ppo_params.num_envs = 1
ppo_params.num_minibatches = 1
ppo_params.num_updates_per_batch = 1
ppo_params.batch_size = 1
ppo_training_params = dict(ppo_params)

network_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_obs_key="privileged_state",
    value_obs_key="privileged_state",
    policy_hidden_layer_sizes=(128,),
    value_hidden_layer_sizes=(128,),
)
del ppo_training_params["network_factory"]
ppo_training_params["network_factory"] = network_factory

checkpoints_path = Path("./checkpoints").expanduser().resolve().as_posix()
make_inference_fn, params, metrics = train.train(
    environment=env,
    **dict(ppo_training_params),
    save_checkpoint_path=checkpoints_path,
    wrap_env_fn=wrapper.wrap_for_brax_training,
)

checkpoints = glob.glob(f"{checkpoints_path}/*/")

inference_fn = ppo_checkpoint.load_policy(
    checkpoints[0],
    network_factory=network_factory,
    deterministic=True,
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions