Skip to content

Commit

Permalink
Merge branch 'main' into fix-typo
Browse files Browse the repository at this point in the history
  • Loading branch information
daviswer authored Oct 10, 2024
2 parents f7f0dff + 2a3b5b0 commit 3cbb433
Show file tree
Hide file tree
Showing 14 changed files with 1,908 additions and 587 deletions.
21 changes: 21 additions & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class train_config:
# dataset and dataloader
use_dummy_dataset: bool = False
data_path: str = "/fsx/data"
file_type: str = "arrow"
col_name: str = "tokens"
tokenizer_path: str = "/fsx/tokenizer"
datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
seq_length: int = 4096
Expand All @@ -22,6 +25,7 @@ class train_config:
eol_token: Optional[int] = None
strip_tokens: str = ""
logical_shards: int = 1024
num_workers: int = 1

# fsdp policies
sharding_strategy: str = "hsdp"
Expand All @@ -33,10 +37,14 @@ class train_config:
# training spec
batch_size: int = 2
num_steps: int = 1000000
training_stage: str = "initial"
learning_rate: float = 3e-4
grad_clip_thresh: float = 1.0
seed: int = 2023

# continued training spec
resuming_dataset: bool = False

# profiling
use_profiler: bool = False
profiler_rank0_only: bool = True
Expand All @@ -51,3 +59,16 @@ class train_config:

# compile
use_torch_compile: bool = True

