diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 58061e9c3c..0c9544b9f8 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -331,7 +331,7 @@ def init_configs( Initialize the Megatron-Bridge bridge and provider objects + hf_config and tokenizer """ tokenizer = get_tokenizer(model_path, trust_remote_code=True) - hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + hf_config_original = AutoConfig.from_pretrained(model_path, trust_remote_code=True) override_config_kwargs = { "bos_token_id": tokenizer.bos_token_id, @@ -339,7 +339,7 @@ def init_configs( "pad_token_id": tokenizer.pad_token_id, } override_config_kwargs.update(model_config_kwargs.get("model_config", {})) - update_model_config(hf_config, override_config_kwargs=override_config_kwargs) + hf_config = update_model_config(hf_config_original, override_config_kwargs=override_config_kwargs) transformer_config_kwargs = ( transformer_config_kwargs @@ -402,7 +402,10 @@ def init_configs( self.provider = provider self.bridge = bridge - self.strategy.hf_config = hf_config + # strategy.hf_config is the on-disk source-of-truth used by + # save_hf_configs and must NOT carry runtime overrides like + # mtp_num_layers=0; assign the un-mutated AutoConfig here. + self.strategy.hf_config = hf_config_original self.tokenizer = tokenizer self.enable_router_replay = megatron_config.moe_enable_routing_replay diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 4ac09c3eb3..0811859bea 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -6,6 +6,7 @@ import socket import sys import time +from copy import deepcopy from datetime import datetime from pathlib import Path @@ -976,15 +977,30 @@ def peer_access_supported(max_num_gpus_per_node: int): def update_model_config(module_config, override_config_kwargs): - """Update the module config with the override_config_kwargs. + """Return a copy of ``module_config`` with ``override_config_kwargs`` applied. + + The returned config is a deep copy, so the caller's input is left + unmodified. Nested dict values in ``override_config_kwargs`` recurse into + the corresponding sub-config attribute (which is already part of the deep + copy, so the recursion mutates the copy in place). Args: module_config: The module config from Huggingface Transformers. override_config_kwargs: The kwargs to override the module config. + + Returns: + A new module config with the overrides applied. """ + new_config = deepcopy(module_config) + _apply_overrides_in_place(new_config, override_config_kwargs) + return new_config + + +def _apply_overrides_in_place(module_config, override_config_kwargs): + """Apply override kwargs to ``module_config`` in place (used for sub-configs).""" for key, val in override_config_kwargs.items(): if isinstance(val, dict): - update_model_config(getattr(module_config, key), val) + _apply_overrides_in_place(getattr(module_config, key), val) else: setattr(module_config, key, val) diff --git a/tests/train/utils/test_update_model_config.py b/tests/train/utils/test_update_model_config.py new file mode 100644 index 0000000000..4e8591d2db --- /dev/null +++ b/tests/train/utils/test_update_model_config.py @@ -0,0 +1,38 @@ +""" +uv run --isolated --extra dev pytest tests/train/utils/test_update_model_config.py +""" + +from copy import deepcopy +from types import SimpleNamespace + +from skyrl.train.utils.utils import update_model_config + + +def _make_config(): + """Build a config-like object with a nested sub-config attribute.""" + sub = SimpleNamespace(mtp_num_layers=4, hidden_size=128) + return SimpleNamespace(num_nextn_predict_layers=4, sub_config=sub) + + +class TestUpdateModelConfigNonMutating: + """Lock in the non-mutating contract of ``update_model_config``.""" + + def test_input_is_not_mutated_at_top_level(self): + cfg = _make_config() + before = deepcopy(cfg) + update_model_config(cfg, {"num_nextn_predict_layers": 0}) + assert cfg.num_nextn_predict_layers == before.num_nextn_predict_layers + assert cfg.sub_config.mtp_num_layers == before.sub_config.mtp_num_layers + + def test_returned_copy_carries_top_level_overrides(self): + cfg = _make_config() + new_cfg = update_model_config(cfg, {"num_nextn_predict_layers": 0}) + assert new_cfg.num_nextn_predict_layers == 0 + assert new_cfg is not cfg + + def test_nested_overrides_do_not_leak_back_into_input(self): + cfg = _make_config() + new_cfg = update_model_config(cfg, {"sub_config": {"mtp_num_layers": 0}}) + assert new_cfg.sub_config.mtp_num_layers == 0 + assert cfg.sub_config.mtp_num_layers == 4 + assert new_cfg.sub_config is not cfg.sub_config