Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/flow_factory/models/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,13 +1491,17 @@ def _resolve_checkpoint_path(self, path: str) -> str:
Raises:
FileNotFoundError: When the spec is neither a local path nor a reachable HF repo.
"""
# Normalize leading ``~`` for local-path inputs; no-op for HF specs since
# ``expanduser`` only acts on a leading ``~``.
path = os.path.expanduser(path)
force_hf = path.startswith(HF_PATH_PREFIX)
spec = path[len(HF_PATH_PREFIX):] if force_hf else path

if not force_hf and os.path.exists(spec):
return spec
# Local path wins unless an explicit ``hf://`` prefix forces remote.
if not force_hf and os.path.exists(path):
return path

repo_id, subfolder, revision = parse_hf_checkpoint_path(spec)
# ``parse_hf_checkpoint_path`` handles the ``hf://`` prefix internally.
repo_id, subfolder, revision = parse_hf_checkpoint_path(path)

try:
local_path = download_hf_checkpoint(repo_id, subfolder, revision)
Expand Down Expand Up @@ -1741,12 +1745,7 @@ def load_checkpoint(
- 'state': Load full training state (model + optimizer + scheduler + RNG)
- None: Auto-detect based on checkpoint directory contents
"""
path = os.path.expanduser(path)
path = self._resolve_checkpoint_path(path)
if not os.path.exists(path):
raise FileNotFoundError(
f"Checkpoint path not found locally or on Hugging Face Hub: {path!r}"
)

# Auto-detect if not specified
if resume_type is None:
Expand Down