Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,15 @@ def init_configs(
Initialize the Megatron-Bridge bridge and provider objects + hf_config and tokenizer
"""
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
hf_config_original = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

override_config_kwargs = {
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
override_config_kwargs.update(model_config_kwargs.get("model_config", {}))
update_model_config(hf_config, override_config_kwargs=override_config_kwargs)
hf_config = update_model_config(hf_config_original, override_config_kwargs=override_config_kwargs)

transformer_config_kwargs = (
transformer_config_kwargs
Expand Down Expand Up @@ -402,7 +402,10 @@ def init_configs(
self.provider = provider
self.bridge = bridge

self.strategy.hf_config = hf_config
# strategy.hf_config is the on-disk source-of-truth used by
# save_hf_configs and must NOT carry runtime overrides like
# mtp_num_layers=0; assign the un-mutated AutoConfig here.
self.strategy.hf_config = hf_config_original
self.tokenizer = tokenizer
self.enable_router_replay = megatron_config.moe_enable_routing_replay

Expand Down
20 changes: 18 additions & 2 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import socket
import sys
import time
from copy import deepcopy
from datetime import datetime
from pathlib import Path

Expand Down Expand Up @@ -976,15 +977,30 @@ def peer_access_supported(max_num_gpus_per_node: int):


def update_model_config(module_config, override_config_kwargs):
"""Update the module config with the override_config_kwargs.
"""Return a copy of ``module_config`` with ``override_config_kwargs`` applied.

The returned config is a deep copy, so the caller's input is left
unmodified. Nested dict values in ``override_config_kwargs`` recurse into
the corresponding sub-config attribute (which is already part of the deep
copy, so the recursion mutates the copy in place).

Args:
module_config: The module config from Huggingface Transformers.
override_config_kwargs: The kwargs to override the module config.

Returns:
A new module config with the overrides applied.
"""
new_config = deepcopy(module_config)
_apply_overrides_in_place(new_config, override_config_kwargs)
return new_config


def _apply_overrides_in_place(module_config, override_config_kwargs):
"""Apply override kwargs to ``module_config`` in place (used for sub-configs)."""
for key, val in override_config_kwargs.items():
if isinstance(val, dict):
update_model_config(getattr(module_config, key), val)
_apply_overrides_in_place(getattr(module_config, key), val)
else:
setattr(module_config, key, val)
Comment on lines +999 to 1005
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _apply_overrides_in_place assumes that module_config and all its nested attributes are objects supporting getattr and setattr. However, many HuggingFace configurations contain standard dictionary attributes (e.g., rope_scaling, quantization_config). If an override contains a nested dictionary for one of these attributes, calling getattr on it returns a dict, and recursing into it will cause an AttributeError when setattr is called on the dictionary keys.

Additionally, if module_config is None or if a nested attribute does not exist (returning None), the function will crash. Enforcing defensive programming with proper None checks and handling both dictionaries and objects will make this utility much more robust.

def _apply_overrides_in_place(module_config, override_config_kwargs):
    """Apply override kwargs to ``module_config`` in place (used for sub-configs)."""
    if module_config is None:
        return
    for key, val in override_config_kwargs.items():
        is_dict = isinstance(module_config, dict)
        current_val = module_config.get(key) if is_dict else getattr(module_config, key, None)

        if isinstance(val, dict) and current_val is not None:
            _apply_overrides_in_place(current_val, val)
        else:
            if is_dict:
                module_config[key] = val
            else:
                setattr(module_config, key, val)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Real issue, but pre-existing — the recursive getattr/setattr pattern is unchanged from the prior in-place version; I only relocated it. The only caller (init_configs) passes scalar overrides, so no observable regression today.

Also, the suggested patch conflates "recurse into existing sub-config" with "assign new dict" when current_val is None. Worth a dedicated PR. Happy to file a follow-up issue.


Expand Down
38 changes: 38 additions & 0 deletions tests/train/utils/test_update_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
uv run --isolated --extra dev pytest tests/train/utils/test_update_model_config.py
"""

from copy import deepcopy
from types import SimpleNamespace

from skyrl.train.utils.utils import update_model_config


def _make_config():
"""Build a config-like object with a nested sub-config attribute."""
sub = SimpleNamespace(mtp_num_layers=4, hidden_size=128)
return SimpleNamespace(num_nextn_predict_layers=4, sub_config=sub)


class TestUpdateModelConfigNonMutating:
"""Lock in the non-mutating contract of ``update_model_config``."""

def test_input_is_not_mutated_at_top_level(self):
cfg = _make_config()
before = deepcopy(cfg)
update_model_config(cfg, {"num_nextn_predict_layers": 0})
assert cfg.num_nextn_predict_layers == before.num_nextn_predict_layers
assert cfg.sub_config.mtp_num_layers == before.sub_config.mtp_num_layers

def test_returned_copy_carries_top_level_overrides(self):
cfg = _make_config()
new_cfg = update_model_config(cfg, {"num_nextn_predict_layers": 0})
assert new_cfg.num_nextn_predict_layers == 0
assert new_cfg is not cfg

def test_nested_overrides_do_not_leak_back_into_input(self):
cfg = _make_config()
new_cfg = update_model_config(cfg, {"sub_config": {"mtp_num_layers": 0}})
assert new_cfg.sub_config.mtp_num_layers == 0
assert cfg.sub_config.mtp_num_layers == 4
assert new_cfg.sub_config is not cfg.sub_config
Loading