diff --git a/skyrl/backends/skyrl_train/distributed/megatron/megatron_strategy.py b/skyrl/backends/skyrl_train/distributed/megatron/megatron_strategy.py index 0544d3dc29..3c4d097636 100644 --- a/skyrl/backends/skyrl_train/distributed/megatron/megatron_strategy.py +++ b/skyrl/backends/skyrl_train/distributed/megatron/megatron_strategy.py @@ -521,6 +521,14 @@ def save_hf_model(self, bridge, model: MegatronModelWrapper, output_dir: str, to # Only rank 0 saves the Huggingface config and tokenizer. if self.is_rank_0(): + # Preserve any custom modeling artifacts (e.g. modeling_*.py, + # special_tokens_map.json, auto_map-referenced files) that + # trust_remote_code models depend on. save_hf_configs below + # overwrites config.json/tokenizer files with the strategy's + # current view, but save_artifacts is required to copy the + # custom Python modules and other artifacts that + # save_pretrained() alone does not emit. + bridge.hf_pretrained.save_artifacts(work_dir) self.save_hf_configs(self.hf_config, work_dir, tokenizer) self.print(f"Successfully saved HF config and tokenizer to {output_dir}") diff --git a/tests/backends/skyrl_train/distributed/test_megatron_correctness.py b/tests/backends/skyrl_train/distributed/test_megatron_correctness.py index e57f4f5df3..c7cedcd2d9 100644 --- a/tests/backends/skyrl_train/distributed/test_megatron_correctness.py +++ b/tests/backends/skyrl_train/distributed/test_megatron_correctness.py @@ -273,3 +273,86 @@ async def test_non_colocated_megatron_merge_lora_still_pauses(self): dispatch._inference_engine_client.pause_generation.assert_awaited_once() dispatch._inference_engine_client.resume_generation.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# save_hf_model: save_artifacts ordering +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _has_megatron, reason="megatron-core not installed") +class TestSaveHFModelArtifacts: + """Verify ``save_hf_model`` invokes ``save_artifacts`` rank-0-only and in + the correct order relative to ``save_hf_weights`` / ``save_hf_configs``. + """ + + def _build_strategy(self, *, is_rank_0: bool): + from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import ( + MegatronStrategy, + ) + + strategy = MegatronStrategy.__new__(MegatronStrategy) + strategy.hf_config = MagicMock(name="hf_config") + strategy.is_rank_0 = MagicMock(return_value=is_rank_0) + strategy.save_hf_configs = MagicMock(name="save_hf_configs") + strategy.print = MagicMock(name="print") + return strategy + + def _build_bridge_and_model(self): + bridge = MagicMock(name="bridge") + model = MagicMock(name="model") + return bridge, model + + def _patch_module_io(self, *, work_dir: str): + """Patch ``io`` and ``dist`` at the megatron_strategy module level. + + ``io.local_work_dir`` is a context manager yielding ``work_dir``. + """ + io_mock = MagicMock(name="io") + io_mock.local_work_dir.return_value.__enter__.return_value = work_dir + io_mock.local_work_dir.return_value.__exit__.return_value = False + dist_mock = MagicMock(name="dist") + return ( + patch( + "skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.io", + io_mock, + ), + patch( + "skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.dist", + dist_mock, + ), + io_mock, + dist_mock, + ) + + def test_rank0_calls_save_artifacts_before_save_hf_configs(self): + strategy = self._build_strategy(is_rank_0=True) + bridge, model = self._build_bridge_and_model() + io_patch, dist_patch, _io_mock, _dist_mock = self._patch_module_io(work_dir="/tmp/work") + + parent = MagicMock() + parent.attach_mock(bridge.save_hf_weights, "save_hf_weights") + parent.attach_mock(bridge.hf_pretrained.save_artifacts, "save_artifacts") + parent.attach_mock(strategy.save_hf_configs, "save_hf_configs") + + with io_patch, dist_patch: + strategy.save_hf_model(bridge=bridge, model=model, output_dir="/out", tokenizer="tok") + + bridge.save_hf_weights.assert_called_once_with(model.actor_module, "/tmp/work") + bridge.hf_pretrained.save_artifacts.assert_called_once_with("/tmp/work") + strategy.save_hf_configs.assert_called_once_with(strategy.hf_config, "/tmp/work", "tok") + + call_order = [c[0] for c in parent.mock_calls] + assert call_order == ["save_hf_weights", "save_artifacts", "save_hf_configs"] + + def test_non_rank0_skips_save_artifacts_and_save_hf_configs(self): + strategy = self._build_strategy(is_rank_0=False) + bridge, model = self._build_bridge_and_model() + io_patch, dist_patch, _io_mock, _dist_mock = self._patch_module_io(work_dir="/tmp/work") + + with io_patch, dist_patch: + strategy.save_hf_model(bridge=bridge, model=model, output_dir="/out", tokenizer="tok") + + bridge.save_hf_weights.assert_called_once_with(model.actor_module, "/tmp/work") + bridge.hf_pretrained.save_artifacts.assert_not_called() + strategy.save_hf_configs.assert_not_called()