# speculator training
tp_size: int = 8
model_arch: str = "embedllama"
model_path: str = "/path/to/model/"
n_speculator_heads: int = 3
speculator_width: int = 4096
speculator_tie_weights: bool = True
speculator_scale_input: bool = True
stage2_start_step: int = 15000
stage2_prompt_length: int = 64
stage2_batch_size: int = 96
stage2_seq_length: int = 256
51 changes: 40 additions & 11 deletions fms_fsdp/utils/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Checkpointer:
report_fn : Callable or None
Optional function for reporting or logging status updates. Expected to handle arbitrary *args, **kwargs.
Defaults to self._selective_print().
model_auto_placement : bool
Optional; If True, auto detect GPU device to move model to, as set in device mesh init
Methods
-------
Expand All @@ -87,6 +89,7 @@ def __init__(
rank,
local_rank,
report_fn=None,
model_auto_placement=False,
):
self.max_ckps = n_to_save
self.rank = rank
Expand All @@ -96,6 +99,7 @@ def __init__(
self.p_mode = parallel_mode
assert parallel_mode in ["fsdp", "hsdp", "ddp"]
self.report = self._selective_print if report_fn is None else report_fn
self.model_auto_placement = model_auto_placement

def _selective_print(self, *args, **kwargs):
if self.rank == 0:
Expand Down Expand Up @@ -168,7 +172,14 @@ def _validate_ckp_path(self, path):
return None

def load(
self, model, optimizer, dataloader, path="", reset_stepcount=False, strict=True
self,
model,
optimizer,
dataloader,
path="",
reset_stepcount=False,
strict=True,
is_compiled=False,
):
"""
Handle checkpoint loading for model/optimizer/dataloader from given path, according to arguments.
Expand All @@ -178,27 +189,38 @@ def load(
Strict determines whether to use strict loading or not FOR SINGLEFILE LOADING ONLY.
Returns model, optimizer, dataloader, current step, and current tokens seen.
"""
is_resuming = False
if self._validate_ckp_path(self.ckp_path) is not None:
path = self.ckp_path
reset_stepcount = False
is_resuming = True
load_path = self._validate_ckp_path(path)
if load_path is None:
self.report(
f"No valid checkpoint detected at {path}, starting from scratch."
)
return model, optimizer, dataloader, 0, 0
return model, optimizer, dataloader, 0, 0, False
else:
self.report(f"Prior checkpoint {load_path} detected.")
model_load_time = time.time()
if os.path.isfile(load_path):
checkpoint_data = torch.load(load_path, map_location="cpu")
model.load_state_dict(checkpoint_data.get("model_state"), strict=strict)
model.to(self.local_rank)
if is_compiled:
model._orig_mod.load_state_dict(
checkpoint_data.get("model_state"), strict=strict
)
else:
model.load_state_dict(
checkpoint_data.get("model_state"), strict=strict
)
if self.model_auto_placement:
model.to("cuda")
else:
model.to(self.local_rank)
self.report(
f"Checkpoint {load_path} is a single-file checkpoint containing only a model. Optimizer and dataloader are from scratch.",
model_load_time=time.time() - model_load_time,
)
return model, optimizer, dataloader, 0, 0
return model, optimizer, dataloader, 0, 0, is_resuming
else:
# Load model
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
Expand All @@ -210,12 +232,15 @@ def load(
planner=DefaultLoadPlanner(),
)
model.load_state_dict(model_ckp["model_state"])
model.to(self.local_rank)
if self.model_auto_placement:
model.to("cuda")
else:
model.to(self.local_rank)
self.report(model_load_time=time.time() - model_load_time)
step = 0
ntok = 0
# Load metadata
if not reset_stepcount:
if is_resuming:
metadata = torch.load(os.path.join(load_path, "metadata.pth"))
step = metadata.get("step", 0)
ntok = metadata.get("tokens_seen", 0)
Expand All @@ -239,11 +264,11 @@ def load(
# Load dataset
if dataloader is not None:
data_load_time = time.time()
dataloader.dataset.load_from_path(load_path)
dataloader.dataset.load_from_path(path)
self.report(dataset_load_time=time.time() - data_load_time)
else:
self.report("Skipping dataset load, no dataloader provided.")
return model, optimizer, dataloader, step, ntok
return model, optimizer, dataloader, step, ntok, is_resuming

def save(
self,
Expand Down Expand Up @@ -284,6 +309,7 @@ def save_single_file(
self,
step,
model,
is_compiled=False,
**kwargs,
):
# Note: metadata kwargs cannot contain any of:
Expand All @@ -295,7 +321,10 @@ def save_single_file(
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state = model.state_dict()
if is_compiled:
model_state = model._orig_mod.state_dict()
else:
model_state = model.state_dict()
if self.rank == 0:
metadata = kwargs
metadata["step"] = step
Expand Down
46 changes: 45 additions & 1 deletion fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def get_model_config(model_variant):
kvheads=8,
nlayers=48,
hidden_grow_factor=22016 / 8192,
max_expected_seq_len=16384,
rope_theta=1000000.0,
)
elif model_variant == "llama2_13b":
llama_config = LLaMAConfig(
Expand All @@ -48,12 +50,17 @@ def get_model_config(model_variant):
hidden_grow_factor=13824 / 5120,
)
elif model_variant == "llama2_7b":
llama_config = LLaMAConfig()
llama_config = LLaMAConfig(
hidden_grow_factor=11008 / 4096,
kvheads=32,
)
elif model_variant == "llama2_1.4b":
llama_config = LLaMAConfig(
emb_dim=2048,
nheads=16,
nlayers=24,
hidden_grow_factor=3,
kvheads=4,
)
elif model_variant == "llama3_8b":
llama_config = LLaMAConfig(
Expand All @@ -64,6 +71,7 @@ def get_model_config(model_variant):
nlayers=32,
hidden_grow_factor=3.5,
max_expected_seq_len=8192,
rope_theta=500000.0,
)
elif model_variant == "llama3_8b_4k":
llama_config = LLaMAConfig(
Expand All @@ -74,6 +82,7 @@ def get_model_config(model_variant):
nlayers=32,
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_1.8b":
llama_config = LLaMAConfig(
Expand All @@ -84,6 +93,7 @@ def get_model_config(model_variant):
nlayers=24,
hidden_grow_factor=3.5,
max_expected_seq_len=8192,
rope_theta=500000.0,
)
elif model_variant == "llama3_1.8b_4k":
llama_config = LLaMAConfig(
Expand All @@ -94,6 +104,29 @@ def get_model_config(model_variant):
nlayers=24,
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=3072,
nheads=24,
kvheads=8,
nlayers=24,
hidden_grow_factor=8 / 3,
max_expected_seq_len=8192,
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b_4k":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=3072,
nheads=24,
kvheads=8,
nlayers=24,
hidden_grow_factor=8 / 3,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b":
llama_config = LLaMAConfig(
Expand All @@ -104,6 +137,7 @@ def get_model_config(model_variant):
nlayers=80,
hidden_grow_factor=3.5,
max_expected_seq_len=8192,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b_4k":
llama_config = LLaMAConfig(
Expand All @@ -114,6 +148,16 @@ def get_model_config(model_variant):
nlayers=80,
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_194m_4k":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=1024,
nheads=8,
nlayers=10,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
else:
raise ValueError(f"model variant {model_variant} not supported.")
Expand Down
Loading

0 comments on commit 3cbb433

Please sign in to comment.