Stream HF→OLMo state conversion to lower load_hf_model peak memory#661
Stream HF→OLMo state conversion to lower load_hf_model peak memory#661finbarrtimbers wants to merge 1 commit into
Conversation
…-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7f595db9c8
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| assert file_exists(f"{prefix}/model.safetensors.index.json") or file_exists( | ||
| f"{prefix}/model.safetensors" | ||
| ) |
There was a problem hiding this comment.
Keep
pytorch_model.bin support in HF checkpoint loading
load_hf_model now enforces safetensors-only inputs, so any HF checkpoint that only contains pytorch_model.bin (common in older/internal repos) will fail immediately instead of loading. This is a behavioral regression from the previous implementation, which could load .bin checkpoints through AutoModelForCausalLM.from_pretrained, and it will break existing conversion workflows unless users manually re-export models to safetensors first.
Useful? React with 👍 / 👎.
Summary
load_hf_modelpreviously instantiated the full HF model twice (once on rank 0 to warm the cache, then again on every rank) just to extract itsstate_dict. For a 32B bf16 model that's ~64GB resident per rank during conversion. This PR drops the model materialization entirely: it readsAutoConfig, then streams tensors directly from the on-disk safetensors files (sharded or single-file) viasafe_open.StateConverter.iter_convert(...)yields(dest_key, tensor)pairs and frees each mapping's source/intermediate tensors before moving on;convert(...)is a thindict(self.iter_convert(...))wrapper. A newiter_convert_state_from_hf(...)plumbs the same pattern through the HF-side converter (with the gemma3+1.0norm transform applied per-key inline).load_hf_modelconsumes it directly so each tensor is redistributed into its target DTensor and the source HF tensor is freed before the next read.huggingface-hub<1.0inpyproject.tomlto keeptransformershappy (4.57.x requires<1.0); without this,uv run pytestresolved tohuggingface-hub 1.12and brokefrom transformers import ...on import.Test plan
uv run pytest src/test/nn/hf/convert_test.py src/test/nn/conversion/— 34/34 pass, including a newtest_iter_convert_state_from_hf_matches_convert_state_from_hfcovering embeddings, lm_head, attention QKV/O, MLP, layernorms, and q/k norms.make style-check/make lint-checkclean.load_hf_modelon a real model (recommend reviewer or follow-up CI run, since local env can't pull large checkpoints).🤖 Generated with Claude Code