Skip to content

Commit 408c751

Browse files
authored
Loader ckp fixes (#119)
Updates `get_latest` and `get_oldest` to use the same sorting function, and allows the dataloader ckp handler to pass in its custom sort manually. Removes the bug where excessive path joins lead to repeated path prefixes in dataloader ckp loading. Fixes GPTBigCode signatures used for speculator training, to match superclass signatures (currently preventing other PRs from landing). Includes and subsumes #110 and #96. Full credit to @weiji14 and @Akash-Nayak respectively
1 parent 01439c8 commit 408c751

3 files changed

Lines changed: 28 additions & 17 deletions

File tree

fms_fsdp/utils/checkpointing_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,45 @@
2020
from torch.distributed.fsdp import StateDictType
2121

2222

23-
def get_latest(targdir, qualifier=lambda x: True):
24-
"""Fetch the latest file or folder written to target directory, subject to name passing the qualifier fn.
25-
If directory is empty or nonexistent or no items qualify, return None."""
23+
def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime):
24+
"""
25+
Fetch the full path of the latest file or folder written to target directory,
26+
subject to name passing the qualifier fn.
27+
Optional key fn can be used for custom sorting.
28+
Both functions take full path arguments.
29+
If directory is empty or nonexistent or no items qualify, return None.
30+
"""
2631
if os.path.exists(targdir) and len(os.listdir(targdir)) > 0:
2732
latest = max(
2833
[
2934
os.path.join(targdir, x)
3035
for x in os.listdir(targdir)
3136
if qualifier(os.path.join(targdir, x))
3237
],
33-
key=lambda path: int(path.split("/")[-1].split("_")[1]),
38+
key=key,
3439
)
35-
return os.path.join(targdir, latest)
40+
return latest
3641
return None
3742

3843

39-
def get_oldest(targdir, qualifier=lambda x: True):
40-
"""Fetch the oldest file or folder written to target directory, subject to name passing the qualifier fn.
41-
If directory is empty or nonexistent or no items qualify, return None."""
44+
def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime):
45+
"""
46+
Fetch the full path of the oldest file or folder written to target directory,
47+
subject to name passing the qualifier fn.
48+
Optional key fn can be used for custom sorting.
49+
Both functions take full path arguments.
50+
If directory is empty or nonexistent or no items qualify, return None.
51+
"""
4252
if os.path.exists(targdir) and len(os.listdir(targdir)) > 0:
4353
oldest = min(
4454
[
4555
os.path.join(targdir, x)
4656
for x in os.listdir(targdir)
4757
if qualifier(os.path.join(targdir, x))
4858
],
49-
key=os.path.getctime,
59+
key=key,
5060
)
51-
return os.path.join(targdir, oldest)
61+
return oldest
5262
return None
5363

5464

fms_fsdp/utils/dataset_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be
3333
re-distributed over workers (i.e. buffers).
3434
35-
Our loaders obey the following type heirarchy:
35+
Our loaders obey the following type hierarchy:
3636
torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset.
3737
`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a
3838
single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times,
@@ -510,8 +510,8 @@ def _validate_ckp_path(self, path: str, verbose: bool = False):
510510
f" Dataset: No valid checkpoint detected at {path}, dataset starting from scratch."
511511
)
512512
return ""
513-
# Check latest path
514-
latest = os.path.join(path, get_latest(path))
513+
# Check latest path, using ckp naming syntax
514+
latest = get_latest(path, key=lambda path: int(path.split("_")[-2]))
515515
if verbose:
516516
self.report(f"Checkpoint detected at {latest}")
517517
# If item is not a folder, exit early

speculator/train_speculator_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import re
33
import time
4-
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union
4+
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
55

66
import torch
77
import torch.distributed as dist
@@ -437,11 +437,12 @@ class EmbedGPTBigCode(GPTBigCode):
437437
# Overrides the forward function of GPTBigCode to allow returning embedding vectors
438438
def forward(
439439
self,
440-
x: torch.LongTensor,
440+
x: torch.Tensor,
441441
mask: Optional[torch.Tensor] = None,
442-
position_ids: Optional[torch.LongTensor] = None,
443-
past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None,
442+
position_ids: Optional[torch.Tensor] = None,
443+
past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
444444
use_cache: bool = False,
445+
only_last_token: bool = False,
445446
attn_algorithm: Optional[str] = None,
446447
include_embeds: bool = False,
447448
):

0 commit comments

Comments
 (0)