diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index e146ac94..2bbeef18 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -20,9 +20,14 @@ from torch.distributed.fsdp import StateDictType -def get_latest(targdir, qualifier=lambda x: True): - """Fetch the latest file or folder written to target directory, subject to name passing the qualifier fn. - If directory is empty or nonexistent or no items qualify, return None.""" +def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime): + """ + Fetch the full path of the latest file or folder written to target directory, + subject to name passing the qualifier fn. + Optional key fn can be used for custom sorting. + Both functions take full path arguments. + If directory is empty or nonexistent or no items qualify, return None. + """ if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: latest = max( [ @@ -30,15 +35,20 @@ def get_latest(targdir, qualifier=lambda x: True): for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x)) ], - key=lambda path: int(path.split("/")[-1].split("_")[1]), + key=key, ) - return os.path.join(targdir, latest) + return latest return None -def get_oldest(targdir, qualifier=lambda x: True): - """Fetch the oldest file or folder written to target directory, subject to name passing the qualifier fn. - If directory is empty or nonexistent or no items qualify, return None.""" +def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime): + """ + Fetch the full path of the oldest file or folder written to target directory, + subject to name passing the qualifier fn. + Optional key fn can be used for custom sorting. + Both functions take full path arguments. + If directory is empty or nonexistent or no items qualify, return None. + """ if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: oldest = min( [ @@ -46,9 +56,9 @@ def get_oldest(targdir, qualifier=lambda x: True): for x in os.listdir(targdir) if qualifier(os.path.join(targdir, x)) ], - key=os.path.getctime, + key=key, ) - return os.path.join(targdir, oldest) + return oldest return None @@ -118,7 +128,7 @@ def _cleanup(self): ckp_to_remove = Path( get_oldest(self.ckp_path, qualifier=lambda x: "tmp" in x) ) - if os.path.is_file(ckp_to_remove): + if os.path.isfile(ckp_to_remove): ckp_to_remove.unlink() else: shutil.rmtree(ckp_to_remove) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index f8996a28..d1d442d7 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -32,7 +32,7 @@ rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be re-distributed over workers (i.e. buffers). -Our loaders obey the following type heirarchy: +Our loaders obey the following type hierarchy: torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset. `_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times, @@ -510,8 +510,8 @@ def _validate_ckp_path(self, path: str, verbose: bool = False): f" Dataset: No valid checkpoint detected at {path}, dataset starting from scratch." ) return "" - # Check latest path - latest = os.path.join(path, get_latest(path)) + # Check latest path, using ckp naming syntax + latest = get_latest(path, key=lambda path: int(path.split("_")[-2])) if verbose: self.report(f"Checkpoint detected at {latest}") # If item is not a folder, exit early diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 87b4e7b2..0a265a63 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -1,7 +1,7 @@ import os import re import time -from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union import torch import torch.distributed as dist @@ -437,11 +437,12 @@ class EmbedGPTBigCode(GPTBigCode): # Overrides the forward function of GPTBigCode to allow returning embedding vectors def forward( self, - x: torch.LongTensor, + x: torch.Tensor, mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, + only_last_token: bool = False, attn_algorithm: Optional[str] = None, include_embeds: bool = False, ):