From da29217d88efc7c504383a159fb21baeece8bd0e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 22 May 2024 15:06:25 -0400 Subject: [PATCH 01/73] Set llama2-1.4b to gqa --- fms_fsdp/utils/config_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 9d3f0386..13384031 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -54,6 +54,8 @@ def get_model_config(model_variant): emb_dim=2048, nheads=16, nlayers=24, + hidden_grow_factor=3, + kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From 41ae740d00838701afe4936f0b17501cd1767ebb Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 28 May 2024 14:58:55 -0400 Subject: [PATCH 02/73] Add singlefile ckp saving/conversion --- fms_fsdp/utils/train_utils.py | 1 + main_training.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 9b507e6b..1a75bd44 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -77,6 +77,7 @@ def train( start = time.time() loop_start = time.time() + train_loss = -1 for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1): if batch_idx > cfg.num_steps: break diff --git a/main_training.py b/main_training.py index 9c4c5e44..21a24429 100644 --- a/main_training.py +++ b/main_training.py @@ -156,6 +156,8 @@ def main(**kwargs): tokens_seen, ) + checkpointer.save_single_file(cfg.num_steps, model) + dist.barrier() dist.destroy_process_group() From 5171b5d427acdc4723b61e32fde8c50c24f946f6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 31 May 2024 18:39:08 -0400 Subject: [PATCH 03/73] Turn off GQA on 1.4B --- fms_fsdp/utils/config_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 13384031..3e046f63 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -54,8 +54,8 @@ def get_model_config(model_variant): emb_dim=2048, nheads=16, nlayers=24, - hidden_grow_factor=3, - kvheads=4, + # hidden_grow_factor=3, + # kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From abd5b197cda7ec69dc32b6f9c8fb8f251247860f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 5 Jun 2024 14:34:40 -0400 Subject: [PATCH 04/73] GQA on, add for 7b --- fms_fsdp/utils/config_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 3e046f63..b63f9b10 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -48,14 +48,17 @@ def get_model_config(model_variant): hidden_grow_factor=13824 / 5120, ) elif model_variant == "llama2_7b": - llama_config = LLaMAConfig() + llama_config = LLaMAConfig( + hidden_grow_factor=3, + kvheads=8, + ) elif model_variant == "llama2_1.4b": llama_config = LLaMAConfig( emb_dim=2048, nheads=16, nlayers=24, - # hidden_grow_factor=3, - # kvheads=4, + hidden_grow_factor=3, + kvheads=4, ) elif model_variant == "llama3_8b": llama_config = LLaMAConfig( From 0ac0a5f00ef6b06e1648b7cf738a813dff93b27e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Jun 2024 19:15:47 -0400 Subject: [PATCH 05/73] Add llama3 tele cfg --- fms_fsdp/utils/config_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index b63f9b10..3546cd5b 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -100,6 +100,15 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) + elif model_variant == "llama3_1.8b_tele": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=2048, + nheads=32, + kvheads=2, + nlayers=24, + hidden_grow_factor=3.75, + max_expected_seq_len=4096, elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 8caeaa24bda8bf7e92727bd6386e1b842983689f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 10 Jun 2024 19:37:57 -0400 Subject: [PATCH 06/73] Add missing paren --- fms_fsdp/utils/config_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 3546cd5b..fdd31780 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -109,6 +109,7 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=3.75, max_expected_seq_len=4096, + ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 941e98fe5c57f8380485977b574fed6dd73391e7 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 13 Jun 2024 11:21:20 -0400 Subject: [PATCH 07/73] Back to gqa4 for llama3 --- fms_fsdp/utils/config_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index fdd31780..5e0ae2e7 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -104,10 +104,10 @@ def get_model_config(model_variant): llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, - nheads=32, - kvheads=2, + nheads=16, + kvheads=4, nlayers=24, - hidden_grow_factor=3.75, + hidden_grow_factor=3.5, max_expected_seq_len=4096, ) elif model_variant == "llama3_70b": From 44edc0d57feb4520c5b5b2d20a1583b3041030fd Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 18 Jun 2024 22:36:34 -0400 Subject: [PATCH 08/73] Nonstrict ckpt load --- main_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main_training.py b/main_training.py index 21a24429..4eabf152 100644 --- a/main_training.py +++ b/main_training.py @@ -123,6 +123,7 @@ def main(**kwargs): optimizer, None, path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), + strict=False, ) # LR schedule From a48a055288d6e6ec74356048553ce65ab66bb9a5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 18 Jun 2024 22:56:52 -0400 Subject: [PATCH 09/73] If singlefile load, don't append "checkpoints" folder --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 4eabf152..afae1b49 100644 --- a/main_training.py +++ b/main_training.py @@ -122,7 +122,7 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") if not os.path.isfile(cfg.ckpt_load_path) else cfg.ckpt_load_path, strict=False, ) From 9031328050b05d3d1f3a12ed1c9b637c8c8a43b5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 21 Jun 2024 12:47:48 -0400 Subject: [PATCH 10/73] Add reset stepcount field --- fms_fsdp/config/training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index e8b4df06..6b1f3888 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -36,6 +36,7 @@ class train_config: learning_rate: float = 3e-4 grad_clip_thresh: float = 1.0 seed: int = 2023 + reset_stepcount: bool = False # profiling use_profiler: bool = False From 0e3430a451200b116aa69093a03e3a7dd252de6c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 21 Jun 2024 12:48:48 -0400 Subject: [PATCH 11/73] Add reset stepcount support --- main_training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main_training.py b/main_training.py index afae1b49..cd858156 100644 --- a/main_training.py +++ b/main_training.py @@ -125,6 +125,8 @@ def main(**kwargs): path=os.path.join(cfg.ckpt_load_path, "checkpoints/") if not os.path.isfile(cfg.ckpt_load_path) else cfg.ckpt_load_path, strict=False, ) + if cfg.reset_stepcount: + start_step = 0 # LR schedule warmup_interval = min(2000, cfg.num_steps // 20) From 45d7e414089233e4f605c3b6cbf24a54efa80309 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 21 Jun 2024 15:07:07 -0400 Subject: [PATCH 12/73] Override optimizer LR values with desired --- main_training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main_training.py b/main_training.py index cd858156..fc3625ee 100644 --- a/main_training.py +++ b/main_training.py @@ -127,6 +127,9 @@ def main(**kwargs): ) if cfg.reset_stepcount: start_step = 0 + # Override loaded optim hyperparams with the current values + for g in optimizer.param_groups: + g["initial_lr"] = cfg.learning_rate # LR schedule warmup_interval = min(2000, cfg.num_steps // 20) From 9cb032986cfa3bb641aa1726159e651be503f46c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Jun 2024 11:41:57 -0400 Subject: [PATCH 13/73] gqa16 --- fms_fsdp/utils/config_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 5e0ae2e7..fdd31780 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -104,10 +104,10 @@ def get_model_config(model_variant): llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, - nheads=16, - kvheads=4, + nheads=32, + kvheads=2, nlayers=24, - hidden_grow_factor=3.5, + hidden_grow_factor=3.75, max_expected_seq_len=4096, ) elif model_variant == "llama3_70b": From 756c3eea36696b642d1d3b2ecb003f47e9e1d49d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Jun 2024 11:48:01 -0400 Subject: [PATCH 14/73] GOTHERE --- fms_fsdp/utils/train_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 1a75bd44..a5fc9b40 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -89,6 +89,7 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) + print("GOTHERE") loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() From fd28fb72a1d77e8f8d535e91e5a356839f680ba8 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 24 Jun 2024 11:50:48 -0400 Subject: [PATCH 15/73] No gothere --- fms_fsdp/utils/train_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index a5fc9b40..1a75bd44 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -89,7 +89,6 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) - print("GOTHERE") loss.backward() ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() From ffded3584e8111012292d2d5c78427f2fb0aa4be Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Jun 2024 13:20:15 -0400 Subject: [PATCH 16/73] Nonstrict fsdp load --- fms_fsdp/utils/checkpointing_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 41dd8e2d..d299baed 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -208,6 +208,7 @@ def load( state_dict=model_ckp, storage_reader=FileSystemReader(load_path), planner=DefaultLoadPlanner(), + strict=strict, ) model.load_state_dict(model_ckp["model_state"]) model.to(self.local_rank) From 1050d1dc68809cbee1d5ef950a4b233adc576d1c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Jun 2024 13:23:23 -0400 Subject: [PATCH 17/73] Nonstrict fsdp load pt2 --- fms_fsdp/utils/checkpointing_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index d299baed..b27b75a4 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -208,9 +208,8 @@ def load( state_dict=model_ckp, storage_reader=FileSystemReader(load_path), planner=DefaultLoadPlanner(), - strict=strict, ) - model.load_state_dict(model_ckp["model_state"]) + model.load_state_dict(model_ckp["model_state"], strict=strict) model.to(self.local_rank) self.report(model_load_time=time.time() - model_load_time) step = 0 From 166c01de04dca59315d019326f054c51102bede1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 25 Jun 2024 13:26:34 -0400 Subject: [PATCH 18/73] Stop nonstrict fsdp load --- fms_fsdp/utils/checkpointing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index b27b75a4..41dd8e2d 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -209,7 +209,7 @@ def load( storage_reader=FileSystemReader(load_path), planner=DefaultLoadPlanner(), ) - model.load_state_dict(model_ckp["model_state"], strict=strict) + model.load_state_dict(model_ckp["model_state"]) model.to(self.local_rank) self.report(model_load_time=time.time() - model_load_time) step = 0 From fee4c4878805a22ccb186921c41322b3d98a5603 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 1 Jul 2024 13:45:39 -0400 Subject: [PATCH 19/73] Separate gqa4 and 16 cfgs --- fms_fsdp/utils/config_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index fdd31780..a96fb66b 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -100,7 +100,7 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) - elif model_variant == "llama3_1.8b_tele": + elif model_variant == "llama3_1.8b_tele16": llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, @@ -110,6 +110,16 @@ def get_model_config(model_variant): hidden_grow_factor=3.75, max_expected_seq_len=4096, ) + elif model_variant == "llama3_1.8b_tele4": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=2048, + nheads=16, + kvheads=4, + nlayers=24, + hidden_grow_factor=3.5, + max_expected_seq_len=4096, + ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From 6f3fd09a4661ad8b2e702dec4002d46f139c0a65 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 1 Jul 2024 13:48:20 -0400 Subject: [PATCH 20/73] Fix indent --- fms_fsdp/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index a96fb66b..e4270892 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -110,7 +110,7 @@ def get_model_config(model_variant): hidden_grow_factor=3.75, max_expected_seq_len=4096, ) - elif model_variant == "llama3_1.8b_tele4": + elif model_variant == "llama3_1.8b_tele4": llama_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, From 63b36afe801a0461da1491b1b43665b7683e573b Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Mon, 8 Jul 2024 12:14:02 -0400 Subject: [PATCH 21/73] add 3b config --- fms_fsdp/utils/config_utils.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 9d3f0386..c42823da 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -95,6 +95,26 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) + elif model_variant == "llama3_3.2b": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=3072, + nheads=24, + kvheads=8, + nlayers=24, + hidden_grow_factor=8 / 3, + max_expected_seq_len=8192, + ) + elif model_variant == "llama3_3.2b_4k": + llama_config = LLaMAConfig( + src_vocab_size=128256, + emb_dim=3072, + nheads=24, + kvheads=8, + nlayers=24, + hidden_grow_factor=8 / 3, + max_expected_seq_len=4096, + ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( src_vocab_size=128256, From f5a707e92fe12a1d8378d76fb7a137b52344674d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 16 Jul 2024 11:48:21 -0400 Subject: [PATCH 22/73] Add mini llama cfg --- fms_fsdp/utils/config_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index e4270892..8f47b765 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -140,6 +140,13 @@ def get_model_config(model_variant): hidden_grow_factor=3.5, max_expected_seq_len=4096, ) + elif model_variant == "llama3_194m_4k": + llama_config = LLaMAConfig( + emb_dim=1024, + nheads=8, + nlayers=10, + max_expected_seq_len=4096, + ) else: raise ValueError(f"model variant {model_variant} not supported.") From 57e3ffd598c1eb1ff6e8e7f8dc52f5d9f9a26c02 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 16 Jul 2024 11:56:03 -0400 Subject: [PATCH 23/73] mini llama3 vsize --- fms_fsdp/utils/config_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 8f47b765..3eb7104f 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -142,6 +142,7 @@ def get_model_config(model_variant): ) elif model_variant == "llama3_194m_4k": llama_config = LLaMAConfig( + src_vocab_size=128256, emb_dim=1024, nheads=8, nlayers=10, From ffad2b9cd71e5800e6d6c2fcd846135691042bff Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 22 Jul 2024 16:23:02 -0400 Subject: [PATCH 24/73] Create new data utils file --- fms_fsdp/utils/dataset_utils_v2.py | 1097 ++++++++++++++++++++++++++++ 1 file changed, 1097 insertions(+) create mode 100644 fms_fsdp/utils/dataset_utils_v2.py diff --git a/fms_fsdp/utils/dataset_utils_v2.py b/fms_fsdp/utils/dataset_utils_v2.py new file mode 100644 index 00000000..08e47eba --- /dev/null +++ b/fms_fsdp/utils/dataset_utils_v2.py @@ -0,0 +1,1097 @@ +import csv +import logging +import math +import os +import random +import time +from typing import Any, Callable, List, Optional, Set, Type, Union + +import pyarrow as pa +import torch +import torch.utils.data as data + +from fms_fsdp.utils.checkpointing_utils import get_latest + + +""" +The following distributed dataloaders are designed around 3 main principles: + +1. Efficient, asynchronous operation. Workers on different devices do not communicate. +2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator + loading from disk and additional layers adding levels of post-processing (shuffling, + packing, padding, etc.). +3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal + state that can be written/read on disk via implemented recursive `state_dict()` and + `load_state_dict()` calls. +4. Rescalability. Users can save and load checkpoints to/from different numbers of workers + without losing the global state. This is accomplished by splitting state fields for each + layer into `state_params`, which are typically scalar-valued and can be discarded when + rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be + re-distributed over workers (i.e. buffers). + +Our loaders obey the following type heirarchy: +torch.data.IterableDataset -> _Stateful_Dataset -> _Wrapper_Dataset. +`_Stateful_Dataset` implements state and checkpointing logic. A `_Wrapper_Dataset` holds a +single `_Stateful_Dataset` and iterates via calling its wrapped dataset any number of times, +then applying some sort of post-processing and yielding the result. Users build data processing +pipelines by wrapping a base `_Stateful_Dataset` in any number of `_Wrapper_Dataset` layers, +which is then passed to the torch DataLoader. + +NOTE: `_Wrapper_Dataset` currently only implements wrapping a single instantiated sub-dataset layer. +Many layers need multiple sub-layers (i.e. random sampling from distinct data sources). These are +currently implemented as base `_Stateful_Datasets` that take the class of their sub-layers plus any +pass-through arguments, and instantiate all those sub-layers. This is easy on the user, who no longer +needs to instantiate large sets of sub-layers in their code, but leads to awkwardness in this file. +Cleanup is planned for the future. +""" + + +def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. + """ + return itemlist[ + (rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize + ] + + +def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, + and return the span including all owned items, fractional or otherwise. + """ + start = math.floor(len(itemlist) * rank / worldsize) + end = math.ceil(len(itemlist) * (rank + 1) / worldsize) + return itemlist[start:end] + + +class _Stateful_Dataset(data.IterableDataset): + """ + Stub for stateful datasets, extends data.IterableDataset with state_dict methods. + All subclasses should specify the params to be considered stateful or reshardable in the + self.state_params and self.reshard_params lists. + """ + + def __init__( + self, + rank: int, + worldsize: int, + ): + assert rank >= 0, f"Rank {rank} must be a positive integer" + assert ( + worldsize > rank + ), f"Worldsize {worldsize} must be greater than rank {rank}" + self.state_params: List[str] = [] + self.reshard_params: List[str] = [] + self.rank = rank + self.worldsize = worldsize + self.load_worldsize = ( + worldsize # Enable calling load_state_dict() directly, assume no rescaling + ) + + def statename(self, x: str): + # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline + return self.__class__.__name__ + "." + x + + def state_dict(self): + """ + Retrieve all state and reshard flags (each worker/process saves its own state dict shard) + """ + return { + self.statename(flag): getattr(self, flag) + for flag in self.state_params + self.reshard_params + } + + def _reshard(self, sharded_list): + """ + Sharded_list is a list of lists, where each "shard" sublist must have the same length. + These shards should tightly span only the partition of data owned by this worker. + (i.e. if global_list is the list of all entries, sharded_list = _shard_inclusive(global_list) ). + Determine fractional ownership of shards, and get the flattened partition owned by this worker. + """ + # How many shards did _shard_inclusive() drop to the left of sharded_list? + shard_offset = math.floor(self.load_worldsize * self.rank / self.worldsize) + # How long are the list shards? + shard_len = len(sharded_list[0]) + for i, shard in enumerate(sharded_list): + assert ( + len(shard) == shard_len + ), f"Shard {i} with length {len(shard)} does not match expected {shard_len}" + # How many list items did _shard_inclusive() drop to the left of the flattened sharded_list? + item_offset = shard_len * shard_offset + # How many list items are there in total? + n_items = self.load_worldsize * shard_len + # The indices of the flattened sharded_list that this worker owns + my_items = range( + int(n_items * self.rank / self.worldsize) - item_offset, + int(n_items * (self.rank + 1) / self.worldsize) - item_offset, + ) + # Pull out owned items + return [sharded_list[i // shard_len][i % shard_len] for i in my_items] + + def load_state_dict(self, state_dicts, sharded_input=False): + """ + Input state_dicts is a list of state_dicts. If sharded_input=False, this is expected to be the + global list of states across all checkpoint shard files. If sharded_input=True, this expects + _shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading. + Workflow: + 1. if sharded_inputs is false, shard the inputs. + 2. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint + shard (state_dicts is a singleton list). + 3. If worldsize does not match checkpoint, toss state params and assemble reshard params from + across given state_dicts. In this case state_dicts may be singleton (for fractional ownership) + or multi-element (for multiple/partitioned ownership). + 4. Return reduced input for use by downstream loading functions + """ + if not sharded_input: + self.load_worldsize = len(state_dicts) + state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) + if self.load_worldsize == self.worldsize: + [ + setattr(self, flag, state_dicts[0][self.statename(flag)]) + for flag in self.state_params + self.reshard_params + ] + else: + for flag in self.reshard_params: + reshard = self._reshard( + [sd[self.statename(flag)] for sd in state_dicts] + ) + setattr(self, flag, reshard) + return state_dicts + + def load_from_path(self, path: str): + """ + Count shard files in the specified checkpoint folder and determine overlap with current + rank and worldsize partition. Load only matching shardfile(s) and pass to load_state_dict. + This is more efficient than sharding the full loaded state. + """ + assert os.path.exists(path), "Specified checkpoint does not exist" + assert not os.path.isfile(path), "Checkpoint should be a folder of shard states" + fileshards = [x for x in os.listdir(path) if "loader" in x] + fileshards = sorted(fileshards, key=lambda x: int(x.split("_")[2][:-4])) + assert ( + len(fileshards) > 0 + ), "Checkpoint directory must contain checkpoint files with 'loader' in the name" + self.load_worldsize = len(fileshards) + # Grab only the shard files holding data we currently own + my_fileshards = _shard_inclusive(fileshards, self.rank, self.worldsize) + states = [torch.load(os.path.join(path, x)) for x in my_fileshards] + self.load_state_dict(states, True) + + def save_to_path(self, path: str): + """ + Grab recursive shard states and save all shard states to the specified checkpoint folder + """ + os.makedirs(path, exist_ok=True) + state = self.state_dict() + torch.save(state, os.path.join(path, f"loader_state_{self.rank}.pth")) + + +class _Wrapper_Dataset(_Stateful_Dataset): + """ + Stub for nested wrappers of _Stateful_Datasets. Extends state fns with recursion. + Requires a single instantiated sub-dataset. + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + ): + self.dataset = dataset + super().__init__(self.dataset.rank, self.dataset.worldsize) + + def load_state_dict(self, state_dicts, sharded_input=False): + """ + Sets all specified flags at the current level, then recurses into wrapped dataset. + """ + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + self.dataset.load_worldsize = self.load_worldsize + self.dataset.load_state_dict(sharded_dicts, True) + return sharded_dicts + + def state_dict(self): + """ + Fetches state dict recursively from wrapped layers, then adds specified flags. + Overlapping flags are overwritten with a warning. + """ + out = self.dataset.state_dict() + state = super().state_dict() + for flag in self.state_params + self.reshard_params: + if flag in out: + logging.warning( + f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. " + + f"Overwriting with value {state[flag]}" + ) + out.update(state) + return out + + +class Preprocess_Dataset(_Wrapper_Dataset): + """ + Wrapper for a _Stateful_Dataset that applies a specified preprocessing + or augmentation function to dataset outputs. + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + aug_fn : function (any -> any) + The augmentation function to apply to each dataset item. + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + aug_fn: Callable, + ): + super().__init__(dataset) + self.aug_fn = aug_fn + + def __iter__(self): + dataset = iter(self.dataset) + while True: + out = next(dataset) + yield self.aug_fn(out) + + +class Checkpoint_Dataset(_Wrapper_Dataset): + """ + Wrapper for a _Stateful_Dataset that implements auto-checkpoint saving every n steps. + Useful for setting n_workers > 0, so that workers do not rely on the master process + for state saving (inter-process communication unsupported in PyTorch datasets). + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + load_path : str + Absolute path to checkpoint load directory. If a checkpoint exists, loads it. + interval : int + Saves a new checkpoint every interval. + steps_per_batch : optional[int] + Number of steps required to fill a single batch. Increments interval only + when a full batch is formed. Defaults to 1. + save_path : optional[str] + Absolute path to checkpoint save directory. Defaults to load_path. + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + load_path: str, + interval: int, + steps_per_batch: int = 1, + save_path: str = "", + ): + super().__init__(dataset) + self.interval = interval + self.spb = steps_per_batch + load_path = os.path.join(load_path, "checkpoints") + if len(save_path) == 0: + save_path = load_path + else: + save_path = os.path.join(save_path, "checkpoints") + self.path = save_path + self.step = 0 + self.ministep = 0 + self.load_from_path(load_path) + + def __iter__(self): + dataset = iter(self.dataset) + while True: + yield next(dataset) + self.ministep += 1 + if self.ministep == self.spb: + self.ministep = 0 + self.step += 1 + if self.step % self.interval == 0: + newpath = os.path.join(self.path, "step_" + str(self.step) + "_ckp") + self.save_to_path(newpath) + + def report(self, msg): + if self.rank == 0: + print(msg) + + def save_to_path(self, path: str): + self.report(f"Saving dataset to {path}") + start = time.time() + super().save_to_path(path) + self.report( + f"Dataset successfully saved to {path}! Save time: {time.time() - start}" + ) + + def load_from_path(self, path: str): + # If path does not exist, or exists but is empty, exit early + if not os.path.exists(path) or len(os.listdir(path)) == 0: + self.report( + f"No valid checkpoint detected at {path}, dataset starting from scratch." + ) + return + # Grab latest item in path + latest = os.path.join(path, get_latest(path)) + self.report(f"Dataset checkpoint detected at {latest}") + # If item is not a folder, exit early + if os.path.isfile(latest): + self.report( + f"Checkpoint exists but contains no dataset! Dataset starting from scratch." + ) + return + # If item is a folder, get the step count + self.step = int(latest.split("_")[-2]) + # Proceed + start = time.time() + self.dataset.load_from_path(latest) + self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}") + + +class Preload_Buffer_Dataset(_Wrapper_Dataset): + """ + Wrapper for a Stateful_Dataset that implements data shuffling via a single in/out buffer. + Fills buffer two at a time, up to desired size, then switches to one at a time to maintain size. + Passes randomly sampled outputs one by one. + Ensures local mixing of data without relying on sliding windows or shuffling of large buffers. + Any two consecutive inputs will be separated by window_size steps in expectation. + Rescaling-enabled: buffers that shrink will re-grow to window_size, buffers that expand stay large. + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + window_size : int + Max size of input/output buffer + """ + + def __init__(self, dataset: _Stateful_Dataset, window_size: int): + super().__init__(dataset) + assert ( + window_size > 1 + ), f"Window size {window_size} must be greater than 1 for shuffling to occur" + self.window_size = window_size + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.buffer: List[List[Any]] = [] + self.buffer_size = 0 + self.state_params = ["g_state"] + self.reshard_params = ["buffer"] + + def __iter__(self): + dataset = iter(self.dataset) + while True: + # Pad out buffer if needed + self._pad_buffer() + + # Load a point to buffer if necessary + if self.buffer_size < self.window_size: + self.buffer[self.buffer_size] = next(dataset) + self.buffer_size += 1 + + # Swap out randomly sampled value from buffer + i = torch.randint(self.buffer_size, (1,), generator=self.generator).item() + out = self.buffer[i] + self.buffer[i] = next(dataset) + yield out + + def _pad_buffer(self): + if self.buffer_size < self.window_size: + self.buffer += [ + [], + ] * (self.window_size - self.buffer_size) + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + # Prune buffer so it can be resharded in future + self.buffer = self.buffer[: self.buffer_size] + out = super().state_dict() + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Manually set buffer size + self.buffer_size = len(self.buffer) + return sharded_dicts + + +class Buffer_Dataset(_Wrapper_Dataset): + """ + Wrapper for a _Stateful_Dataset that takes in sequences of varying lengths, and packs/pads them + into sequences of desired length. Input sequences are packed greedily until the buffer would + otherwise overrun, then remaining values are filled depending on initialization flags. + Also injects BOS/EOS into the packed output sequence if desired, and if BOS/EOS tokens are + not already in those positions. Implements rescaling by simply dropping (buffer) state. + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + seq_len : int + The desired sequence length + pack_hard : bool + Split input sequences to fill output buffer, or use pad tokens to fill remaining space? + bos_token : any | None + Token to prepend to every output sequence. If None, no token is added. Type should match data type. + eos_token : any | None + Token to append to every output sequence. If None, no token is added. Type should match data type. + pad_token : any | None + Token used to fill out output sequence. Type should match data type. + drop_final_token : any | None + Drop the final token of each document if it matches this value? + (For edge case where bos=eos=None, and sep already appears at beginning of each doc - + drop added extra sep from end of doc) + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + seq_len: int, + pack_hard: bool, + bos_token=None, + eos_token=None, + pad_token=None, + ): + super().__init__(dataset) + self.len = seq_len + + # Buffer args + self.buffer: List[str] = [] + self.bos = bos_token + self.eos = eos_token + self.pad = pad_token + self.pack_hard = pack_hard + if not pack_hard: + assert ( + pad_token is not None + ), "Error: if using pads, you must supply a pad_token" + + self.state_params = ["buffer"] + + def _get_buffer(self, iterable, length, buffer): + # Pull data until buffer is about to overrun, return exactly proper length + new = [] + while len(buffer) + len(new) < length: + buffer += new + new = next(iterable) + + # Add bos if needed + if self.bos is not None and (len(buffer) == 0 or buffer[0] != self.bos): + buffer = [self.bos] + buffer + + # Handle buffer splitting + if len(buffer) >= length: + # If buffer is too long, force split + out = buffer[:length] + buffer = buffer[length:] + if self.eos is not None and out[-1] != self.eos: + buffer = [out[-1]] + buffer + out[-1] = self.eos + buffer = buffer + new + else: + if self.pack_hard: + # Pack in as much of new sequence as will fit + buffer = buffer + new + out = buffer[:length] + buffer = buffer[length:] + if self.eos is not None and out[-1] != self.eos: + buffer = [out[-1]] + buffer + out[-1] = self.eos + else: + # Fill out with pads as needed + if self.eos is not None and buffer[-1] != self.eos: + buffer.append(self.eos) + if self.pad is not None: + out = buffer + [self.pad] * (length - len(buffer)) + else: + out = buffer + buffer = new + return out, buffer + + # Fill buffer line by line, delimiters and packing/splitting as appropriate + def __iter__(self): + dataset = iter(self.dataset) + while True: + out, buffer = self._get_buffer(dataset, self.len, self.buffer) + self.buffer = buffer + yield out + + +class Streaming_Doc_Dataset(_Stateful_Dataset): + """ + The base distributed dataset for loading sequences/documents from pyarrow shards. + Pyarrow shard files are expected to hold multiple recordBatches, where each recordBatch has a "tokens" + field consisting of a single token list. (i.e. each document is a single sequence under a "token" field, + and the file is a list of such sequences) + Relies on a compiled metadata file to fetch shardfile lengths, assumes file already exists in the parent directory, + and is in proper csv format (first row "dataset/filename,documents,tokens", subsequent rows these values). + + For a single dataset directory, splits shard files into x=worldsize fragments and grabs a 1/n contiguous + span of shard fragments (contiguous to limit file reads from cloud/disk). + Logs the number of documents owned from each shardfile, and relies on ZCG random bijection to + map contiguous range of indices to shuffled, noncontiguous set of documents from each shard file. + Shuffles the file list deterministically to hop from file to file. + + At runtime, iterates through documents in each shuffled shard file, pulling each shard on demand. + Shards are thus pulled no more than once per epoch. + Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents. + + Streaming_Doc_Dataset grabs files from a flat directory representing a single dataset. + For percentage-based sampling of multiple subdatasets, see Sampling_Dataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects directory containing pyarrow shardfiles. + Parent directory should contain 'meta' folder with metadata csv file inside. + rank : int + Current worker index + worldsize : int + Total number of workers + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. Required for downstream + sampling logic (can be removed later via PreProcess_Dataset if needed). + bos_token : Any | None + Optional token used to indicate sequence/document start. Type should match data type. + strip_tokens : set[Any] + Token values that should be removed if detected at beginning or end of document + (i.e. any eos/bos tokens already present in the data). Type should match data type. + seed : int + The random seed for deterministic shuffling/sharding + min_length : int + Sequences below this length are skipped + max_chunksize : int + Maximum sequence length to return. Break long docs into chunks of this size or shorter. + verbose : bool + Track setup progress? + shuffle : bool + Shuffle shard file and document orders? (Disable for simple testing) + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + delimiter_token: Any, + bos_token: Optional[Any] = None, + strip_tokens: Optional[Set[Any]] = set(), + seed: int = 42, + min_length: int = 1, + max_chunksize: int = 1024, + verbose: bool = False, + shuffle: bool = True, + ): + super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) + self.seed = seed + self.data = datapath + self.min_length = min_length + assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" + self.chunksize = max_chunksize + self.eos = delimiter_token + self.bos = bos_token + self.drop = strip_tokens + self.verbose = verbose + self.docset: List[ + Any + ] = [] # map of doc indices to (shardid, min docid, max docid) + self.docs_per_shard = {} + + # Guaranteed inconsistent shuffling across workers + random.seed(self.seed + rank) + + # Gather per-file document counts from metadata count file(s) + countfiles = [ + x + for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) + if "counts" in x and "csv" in x + ] + assert len(countfiles) == 1 + doc_counts = {} + pathsplit = (datapath, "") + while len(pathsplit[1]) == 0: + pathsplit = os.path.split(pathsplit[0]) + pardir, dataset = pathsplit + self.dataset = dataset + with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find("/" + dataset) + 1 + if prefix > 0: + key = fullpath[prefix:] + doc_counts[key] = int(row["documents"]) + + # Assemble document set owned by this worker: + # listdir, assemble shardfraglist (ind -> shard, frag) + shards = [ + shard + for shard in os.listdir(datapath) + if os.path.isfile(os.path.join(datapath, shard)) + and "arrow" in os.path.join(datapath, shard) + ] + shards.sort() # Ensure consistent sharding across machines + start_frag = (rank * worldsize * len(shards)) // worldsize + end_frag = ((rank + 1) * worldsize * len(shards)) // worldsize + shardfrags = [ + (shards[i // worldsize], i % worldsize) for i in range(start_frag, end_frag) + ] + + # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): + ndocs = -1 + docset = {} # shardid -> (min docid, max docid) + for i, (shard, frag) in enumerate(shardfrags): + ndocs = doc_counts[os.path.join(dataset, shard)] + self.docs_per_shard[shard] = ndocs + doc_start = (ndocs * frag) // worldsize + doc_end = (ndocs * frag + ndocs) // worldsize - 1 # Inclusive upper bound + if shard not in docset: + docset[shard] = [doc_start, doc_end] + min_d, max_d = docset[shard] + if doc_start < min_d: + docset[shard][0] = doc_start + if doc_end > max_d: + docset[shard][1] = doc_end + + # Add all of this dataset's shard entries to self.docset + doccount = 0 + for shardid in docset: + min_d = docset[shardid][0] + max_d = docset[shardid][1] + self.docset.append((shardid, min_d, max_d)) + doccount += max_d - min_d + 1 + self._len = doccount + + if verbose: + logging.info( + f" Worker {rank} ingested {len(shardfrags)} shard fragments from {dataset}" + ) + + # Shuffle shard files + if shuffle: + random.shuffle(self.docset) + + self.docset_index = 0 + self.chunk_index = -1 + + # Stats + self.epochs_seen = -1 + self.tokens_seen = 0 + self.docs_seen = 0 + self.percent_seen = 0 + self.lcg_state = seed + rank + + self.state_params = [ + "dataset", + "docset_index", + "chunk_index", + "epochs_seen", + "tokens_seen", + "docs_seen", + "percent_seen", + "lcg_state", + ] + + def _get_docid(self, i): + """ + Given a global doc index over the set of docs owned by this worker, + return the corresponding data/shard/local index + """ + cur = 0 + assert ( + i <= self._len + ), f"You have requested an illegal doc index {i}, docset length is {self._len}" + for shardid, min_d, max_d in self.docset: + docrange = max_d - min_d + 1 + cur += docrange + if cur > i: + return shardid, docrange, min_d + + def _get_reader(self, path, newpath, reader): + """ + If new filepath does not match the current one, + open a new reader on that filepath (pull file on demand) + """ + if newpath != path: + del reader + if self.verbose: + logging.info(f"Worker {self.rank} opening new file {newpath}") + reader = pa.ipc.open_file(newpath) + path = newpath + return path, reader + + def _construct_chunk(self, j, doc, n_chunks): + """ + Grab a chunk of the desired size from the pyarrow document, + avoiding unnecessary overhead in case of large docs + """ + start_index = j * self.chunksize + n_pull = self.chunksize + if self.bos is not None: + if j == 0: + n_pull -= 1 + else: + start_index -= 1 + chunk = doc.slice(start_index, n_pull).to_pylist() + self.tokens_seen += len(chunk) + # Add bos/eos tokens if needed + if self.bos is not None and j == 0: + chunk = [self.bos] + chunk + if j == n_chunks - 1: + chunk = chunk + [self.eos] + return chunk + + def _random_map_docid(self, size): + """ + Given size of document pool, use saved state (prior index) to generate the next index via LCG. + Implements within-shard document shuffling without materializing any large doc lists. + """ + m = 2 ** math.ceil(math.log2(size)) # Round up to nearest power of 2 + a = 5 # A,C values known to work well with powers of 2 (Knuth, 1997, 3.2.1.3) + c = (self.rank + self.seed) * 2 + 1 + state = self.lcg_state + while True: + state = (a * state + c) % m + if state < size: + return state + + def __iter__(self): + docset_offset = self.docset_index + lcg_offset = self.lcg_state + residual_chunks = self.chunk_index + 1 # pick up AFTER where the ckp left off + ndocs = self._len + path = "" + reader = None + while True: + # Iterate through docs, starting at desired offset + for i in range(ndocs): + doc_index = (docset_offset + i) % ndocs + + # Update stats + if doc_index == 0: + self.epochs_seen += 1 + self.docset_index = doc_index + # Map doc id to shard, id in file + shardid, docrange, mindoc = self._get_docid(doc_index) + + # Read doc + newpath = os.path.join(self.data, shardid) + path, reader = self._get_reader(path, newpath, reader) + # Map id in range of owned docs to new (consistently) shuffled id + doclcg = self._random_map_docid(docrange) + docid = doclcg + mindoc + doc = reader.get_batch(docid)["tokens"] + if doc[0].as_py() in self.drop: + doc = doc.slice(1, len(doc) - 1) + if doc[-1].as_py() in self.drop: + doc = doc.slice(0, len(doc) - 1) + doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 + if doclen >= self.min_length: + n_chunks = math.ceil(doclen / self.chunksize) + for j in range(n_chunks): + if i == 0 and j < residual_chunks: + pass + else: + self.chunk_index = j + # Document complete, update stats + if j == n_chunks - 1: + self.docs_seen += 1 + self.percent_seen = ( + self.docs_seen * 100 / (self._len + 1e-9) + ) + yield self._construct_chunk(j, doc, n_chunks) + + # Advance RNG state + self.lcg_state = doclcg + + # Load any chunks initially skipped in first doc + self.docset_index = docset_offset + self.lcg_state = lcg_offset + shardid, docrange, mindoc = self._get_docid(docset_offset) + docid = self._random_map_docid(docrange) + mindoc + newpath = os.path.join(self.data, shardid) + path, reader = self._get_reader(path, newpath, reader) + doc = reader.get_batch(docid)["tokens"] + if doc[0].as_py() in self.drop: + doc = doc.slice(1, len(doc) - 1) + if doc[-1].as_py() in self.drop: + doc = doc.slice(0, len(doc) - 1) + doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 + if doclen >= self.min_length: + n_chunks = math.ceil(doclen / self.chunksize) + for j in range(residual_chunks): + self.chunk_index = j + yield self._construct_chunk(j, doc, n_chunks) + + def load_state_dict(self, state_dicts, sharded_input=False): + assert ( + self.load_worldsize == self.worldsize + ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." + d = self.dataset + out = super().load_state_dict(state_dicts, sharded_input) + assert ( + d == self.dataset + ), f"Dataset mismatch: checkpoint contains {self.dataset}, expected {d}" + return out + + +class Sampling_Dataset(_Stateful_Dataset): + """ + A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the + number of tokens seen from each subdataset will match those weights as closely as possible. + This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking + the number of tokens emitted by each. Whichever loader is furthest from its target will be + the next to pass a document. + + All args except for dataset_type, datasets, weights and delimiter are pass-through args for + the component _Stateful_Datasets and are documented in the appropriate classes. + ... + Args + ---- + dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset + Underlying iterator for each desired subdataset + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + datasets : list[str] | None + A list of subdatasets to draw from. If None, draws from all subfolders of datapath. + weights : list(float) | None + Weights describing what percent of emitted tokens should come from each subdataset. + Need not sum to 1. If None, tokens are drawn evenly. + ... + Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset + """ + + def __init__( + self, + datapath: str, + dataset_type: Union[ + Type["Streaming_Doc_Dataset"], + Type["Scalable_Shard_Dataset"], + ], + rank: int, + worldsize: int, + delimiter_token: Any, + datasets=None, + weights=None, + verbose=False, + **kwargs, + ): + super().__init__(rank, worldsize) + self.delimiter = delimiter_token + self.datasets = ( + datasets + if datasets is not None + else [ + f + for f in os.listdir(datapath) + if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f + ] + ) + assert len(self.datasets) > 0, "You must specify at least one dataset" + + if weights is not None: + assert len(weights) == len( + self.datasets + ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" + for w in weights: + assert w > 0, f"Sampling rate {w} must be positive" + self.weights = [1] * len(self.datasets) if weights is None else weights + self.weights = [w / sum(self.weights) for w in self.weights] + + self.tokens_seen = [0] * len(self.datasets) + + # Build subdataset iterators + self.data = [] + for i, d in enumerate(self.datasets): + self.data.append( + dataset_type( + datapath=os.path.join(datapath, d), + rank=rank, + worldsize=worldsize, + delimiter_token=delimiter_token, + verbose=verbose, + **kwargs, + ) + ) + if verbose: + logging.info( + f"Worker {rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" + ) + + self.current_iterator = -1 + self.state_params = ["tokens_seen", "current_iterator"] + + def __iter__(self): + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + if self.current_iterator != -1: + # Finish current document + out = next(data[self.current_iterator]) + self.tokens_seen[self.current_iterator] += len(out) + if out[-1] == self.delimiter: + self.current_iterator = -1 + yield out + else: + # Choose new subdataset to draw from + # (whichever is currently most underrepresented compared to target rate) + offset = [ + self.weights[i] + - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) + for i in range(len(self.datasets)) + ] + offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] + self.current_iterator = offset_argmax + + def state_dict(self): + # Manually add state of all subloaders to self state + out = { + self.statename("sample_iterator_states"): [ + d.state_dict() for d in self.data + ] + } + out.update(super().state_dict()) + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + # Load stats + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Load sub-iterator states + for i, subdata in enumerate(self.data): + # Grab just that sub-iterator across all ranks + subdata.load_worldsize = self.load_worldsize + subdata.load_state_dict( + [ + sd[self.statename("sample_iterator_states")][i] + for sd in sharded_dicts + ], + True, + ) + return sharded_dicts + + +class Scalable_Shard_Dataset(_Stateful_Dataset): + """ + A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track + state individually and reshard over n_gpus. + + All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. + rank : int + Current worker index + worldsize : int + Total number of workers + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + n_logical_shards : int + Number of logical shards. Must be a multiple of world size. + ... + Pass-through args, see Streaming_Doc_Dataset + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + delimiter_token: Any, + n_logical_shards: int = 2048, + verbose=False, + **kwargs, + ): + assert ( + n_logical_shards % worldsize == 0 + ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert ( + n_logical_shards > 0 + ), f"n_logical_shards {n_logical_shards} must be a positive integer" + + super().__init__(rank, worldsize) + self.data = [] + self.n_logicals = n_logical_shards // worldsize + self.total_shards = n_logical_shards + self.delimiter = delimiter_token + + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + assert len(self.logicals_owned) == self.n_logicals + + # Build logical shards + for i in range(self.n_logicals): + self.data.append( + Streaming_Doc_Dataset( + datapath=datapath, + worldsize=n_logical_shards, + rank=self.logicals_owned[i], + delimiter_token=delimiter_token, + verbose=(rank == 0), + **kwargs, + ) + ) + if verbose: + logging.info( + f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + + # Fetch logical shard sampling stats + self.n_docs_remaining = [d._len for d in self.data] + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = None + self.logical_shard_states = None + self.generator = torch.Generator().manual_seed(self.rank) + self.g_state = None + self.state_params = ["current_reader", "g_state"] + self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + + def __iter__(self): + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + # Sample logical shard (or load from ckp) + if self.current_reader is not None: + ind = self.current_reader + else: + ind = torch.multinomial( + torch.tensor(self.n_docs_remaining, dtype=torch.float), + 1, + generator=self.generator, + ).item() + self.current_reader = ind + # Read doc + out = next(data[ind]) + while out[-1] != self.delimiter: + yield out + out = next(data[ind]) + # Update state to show we've finished the doc + self.current_reader = None + self.n_docs_remaining[ind] -= 1 + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) + # Return final piece of doc + yield out + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + # Recursive fetch + self.logical_shard_states = [d.state_dict() for d in self.data] + return super().state_dict() + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Recursive set + for i in range(self.n_logicals): + self.data[i].load_state_dict([self.logical_shard_states[i]], True) + return sharded_dicts From c72692cef8190e0265fcb326bb66870a9524ce64 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 13:08:28 -0400 Subject: [PATCH 25/73] Add v3 (incremental changes), test those, linting --- fms_fsdp/utils/dataset_utils_v3.py | 1108 ++++++++++++++++++++++++++++ main_training.py | 4 +- tests/test_datasets.py | 2 +- 3 files changed, 1112 insertions(+), 2 deletions(-) create mode 100644 fms_fsdp/utils/dataset_utils_v3.py diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py new file mode 100644 index 00000000..c8e263d0 --- /dev/null +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -0,0 +1,1108 @@ +import csv +import logging +import math +import os +import random +import time +from typing import Any, Callable, List, Optional, Set, Type, Union + +import pyarrow as pa +import torch +import torch.utils.data as data + +from fms_fsdp.utils.checkpointing_utils import get_latest + + +""" +The following distributed dataloaders are designed around 3 main principles: + +1. Efficient, asynchronous operation. Workers on different devices do not communicate. +2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator + loading from disk and additional layers adding levels of post-processing (shuffling, + packing, padding, etc.). +3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal + state that can be written/read on disk via implemented recursive `state_dict()` and + `load_state_dict()` calls. +4. Rescalability. Users can save and load checkpoints to/from different numbers of workers + without losing the global state. This is accomplished by splitting state fields for each + layer into `state_params`, which are typically scalar-valued and can be discarded when + rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be + re-distributed over workers (i.e. buffers). + +Our loaders obey the following type heirarchy: +torch.data.IterableDataset -> _Stateful_Dataset -> _Wrapper_Dataset. +`_Stateful_Dataset` implements state and checkpointing logic. A `_Wrapper_Dataset` holds a +single `_Stateful_Dataset` and iterates via calling its wrapped dataset any number of times, +then applying some sort of post-processing and yielding the result. Users build data processing +pipelines by wrapping a base `_Stateful_Dataset` in any number of `_Wrapper_Dataset` layers, +which is then passed to the torch DataLoader. + +NOTE: `_Wrapper_Dataset` currently only implements wrapping a single instantiated sub-dataset layer. +Many layers need multiple sub-layers (i.e. random sampling from distinct data sources). These are +currently implemented as base `_Stateful_Datasets` that take the class of their sub-layers plus any +pass-through arguments, and instantiate all those sub-layers. This is easy on the user, who no longer +needs to instantiate large sets of sub-layers in their code, but leads to awkwardness in this file. +Cleanup is planned for the future. +""" + + +def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. + """ + return itemlist[ + (rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize + ] + + +def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, + and return the span including all owned items, fractional or otherwise. + """ + start = math.floor(len(itemlist) * rank / worldsize) + end = math.ceil(len(itemlist) * (rank + 1) / worldsize) + return itemlist[start:end] + + +class _Stateful_Dataset(data.IterableDataset): + """ + Stub for stateful datasets, extends data.IterableDataset with state_dict methods. + All subclasses should specify the params to be considered stateful or reshardable in the + self.state_params and self.reshard_params lists. + """ + + def __init__( + self, + rank: int, + worldsize: int, + ): + assert rank >= 0, f"Rank {rank} must be a positive integer" + assert ( + worldsize > rank + ), f"Worldsize {worldsize} must be greater than rank {rank}" + self.state_params: List[str] = [] + self.reshard_params: List[str] = [] + self.rank = rank + self.worldsize = worldsize + self.load_worldsize = ( + worldsize # Enable calling load_state_dict() directly, assume no rescaling + ) + + def statename(self, x: str): + # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline + return self.__class__.__name__ + "." + x + + def state_dict(self): + """ + Retrieve all state and reshard flags (each worker/process saves its own state dict shard) + """ + return { + self.statename(flag): getattr(self, flag) + for flag in self.state_params + self.reshard_params + } + + def _reshard(self, sharded_list): + """ + Sharded_list is a list of lists, where each "shard" sublist must have the same length. + These shards should tightly span only the partition of data owned by this worker. + (i.e. if global_list is the list of all entries, sharded_list = _shard_inclusive(global_list) ). + Determine fractional ownership of shards, and get the flattened partition owned by this worker. + """ + # How many shards did _shard_inclusive() drop to the left of sharded_list? + shard_offset = math.floor(self.load_worldsize * self.rank / self.worldsize) + # How long are the list shards? + shard_len = len(sharded_list[0]) + for i, shard in enumerate(sharded_list): + assert ( + len(shard) == shard_len + ), f"Shard {i} with length {len(shard)} does not match expected {shard_len}" + # How many list items did _shard_inclusive() drop to the left of the flattened sharded_list? + item_offset = shard_len * shard_offset + # How many list items are there in total? + n_items = self.load_worldsize * shard_len + # The indices of the flattened sharded_list that this worker owns + my_items = range( + int(n_items * self.rank / self.worldsize) - item_offset, + int(n_items * (self.rank + 1) / self.worldsize) - item_offset, + ) + # Pull out owned items + return [sharded_list[i // shard_len][i % shard_len] for i in my_items] + + def load_state_dict(self, state_dicts, sharded_input=False): + """ + Input state_dicts is a list of state_dicts. If sharded_input=False, this is expected to be the + global list of states across all checkpoint shard files. If sharded_input=True, this expects + _shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading. + Workflow: + 1. if sharded_inputs is false, shard the inputs. + 2. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint + shard (state_dicts is a singleton list). + 3. If worldsize does not match checkpoint, toss state params and assemble reshard params from + across given state_dicts. In this case state_dicts may be singleton (for fractional ownership) + or multi-element (for multiple/partitioned ownership). + 4. Return reduced input for use by downstream loading functions + """ + if not sharded_input: + self.load_worldsize = len(state_dicts) + state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) + if self.load_worldsize == self.worldsize: + [ + setattr(self, flag, state_dicts[0][self.statename(flag)]) + for flag in self.state_params + self.reshard_params + ] + else: + for flag in self.reshard_params: + reshard = self._reshard( + [sd[self.statename(flag)] for sd in state_dicts] + ) + setattr(self, flag, reshard) + return state_dicts + + def load_from_path(self, path: str): + """ + Count shard files in the specified checkpoint folder and determine overlap with current + rank and worldsize partition. Load only matching shardfile(s) and pass to load_state_dict. + This is more efficient than sharding the full loaded state. + """ + assert os.path.exists(path), "Specified checkpoint does not exist" + assert not os.path.isfile(path), "Checkpoint should be a folder of shard states" + fileshards = [x for x in os.listdir(path) if "loader" in x] + fileshards = sorted(fileshards, key=lambda x: int(x.split("_")[2][:-4])) + assert ( + len(fileshards) > 0 + ), "Checkpoint directory must contain checkpoint files with 'loader' in the name" + self.load_worldsize = len(fileshards) + # Grab only the shard files holding data we currently own + my_fileshards = _shard_inclusive(fileshards, self.rank, self.worldsize) + states = [torch.load(os.path.join(path, x)) for x in my_fileshards] + self.load_state_dict(states, True) + + def save_to_path(self, path: str): + """ + Grab recursive shard states and save all shard states to the specified checkpoint folder + """ + os.makedirs(path, exist_ok=True) + state = self.state_dict() + torch.save(state, os.path.join(path, f"loader_state_{self.rank}.pth")) + + +class _Wrapper_Dataset(_Stateful_Dataset): + """ + Stub for nested wrappers of _Stateful_Datasets. Extends state fns with recursion. + Requires a single instantiated sub-dataset. + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + ): + self.dataset = dataset + super().__init__(self.dataset.rank, self.dataset.worldsize) + + def load_state_dict(self, state_dicts, sharded_input=False): + """ + Sets all specified flags at the current level, then recurses into wrapped dataset. + """ + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + self.dataset.load_worldsize = self.load_worldsize + self.dataset.load_state_dict(sharded_dicts, True) + return sharded_dicts + + def state_dict(self): + """ + Fetches state dict recursively from wrapped layers, then adds specified flags. + Overlapping flags are overwritten with a warning. + """ + out = self.dataset.state_dict() + state = super().state_dict() + for flag in self.state_params + self.reshard_params: + if flag in out: + logging.warning( + f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. " + + f"Overwriting with value {state[flag]}" + ) + out.update(state) + return out + + +class Preprocess_Dataset(_Wrapper_Dataset): + """ + Wrapper for a _Stateful_Dataset that applies a specified preprocessing + or augmentation function to dataset outputs. + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + aug_fn : function (any -> any) + The augmentation function to apply to each dataset item. + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + aug_fn: Callable, + ): + super().__init__(dataset) + self.aug_fn = aug_fn + + def __iter__(self): + dataset = iter(self.dataset) + while True: + out = next(dataset) + yield self.aug_fn(out) + + +class Checkpoint_Dataset(_Wrapper_Dataset): + """ + Wrapper for a _Stateful_Dataset that implements auto-checkpoint saving every n steps. + Useful for setting n_workers > 0, so that workers do not rely on the master process + for state saving (inter-process communication unsupported in PyTorch datasets). + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + load_path : str + Absolute path to checkpoint load directory. If a checkpoint exists, loads it. + interval : int + Saves a new checkpoint every interval. + steps_per_batch : optional[int] + Number of steps required to fill a single batch. Increments interval only + when a full batch is formed. Defaults to 1. + save_path : optional[str] + Absolute path to checkpoint save directory. Defaults to load_path. + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + load_path: str, + interval: int, + steps_per_batch: int = 1, + save_path: str = "", + ): + super().__init__(dataset) + self.interval = interval + self.spb = steps_per_batch + load_path = os.path.join(load_path, "checkpoints") + if len(save_path) == 0: + save_path = load_path + else: + save_path = os.path.join(save_path, "checkpoints") + self.path = save_path + self.step = 0 + self.ministep = 0 + self.load_from_path(load_path) + + def __iter__(self): + dataset = iter(self.dataset) + while True: + yield next(dataset) + self.ministep += 1 + if self.ministep == self.spb: + self.ministep = 0 + self.step += 1 + if self.step % self.interval == 0: + newpath = os.path.join(self.path, "step_" + str(self.step) + "_ckp") + self.save_to_path(newpath) + + def report(self, msg): + if self.rank == 0: + print(msg) + + def save_to_path(self, path: str): + self.report(f"Saving dataset to {path}") + start = time.time() + super().save_to_path(path) + self.report( + f"Dataset successfully saved to {path}! Save time: {time.time() - start}" + ) + + def load_from_path(self, path: str): + # If path does not exist, or exists but is empty, exit early + if not os.path.exists(path) or len(os.listdir(path)) == 0: + self.report( + f"No valid checkpoint detected at {path}, dataset starting from scratch." + ) + return + # Grab latest item in path + latest = os.path.join(path, get_latest(path)) + self.report(f"Dataset checkpoint detected at {latest}") + # If item is not a folder, exit early + if os.path.isfile(latest): + self.report( + f"Checkpoint exists but contains no dataset! Dataset starting from scratch." + ) + return + # If item is a folder, get the step count + self.step = int(latest.split("_")[-2]) + # Proceed + start = time.time() + self.dataset.load_from_path(latest) + self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}") + + +class Preload_Buffer_Dataset(_Wrapper_Dataset): + """ + Wrapper for a Stateful_Dataset that implements data shuffling via a single in/out buffer. + Fills buffer two at a time, up to desired size, then switches to one at a time to maintain size. + Passes randomly sampled outputs one by one. + Ensures local mixing of data without relying on sliding windows or shuffling of large buffers. + Any two consecutive inputs will be separated by window_size steps in expectation. + Rescaling-enabled: buffers that shrink will re-grow to window_size, buffers that expand stay large. + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + window_size : int + Max size of input/output buffer + """ + + def __init__(self, dataset: _Stateful_Dataset, window_size: int): + super().__init__(dataset) + assert ( + window_size > 1 + ), f"Window size {window_size} must be greater than 1 for shuffling to occur" + self.window_size = window_size + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.buffer: List[List[Any]] = [] + self.buffer_size = 0 + self.state_params = ["g_state"] + self.reshard_params = ["buffer"] + + def __iter__(self): + dataset = iter(self.dataset) + while True: + # Pad out buffer if needed + self._pad_buffer() + + # Load a point to buffer if necessary + if self.buffer_size < self.window_size: + self.buffer[self.buffer_size] = next(dataset) + self.buffer_size += 1 + + # Swap out randomly sampled value from buffer + i = torch.randint(self.buffer_size, (1,), generator=self.generator).item() + out = self.buffer[i] + self.buffer[i] = next(dataset) + yield out + + def _pad_buffer(self): + if self.buffer_size < self.window_size: + self.buffer += [ + [], + ] * (self.window_size - self.buffer_size) + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + # Prune buffer so it can be resharded in future + self.buffer = self.buffer[: self.buffer_size] + out = super().state_dict() + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Manually set buffer size + self.buffer_size = len(self.buffer) + return sharded_dicts + + +class Buffer_Dataset(_Wrapper_Dataset): + """ + Wrapper for a _Stateful_Dataset that takes in sequences of varying lengths, and packs/pads them + into sequences of desired length. Input sequences are packed greedily until the buffer would + otherwise overrun, then remaining values are filled depending on initialization flags. + Also injects BOS/EOS into the packed output sequence if desired, and if BOS/EOS tokens are + not already in those positions. Implements rescaling by simply dropping (buffer) state. + ... + Args + ---- + dataset : _Stateful_Dataset + Fully instantiated dataset + seq_len : int + The desired sequence length + pack_hard : bool + Split input sequences to fill output buffer, or use pad tokens to fill remaining space? + bos_token : any | None + Token to prepend to every output sequence. If None, no token is added. Type should match data type. + eos_token : any | None + Token to append to every output sequence. If None, no token is added. Type should match data type. + pad_token : any | None + Token used to fill out output sequence. Type should match data type. + drop_final_token : any | None + Drop the final token of each document if it matches this value? + (For edge case where bos=eos=None, and sep already appears at beginning of each doc - + drop added extra sep from end of doc) + """ + + def __init__( + self, + dataset: _Stateful_Dataset, + seq_len: int, + pack_hard: bool, + bos_token=None, + eos_token=None, + pad_token=None, + ): + super().__init__(dataset) + self.len = seq_len + + # Buffer args + self.buffer: List[str] = [] + self.bos = bos_token + self.eos = eos_token + self.pad = pad_token + self.pack_hard = pack_hard + if not pack_hard: + assert ( + pad_token is not None + ), "Error: if using pads, you must supply a pad_token" + + self.state_params = ["buffer"] + + def _get_buffer(self, iterable, length, buffer): + # Pull data until buffer is about to overrun, return exactly proper length + new = [] + while len(buffer) + len(new) < length: + buffer += new + new = next(iterable) + + # Add bos if needed + if self.bos is not None and (len(buffer) == 0 or buffer[0] != self.bos): + buffer = [self.bos] + buffer + + # Handle buffer splitting + if len(buffer) >= length: + # If buffer is too long, force split + out = buffer[:length] + buffer = buffer[length:] + if self.eos is not None and out[-1] != self.eos: + buffer = [out[-1]] + buffer + out[-1] = self.eos + buffer = buffer + new + else: + if self.pack_hard: + # Pack in as much of new sequence as will fit + buffer = buffer + new + out = buffer[:length] + buffer = buffer[length:] + if self.eos is not None and out[-1] != self.eos: + buffer = [out[-1]] + buffer + out[-1] = self.eos + else: + # Fill out with pads as needed + if self.eos is not None and buffer[-1] != self.eos: + buffer.append(self.eos) + if self.pad is not None: + out = buffer + [self.pad] * (length - len(buffer)) + else: + out = buffer + buffer = new + return out, buffer + + # Fill buffer line by line, delimiters and packing/splitting as appropriate + def __iter__(self): + dataset = iter(self.dataset) + while True: + out, buffer = self._get_buffer(dataset, self.len, self.buffer) + self.buffer = buffer + yield out + + +class Streaming_Doc_Dataset(_Stateful_Dataset): + """ + The base distributed dataset for loading sequences/documents from pyarrow shards. + Pyarrow shard files are expected to hold multiple recordBatches, where each recordBatch has a "tokens" + field consisting of a single token list. (i.e. each document is a single sequence under a "token" field, + and the file is a list of such sequences) + Relies on a compiled metadata file to fetch shardfile lengths, assumes file already exists in the parent directory, + and is in proper csv format (first row "dataset/filename,documents,tokens", subsequent rows these values). + + For a single dataset directory, splits shard files into x=worldsize fragments and grabs a 1/n contiguous + span of shard fragments (contiguous to limit file reads from cloud/disk). + Logs the number of documents owned from each shardfile, and relies on ZCG random bijection to + map contiguous range of indices to shuffled, noncontiguous set of documents from each shard file. + Shuffles the file list deterministically to hop from file to file. + + At runtime, iterates through documents in each shuffled shard file, pulling each shard on demand. + Shards are thus pulled no more than once per epoch. + Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents. + + Streaming_Doc_Dataset grabs files from a flat directory representing a single dataset. + For percentage-based sampling of multiple subdatasets, see Sampling_Dataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects directory containing pyarrow shardfiles. + Parent directory should contain 'meta' folder with metadata csv file inside. + rank : int + Current worker index + worldsize : int + Total number of workers + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. Required for downstream + sampling logic (can be removed later via PreProcess_Dataset if needed). + bos_token : Any | None + Optional token used to indicate sequence/document start. Type should match data type. + strip_tokens : set[Any] + Token values that should be removed if detected at beginning or end of document + (i.e. any eos/bos tokens already present in the data). Type should match data type. + seed : int + The random seed for deterministic shuffling/sharding + min_length : int + Sequences below this length are skipped + max_chunksize : int + Maximum sequence length to return. Break long docs into chunks of this size or shorter. + verbose : bool + Track setup progress? + shuffle : bool + Shuffle shard file and document orders? (Disable for simple testing) + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + delimiter_token: Any, + bos_token: Optional[Any] = None, + strip_tokens: Optional[Set[Any]] = set(), + seed: int = 42, + min_length: int = 1, + max_chunksize: int = 1024, + verbose: bool = False, + ): + super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) + self.seed = seed + self.data = datapath + self.min_length = min_length + assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" + self.chunksize = max_chunksize + self.eos = delimiter_token + self.bos = bos_token + self.drop = strip_tokens + self.verbose = verbose + self.docset: List[ + Any + ] = [] # map of doc indices to (shardid, min docid, max docid) + self.docs_per_shard = {} + + # Position + self.docset_index = 0 + self.chunk_index = -1 + + # Stats + self.epochs_seen = -1 + self.tokens_seen = 0 + self.docs_seen = 0 + self.percent_seen = 0 + self.lcg_state = 0 + + self.state_params = [ + "dataset", + "docset_index", + "chunk_index", + "epochs_seen", + "tokens_seen", + "docs_seen", + "percent_seen", + "lcg_state", + ] + + def setup(self): + """ + All rank-dependent setup, which must occur after init + (rank assignment, subdataset splitting, etc.) + """ + datapath = self.data + + # Gather per-file document counts from metadata count file(s) + countfiles = [ + x + for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) + if "counts" in x and "csv" in x + ] + assert len(countfiles) == 1 + doc_counts = {} + pathsplit = (datapath, "") + while len(pathsplit[1]) == 0: + pathsplit = os.path.split(pathsplit[0]) + pardir, dataset = pathsplit + self.dataset = dataset + with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find("/" + dataset) + 1 + if prefix > 0: + key = fullpath[prefix:] + doc_counts[key] = int(row["documents"]) + + # Assemble document set owned by this worker: + # listdir, assemble shardfraglist (ind -> shard, frag) + shards = [ + shard + for shard in os.listdir(datapath) + if os.path.isfile(os.path.join(datapath, shard)) + and "arrow" in os.path.join(datapath, shard) + ] + shards.sort() # Ensure consistent sharding across machines + start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize + end_frag = ((self.rank + 1) * self.worldsize * len(shards)) // self.worldsize + shardfrags = [ + (shards[i // self.worldsize], i % self.worldsize) + for i in range(start_frag, end_frag) + ] + + # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): + ndocs = -1 + docset = {} # shardid -> (min docid, max docid) + for i, (shard, frag) in enumerate(shardfrags): + ndocs = doc_counts[os.path.join(dataset, shard)] + self.docs_per_shard[shard] = ndocs + doc_start = (ndocs * frag) // self.worldsize + doc_end = ( + ndocs * frag + ndocs + ) // self.worldsize - 1 # Inclusive upper bound + if shard not in docset: + docset[shard] = [doc_start, doc_end] + min_d, max_d = docset[shard] + if doc_start < min_d: + docset[shard][0] = doc_start + if doc_end > max_d: + docset[shard][1] = doc_end + + # Add all of this dataset's shard entries to self.docset + doccount = 0 + for shardid in docset: + min_d = docset[shardid][0] + max_d = docset[shardid][1] + self.docset.append((shardid, min_d, max_d)) + doccount += max_d - min_d + 1 + self._len = doccount + + if self.verbose: + logging.info( + f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" + ) + + # Shuffle shard files - guaranteed inconsistent across workers + seed = self.seed + self.rank + random.seed(seed) + random.shuffle(self.docset) + # Setup doc shuffle - same guarantee + self.lcg_state = seed + + def _get_docid(self, i): + """ + Given a global doc index over the set of docs owned by this worker, + return the corresponding data/shard/local index + """ + cur = 0 + assert ( + i <= self._len + ), f"You have requested an illegal doc index {i}, docset length is {self._len}" + for shardid, min_d, max_d in self.docset: + docrange = max_d - min_d + 1 + cur += docrange + if cur > i: + return shardid, docrange, min_d + + def _get_reader(self, path, newpath, reader): + """ + If new filepath does not match the current one, + open a new reader on that filepath (pull file on demand) + """ + if newpath != path: + del reader + if self.verbose: + logging.info(f"Worker {self.rank} opening new file {newpath}") + reader = pa.ipc.open_file(newpath) + path = newpath + return path, reader + + def _construct_chunk(self, j, doc, n_chunks): + """ + Grab a chunk of the desired size from the pyarrow document, + avoiding unnecessary overhead in case of large docs + """ + start_index = j * self.chunksize + n_pull = self.chunksize + if self.bos is not None: + if j == 0: + n_pull -= 1 + else: + start_index -= 1 + chunk = doc.slice(start_index, n_pull).to_pylist() + self.tokens_seen += len(chunk) + # Add bos/eos tokens if needed + if self.bos is not None and j == 0: + chunk = [self.bos] + chunk + if j == n_chunks - 1: + chunk = chunk + [self.eos] + return chunk + + def _random_map_docid(self, size): + """ + Given size of document pool, use saved state (prior index) to generate the next index via LCG. + Implements within-shard document shuffling without materializing any large doc lists. + """ + m = 2 ** math.ceil(math.log2(size)) # Round up to nearest power of 2 + a = 5 # A,C values known to work well with powers of 2 (Knuth, 1997, 3.2.1.3) + c = (self.rank + self.seed) * 2 + 1 + state = self.lcg_state + while True: + state = (a * state + c) % m + if state < size: + return state + + def __iter__(self): + self.setup() + docset_offset = self.docset_index + lcg_offset = self.lcg_state + residual_chunks = self.chunk_index + 1 # pick up AFTER where the ckp left off + ndocs = self._len + path = "" + reader = None + while True: + # Iterate through docs, starting at desired offset + for i in range(ndocs): + doc_index = (docset_offset + i) % ndocs + + # Update stats + if doc_index == 0: + self.epochs_seen += 1 + self.docset_index = doc_index + # Map doc id to shard, id in file + shardid, docrange, mindoc = self._get_docid(doc_index) + + # Read doc + newpath = os.path.join(self.data, shardid) + path, reader = self._get_reader(path, newpath, reader) + # Map id in range of owned docs to new (consistently) shuffled id + doclcg = self._random_map_docid(docrange) + docid = doclcg + mindoc + doc = reader.get_batch(docid)["tokens"] + if doc[0].as_py() in self.drop: + doc = doc.slice(1, len(doc) - 1) + if doc[-1].as_py() in self.drop: + doc = doc.slice(0, len(doc) - 1) + doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 + if doclen >= self.min_length: + n_chunks = math.ceil(doclen / self.chunksize) + for j in range(n_chunks): + if i == 0 and j < residual_chunks: + pass + else: + self.chunk_index = j + # Document complete, update stats + if j == n_chunks - 1: + self.docs_seen += 1 + self.percent_seen = ( + self.docs_seen * 100 / (self._len + 1e-9) + ) + yield self._construct_chunk(j, doc, n_chunks) + + # Advance RNG state + self.lcg_state = doclcg + + # Load any chunks initially skipped in first doc + self.docset_index = docset_offset + self.lcg_state = lcg_offset + shardid, docrange, mindoc = self._get_docid(docset_offset) + docid = self._random_map_docid(docrange) + mindoc + newpath = os.path.join(self.data, shardid) + path, reader = self._get_reader(path, newpath, reader) + doc = reader.get_batch(docid)["tokens"] + if doc[0].as_py() in self.drop: + doc = doc.slice(1, len(doc) - 1) + if doc[-1].as_py() in self.drop: + doc = doc.slice(0, len(doc) - 1) + doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 + if doclen >= self.min_length: + n_chunks = math.ceil(doclen / self.chunksize) + for j in range(residual_chunks): + self.chunk_index = j + yield self._construct_chunk(j, doc, n_chunks) + + def load_state_dict(self, state_dicts, sharded_input=False): + assert ( + self.load_worldsize == self.worldsize + ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." + d = self.dataset + out = super().load_state_dict(state_dicts, sharded_input) + assert ( + d == self.dataset + ), f"Dataset mismatch: checkpoint contains {self.dataset}, expected {d}" + return out + + +class Sampling_Dataset(_Stateful_Dataset): + """ + A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the + number of tokens seen from each subdataset will match those weights as closely as possible. + This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking + the number of tokens emitted by each. Whichever loader is furthest from its target will be + the next to pass a document. + + All args except for dataset_type, datasets, weights and delimiter are pass-through args for + the component _Stateful_Datasets and are documented in the appropriate classes. + ... + Args + ---- + dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset + Underlying iterator for each desired subdataset + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + datasets : list[str] | None + A list of subdatasets to draw from. If None, draws from all subfolders of datapath. + weights : list(float) | None + Weights describing what percent of emitted tokens should come from each subdataset. + Need not sum to 1. If None, tokens are drawn evenly. + ... + Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset + """ + + def __init__( + self, + datapath: str, + dataset_type: Union[ + Type["Streaming_Doc_Dataset"], + Type["Scalable_Shard_Dataset"], + ], + rank: int, + worldsize: int, + delimiter_token: Any, + datasets=None, + weights=None, + verbose=False, + **kwargs, + ): + super().__init__(rank, worldsize) + self.delimiter = delimiter_token + self.datasets = ( + datasets + if datasets is not None + else [ + f + for f in os.listdir(datapath) + if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f + ] + ) + assert len(self.datasets) > 0, "You must specify at least one dataset" + + if weights is not None: + assert len(weights) == len( + self.datasets + ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" + for w in weights: + assert w > 0, f"Sampling rate {w} must be positive" + self.weights = [1] * len(self.datasets) if weights is None else weights + self.weights = [w / sum(self.weights) for w in self.weights] + + self.tokens_seen = [0] * len(self.datasets) + + # Build subdataset iterators + self.data = [] + for i, d in enumerate(self.datasets): + self.data.append( + dataset_type( + datapath=os.path.join(datapath, d), + rank=rank, + worldsize=worldsize, + delimiter_token=delimiter_token, + verbose=verbose, + **kwargs, + ) + ) + if verbose: + logging.info( + f"Worker {rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" + ) + + self.current_iterator = -1 + self.state_params = ["tokens_seen", "current_iterator"] + + def __iter__(self): + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + if self.current_iterator != -1: + # Finish current document + out = next(data[self.current_iterator]) + self.tokens_seen[self.current_iterator] += len(out) + if out[-1] == self.delimiter: + self.current_iterator = -1 + yield out + else: + # Choose new subdataset to draw from + # (whichever is currently most underrepresented compared to target rate) + offset = [ + self.weights[i] + - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) + for i in range(len(self.datasets)) + ] + offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] + self.current_iterator = offset_argmax + + def state_dict(self): + # Manually add state of all subloaders to self state + out = { + self.statename("sample_iterator_states"): [ + d.state_dict() for d in self.data + ] + } + out.update(super().state_dict()) + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + # Load stats + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Load sub-iterator states + for i, subdata in enumerate(self.data): + # Grab just that sub-iterator across all ranks + subdata.load_worldsize = self.load_worldsize + subdata.load_state_dict( + [ + sd[self.statename("sample_iterator_states")][i] + for sd in sharded_dicts + ], + True, + ) + return sharded_dicts + + +class Scalable_Shard_Dataset(_Stateful_Dataset): + """ + A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track + state individually and reshard over n_gpus. + + All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. + rank : int + Current worker index + worldsize : int + Total number of workers + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + n_logical_shards : int + Number of logical shards. Must be a multiple of world size. + ... + Pass-through args, see Streaming_Doc_Dataset + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + delimiter_token: Any, + n_logical_shards: int = 2048, + verbose=False, + **kwargs, + ): + assert ( + n_logical_shards % worldsize == 0 + ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert ( + n_logical_shards > 0 + ), f"n_logical_shards {n_logical_shards} must be a positive integer" + + super().__init__(rank, worldsize) + self.data = [] + self.n_logicals = n_logical_shards // worldsize + self.total_shards = n_logical_shards + self.delimiter = delimiter_token + + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + assert len(self.logicals_owned) == self.n_logicals + + # Build logical shards + for i in range(self.n_logicals): + self.data.append( + Streaming_Doc_Dataset( + datapath=datapath, + worldsize=n_logical_shards, + rank=self.logicals_owned[i], + delimiter_token=delimiter_token, + verbose=(rank == 0), + **kwargs, + ) + ) + if verbose: + logging.info( + f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + + # Fetch logical shard sampling stats + self.n_docs_remaining = [d._len for d in self.data] + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = None + self.logical_shard_states = None + self.generator = torch.Generator().manual_seed(self.rank) + self.g_state = None + self.state_params = ["current_reader", "g_state"] + self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + + def __iter__(self): + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + # Sample logical shard (or load from ckp) + if self.current_reader is not None: + ind = self.current_reader + else: + ind = torch.multinomial( + torch.tensor(self.n_docs_remaining, dtype=torch.float), + 1, + generator=self.generator, + ).item() + self.current_reader = ind + # Read doc + out = next(data[ind]) + while out[-1] != self.delimiter: + yield out + out = next(data[ind]) + # Update state to show we've finished the doc + self.current_reader = None + self.n_docs_remaining[ind] -= 1 + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) + # Return final piece of doc + yield out + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + # Recursive fetch + self.logical_shard_states = [d.state_dict() for d in self.data] + return super().state_dict() + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Recursive set + for i in range(self.n_logicals): + self.data[i].load_state_dict([self.logical_shard_states[i]], True) + return sharded_dicts diff --git a/main_training.py b/main_training.py index fc3625ee..bae6dbad 100644 --- a/main_training.py +++ b/main_training.py @@ -122,7 +122,9 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/") if not os.path.isfile(cfg.ckpt_load_path) else cfg.ckpt_load_path, + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path, strict=False, ) if cfg.reset_stepcount: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a581494f..6ceab089 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -8,7 +8,7 @@ import pyarrow as pa import torch -from fms_fsdp.utils.dataset_utils import * +from fms_fsdp.utils.dataset_utils_v3 import * # Generates test data in a temp directory, and returns that tempdir object. From 736e431dec23d1fd906b823a55ee2faa4fe7bdc8 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 13:20:21 -0400 Subject: [PATCH 26/73] call setup when loading but haven't stepped yet --- fms_fsdp/utils/dataset_utils_v3.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index c8e263d0..a72763d2 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -4,6 +4,7 @@ import os import random import time +from copy import deepcopy from typing import Any, Callable, List, Optional, Set, Type, Union import pyarrow as pa @@ -584,6 +585,7 @@ def __init__( super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) self.seed = seed self.data = datapath + self.dataset = "" self.min_length = min_length assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" self.chunksize = max_chunksize @@ -618,12 +620,15 @@ def __init__( "lcg_state", ] + self.is_setup = False + def setup(self): """ All rank-dependent setup, which must occur after init (rank assignment, subdataset splitting, etc.) """ datapath = self.data + self.is_setup = True # Gather per-file document counts from metadata count file(s) countfiles = [ @@ -766,7 +771,8 @@ def _random_map_docid(self, size): return state def __iter__(self): - self.setup() + if not self.is_setup: + self.setup() docset_offset = self.docset_index lcg_offset = self.lcg_state residual_chunks = self.chunk_index + 1 # pick up AFTER where the ckp left off @@ -835,6 +841,8 @@ def __iter__(self): yield self._construct_chunk(j, doc, n_chunks) def load_state_dict(self, state_dicts, sharded_input=False): + if not self.is_setup: + self.setup() assert ( self.load_worldsize == self.worldsize ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." From 12679badfbfe93c04c30ccec23c9822db836bcf2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 13:26:21 -0400 Subject: [PATCH 27/73] Remove redundant buggy tracking field --- fms_fsdp/utils/dataset_utils_v3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index a72763d2..aef6a96d 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -596,7 +596,6 @@ def __init__( self.docset: List[ Any ] = [] # map of doc indices to (shardid, min docid, max docid) - self.docs_per_shard = {} # Position self.docset_index = 0 @@ -673,7 +672,6 @@ def setup(self): docset = {} # shardid -> (min docid, max docid) for i, (shard, frag) in enumerate(shardfrags): ndocs = doc_counts[os.path.join(dataset, shard)] - self.docs_per_shard[shard] = ndocs doc_start = (ndocs * frag) // self.worldsize doc_end = ( ndocs * frag + ndocs From 225879f564bf2dc69db5a363c4f53f676a91a7d0 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 13:33:11 -0400 Subject: [PATCH 28/73] Shift _len back to init --- fms_fsdp/utils/dataset_utils_v3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index aef6a96d..46a18bc6 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -585,7 +585,6 @@ def __init__( super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) self.seed = seed self.data = datapath - self.dataset = "" self.min_length = min_length assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" self.chunksize = max_chunksize @@ -606,7 +605,6 @@ def __init__( self.tokens_seen = 0 self.docs_seen = 0 self.percent_seen = 0 - self.lcg_state = 0 self.state_params = [ "dataset", @@ -619,7 +617,11 @@ def __init__( "lcg_state", ] + # Setup flags self.is_setup = False + self._len = 0 + self.dataset = "" + self.lcg_state = 0 def setup(self): """ From 301a5307eafb8ffc0fa86f649652fe17b2d46f37 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 13:41:56 -0400 Subject: [PATCH 29/73] Get sampling probs after setup --- fms_fsdp/utils/dataset_utils_v3.py | 126 ++++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index 46a18bc6..def5ea39 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -1058,8 +1058,8 @@ def __init__( f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" ) - # Fetch logical shard sampling stats - self.n_docs_remaining = [d._len for d in self.data] + # Logical shard sampling stats - populate after subdataset setup + self.n_docs_remaining = [] # Position "state", used only for maintaining order when n_workers is unchanged # For scaling up or down, logical position is meaningless, and reset @@ -1071,6 +1071,8 @@ def __init__( self.reshard_params = ["n_docs_remaining", "logical_shard_states"] def __iter__(self): + [d.setup() for d in self.data] + self.n_docs_remaining = [d._len for d in self.data] # Grab one doc at a time in random order data = [iter(d) for d in self.data] while True: @@ -1114,3 +1116,123 @@ def load_state_dict(self, state_dicts, sharded_input=False): for i in range(self.n_logicals): self.data[i].load_state_dict([self.logical_shard_states[i]], True) return sharded_dicts + + + +# class Scalable_Shard_Dataset(_Wrapper_Dataset): +# """ +# A _Wrapper_Dataset implementing rescalability: loading from checkpoint into a different +# number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. +# This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track +# state individually and reshard over n_gpus. + +# Because only one Streaming_Doc_Dataset is provided to the wrapper, the wrapper clones it, and +# all rank-dependent setup for the base datasets is deferred to the first step. + +# ... +# Args +# ---- +# dataset : Streaming_Doc_Dataset +# The worker to instantiate in each logical shard +# rank : int +# Current worker index +# worldsize : int +# Total number of workers +# delimiter_token : Any +# Token used to indicate sequence/document breaks. Type should match data type. +# n_logical_shards : int +# Number of logical shards. Must be a multiple of world size. +# """ + +# def __init__( +# self, +# dataset: Streaming_Doc_Dataset, +# rank: int, +# worldsize: int, +# n_logical_shards: int = 2048, +# verbose=False, +# ): +# assert ( +# n_logical_shards % worldsize == 0 +# ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" +# assert ( +# n_logical_shards > 0 +# ), f"n_logical_shards {n_logical_shards} must be a positive integer" + +# super().__init__(rank, worldsize) +# self.data = [] +# self.n_logicals = n_logical_shards // worldsize +# self.total_shards = n_logical_shards +# self.delimiter = dataset.delimiter_token + +# logicals = list(range(n_logical_shards)) +# self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) +# assert len(self.logicals_owned) == self.n_logicals + +# # Build logical shards +# self.data = [deepcopy(dataset) for _ in range(self.n_logicals)] +# for i,d in enumerate(self.data): +# d.worldsize = n_logical_shards +# d.rank = self.logicals_owned[i] +# d.verbose = (rank == 0) +# if verbose: +# logging.info( +# f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" +# ) + +# # Fetch logical shard sampling stats +# self.n_docs_remaining = [d._len for d in self.data] + +# # Position "state", used only for maintaining order when n_workers is unchanged +# # For scaling up or down, logical position is meaningless, and reset +# self.current_reader = None +# self.logical_shard_states = None +# self.generator = torch.Generator().manual_seed(self.rank) +# self.g_state = None +# self.state_params = ["current_reader", "g_state"] +# self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + +# def __iter__(self): +# # Grab one doc at a time in random order +# data = [iter(d) for d in self.data] +# while True: +# # Sample logical shard (or load from ckp) +# if self.current_reader is not None: +# ind = self.current_reader +# else: +# ind = torch.multinomial( +# torch.tensor(self.n_docs_remaining, dtype=torch.float), +# 1, +# generator=self.generator, +# ).item() +# self.current_reader = ind +# # Read doc +# out = next(data[ind]) +# while out[-1] != self.delimiter: +# yield out +# out = next(data[ind]) +# # Update state to show we've finished the doc +# self.current_reader = None +# self.n_docs_remaining[ind] -= 1 +# if sum(self.n_docs_remaining) == 0: +# self.n_docs_remaining = [d._len for d in self.data] +# self.generator.manual_seed(self.rank) +# # Return final piece of doc +# yield out + +# def state_dict(self): +# # Write generator state manually +# self.g_state = self.generator.get_state() +# # Recursive fetch +# self.logical_shard_states = [d.state_dict() for d in self.data] +# return _Stateful_Dataset.state_dict() + +# def load_state_dict(self, state_dicts, sharded_input=False): +# sharded_dicts = _Stateful_Dataset.load_state_dict(state_dicts, sharded_input) +# # Manually set generator state if it exists +# if self.g_state is not None: +# self.generator.set_state(self.g_state) +# # Recursive set +# for i in range(self.n_logicals): +# self.data[i].load_state_dict([self.logical_shard_states[i]], True) +# return sharded_dicts From f61a3db2cecc6b9f65768a50c319b63bf1d0f124 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 13:55:17 -0400 Subject: [PATCH 30/73] Make setup properly conditional --- fms_fsdp/utils/dataset_utils_v3.py | 158 ++++++++++++++--------------- 1 file changed, 79 insertions(+), 79 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index def5ea39..f161362b 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -582,7 +582,7 @@ def __init__( max_chunksize: int = 1024, verbose: bool = False, ): - super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) + super().__init__(rank, worldsize) self.seed = seed self.data = datapath self.min_length = min_length @@ -628,84 +628,85 @@ def setup(self): All rank-dependent setup, which must occur after init (rank assignment, subdataset splitting, etc.) """ - datapath = self.data - self.is_setup = True - - # Gather per-file document counts from metadata count file(s) - countfiles = [ - x - for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) - if "counts" in x and "csv" in x - ] - assert len(countfiles) == 1 - doc_counts = {} - pathsplit = (datapath, "") - while len(pathsplit[1]) == 0: - pathsplit = os.path.split(pathsplit[0]) - pardir, dataset = pathsplit - self.dataset = dataset - with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: - key = fullpath[prefix:] - doc_counts[key] = int(row["documents"]) - - # Assemble document set owned by this worker: - # listdir, assemble shardfraglist (ind -> shard, frag) - shards = [ - shard - for shard in os.listdir(datapath) - if os.path.isfile(os.path.join(datapath, shard)) - and "arrow" in os.path.join(datapath, shard) - ] - shards.sort() # Ensure consistent sharding across machines - start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize - end_frag = ((self.rank + 1) * self.worldsize * len(shards)) // self.worldsize - shardfrags = [ - (shards[i // self.worldsize], i % self.worldsize) - for i in range(start_frag, end_frag) - ] + if not self.is_setup: + datapath = self.data + self.is_setup = True + + # Gather per-file document counts from metadata count file(s) + countfiles = [ + x + for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) + if "counts" in x and "csv" in x + ] + assert len(countfiles) == 1 + doc_counts = {} + pathsplit = (datapath, "") + while len(pathsplit[1]) == 0: + pathsplit = os.path.split(pathsplit[0]) + pardir, dataset = pathsplit + self.dataset = dataset + with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find("/" + dataset) + 1 + if prefix > 0: + key = fullpath[prefix:] + doc_counts[key] = int(row["documents"]) + + # Assemble document set owned by this worker: + # listdir, assemble shardfraglist (ind -> shard, frag) + shards = [ + shard + for shard in os.listdir(datapath) + if os.path.isfile(os.path.join(datapath, shard)) + and "arrow" in os.path.join(datapath, shard) + ] + shards.sort() # Ensure consistent sharding across machines + start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize + end_frag = ((self.rank + 1) * self.worldsize * len(shards)) // self.worldsize + shardfrags = [ + (shards[i // self.worldsize], i % self.worldsize) + for i in range(start_frag, end_frag) + ] - # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): - ndocs = -1 - docset = {} # shardid -> (min docid, max docid) - for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[os.path.join(dataset, shard)] - doc_start = (ndocs * frag) // self.worldsize - doc_end = ( - ndocs * frag + ndocs - ) // self.worldsize - 1 # Inclusive upper bound - if shard not in docset: - docset[shard] = [doc_start, doc_end] - min_d, max_d = docset[shard] - if doc_start < min_d: - docset[shard][0] = doc_start - if doc_end > max_d: - docset[shard][1] = doc_end - - # Add all of this dataset's shard entries to self.docset - doccount = 0 - for shardid in docset: - min_d = docset[shardid][0] - max_d = docset[shardid][1] - self.docset.append((shardid, min_d, max_d)) - doccount += max_d - min_d + 1 - self._len = doccount - - if self.verbose: - logging.info( - f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" - ) + # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): + ndocs = -1 + docset = {} # shardid -> (min docid, max docid) + for i, (shard, frag) in enumerate(shardfrags): + ndocs = doc_counts[os.path.join(dataset, shard)] + doc_start = (ndocs * frag) // self.worldsize + doc_end = ( + ndocs * frag + ndocs + ) // self.worldsize - 1 # Inclusive upper bound + if shard not in docset: + docset[shard] = [doc_start, doc_end] + min_d, max_d = docset[shard] + if doc_start < min_d: + docset[shard][0] = doc_start + if doc_end > max_d: + docset[shard][1] = doc_end + + # Add all of this dataset's shard entries to self.docset + doccount = 0 + for shardid in docset: + min_d = docset[shardid][0] + max_d = docset[shardid][1] + self.docset.append((shardid, min_d, max_d)) + doccount += max_d - min_d + 1 + self._len = doccount + + if self.verbose: + logging.info( + f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" + ) - # Shuffle shard files - guaranteed inconsistent across workers - seed = self.seed + self.rank - random.seed(seed) - random.shuffle(self.docset) - # Setup doc shuffle - same guarantee - self.lcg_state = seed + # Shuffle shard files - guaranteed inconsistent across workers + seed = self.seed + self.rank + random.seed(seed) + random.shuffle(self.docset) + # Setup doc shuffle - same guarantee + self.lcg_state = seed def _get_docid(self, i): """ @@ -841,8 +842,7 @@ def __iter__(self): yield self._construct_chunk(j, doc, n_chunks) def load_state_dict(self, state_dicts, sharded_input=False): - if not self.is_setup: - self.setup() + self.setup() assert ( self.load_worldsize == self.worldsize ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." From 4b504d114bbef75c3a6838f5404009e3db74d9e1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 14:41:38 -0400 Subject: [PATCH 31/73] Add setup to scalable, fix test_reload_epoch sampler call --- fms_fsdp/utils/dataset_utils_v3.py | 71 ++++++++++++++++++------------ tests/test_datasets.py | 2 +- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index f161362b..6acca817 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -1032,47 +1032,63 @@ def __init__( ), f"n_logical_shards {n_logical_shards} must be a positive integer" super().__init__(rank, worldsize) - self.data = [] - self.n_logicals = n_logical_shards // worldsize + self.datapath = datapath self.total_shards = n_logical_shards self.delimiter = delimiter_token + self.kwargs = kwargs + self.verbose = verbose - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - assert len(self.logicals_owned) == self.n_logicals - - # Build logical shards - for i in range(self.n_logicals): - self.data.append( - Streaming_Doc_Dataset( - datapath=datapath, - worldsize=n_logical_shards, - rank=self.logicals_owned[i], - delimiter_token=delimiter_token, - verbose=(rank == 0), - **kwargs, - ) - ) - if verbose: - logging.info( - f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - - # Logical shard sampling stats - populate after subdataset setup + # Fields to be populated during setup / subdataset setup + self.data = [] + self.logicals_owned = [] + self.n_logicals = 0 self.n_docs_remaining = [] # Position "state", used only for maintaining order when n_workers is unchanged # For scaling up or down, logical position is meaningless, and reset self.current_reader = None self.logical_shard_states = None - self.generator = torch.Generator().manual_seed(self.rank) + self.generator = torch.Generator() + self.g_state = None self.state_params = ["current_reader", "g_state"] self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + self.is_setup = False + + def setup(self): + if not self.is_setup: + self.is_setup = True + n_logical_shards = self.total_shards + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + self.n_logicals = n_logical_shards // self.worldsize + assert len(self.logicals_owned) == self.n_logicals + + # Build logical shards + for i in range(self.n_logicals): + self.data.append( + Streaming_Doc_Dataset( + datapath=self.datapath, + worldsize=n_logical_shards, + rank=self.logicals_owned[i], + delimiter_token=self.delimiter, + verbose=(self.rank == 0), + **self.kwargs, + ) + ) + if self.verbose: + logging.info( + f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + + self.generator.manual_seed(self.rank) + + [d.setup() for d in self.data] + self.n_docs_remaining = [d._len for d in self.data] + def __iter__(self): - [d.setup() for d in self.data] - self.n_docs_remaining = [d._len for d in self.data] + self.setup() # Grab one doc at a time in random order data = [iter(d) for d in self.data] while True: @@ -1108,6 +1124,7 @@ def state_dict(self): return super().state_dict() def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() sharded_dicts = super().load_state_dict(state_dicts, sharded_input) # Manually set generator state if it exists if self.g_state is not None: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6ceab089..232beada 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -505,7 +505,7 @@ def test_reload_complete_epoch(): reload_single_epoch_check(basic_loader) reload_single_epoch_check(functools.partial(basic_scalable, n_logical_shards=8)) reload_single_epoch_check(basic_sampler) - reload_single_epoch_check(functools.partial(basic_scalable, n_logical_shards=8)) + reload_single_epoch_check(functools.partial(basic_scalable_sampler, n_logical_shards=8)) def test_eos_bos_chunking(): From ee6d0c7f02fe452ed9cc968ff811ff65761c615c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 23 Jul 2024 15:54:18 -0400 Subject: [PATCH 32/73] setup in scalable, not yet wrapper --- fms_fsdp/utils/dataset_utils_v3.py | 23 ++++++++++++++++++----- tests/test_datasets.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index 6acca817..d0b99d40 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -89,6 +89,12 @@ def __init__( self.load_worldsize = ( worldsize # Enable calling load_state_dict() directly, assume no rescaling ) + self.is_setup = False + + def setup(self): + if not self.is_setup: + self.is_setup = True + pass def statename(self, x: str): # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline @@ -96,8 +102,10 @@ def statename(self, x: str): def state_dict(self): """ - Retrieve all state and reshard flags (each worker/process saves its own state dict shard) + Retrieve all state and reshard flags (each worker/process saves its own state dict shard). + On the off chance that you're saving a checkpoint with zero steps, run setup first. """ + self.setup() return { self.statename(flag): getattr(self, flag) for flag in self.state_params + self.reshard_params @@ -136,14 +144,16 @@ def load_state_dict(self, state_dicts, sharded_input=False): global list of states across all checkpoint shard files. If sharded_input=True, this expects _shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading. Workflow: - 1. if sharded_inputs is false, shard the inputs. - 2. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint + 1. Run setup to prepare dataset + 2. if sharded_inputs is false, shard the inputs. + 3. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint shard (state_dicts is a singleton list). - 3. If worldsize does not match checkpoint, toss state params and assemble reshard params from + 4. If worldsize does not match checkpoint, toss state params and assemble reshard params from across given state_dicts. In this case state_dicts may be singleton (for fractional ownership) or multi-element (for multiple/partitioned ownership). - 4. Return reduced input for use by downstream loading functions + 5. Return reduced input for use by downstream loading functions """ + self.setup() if not sharded_input: self.load_worldsize = len(state_dicts) state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) @@ -205,6 +215,7 @@ def load_state_dict(self, state_dicts, sharded_input=False): """ Sets all specified flags at the current level, then recurses into wrapped dataset. """ + self.setup() sharded_dicts = super().load_state_dict(state_dicts, sharded_input) self.dataset.load_worldsize = self.load_worldsize self.dataset.load_state_dict(sharded_dicts, True) @@ -215,6 +226,7 @@ def state_dict(self): Fetches state dict recursively from wrapped layers, then adds specified flags. Overlapping flags are overwritten with a warning. """ + self.setup() out = self.dataset.state_dict() state = super().state_dict() for flag in self.state_params + self.reshard_params: @@ -1117,6 +1129,7 @@ def __iter__(self): yield out def state_dict(self): + self.setup() # Write generator state manually self.g_state = self.generator.get_state() # Recursive fetch diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 232beada..576b3304 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -107,7 +107,7 @@ def reload_stress(datasets, datasets2, steps1, steps2): out1[j] == out2[j] ), f"Dataloader {i} in step {k} has mismatched token in position {j}: {out1[j]} vs {out2[j]}" - steps1 = [0, 1, 10, 100, 1000] + steps1 = [1, 10, 100, 1000] steps2 = [100, 200, 300, 400, 500] for i in range(len(steps1)): # Reset between tests (instantiate fresh datasets) From 468762097c9c2190286d8fbefeec523cf949a395 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 24 Jul 2024 17:23:44 -0400 Subject: [PATCH 33/73] Make scalable and sampler wrappers --- fms_fsdp/utils/dataset_utils_v2.py | 1097 ---------------------------- fms_fsdp/utils/dataset_utils_v3.py | 485 +++++------- tests/test_datasets.py | 208 ++---- 3 files changed, 243 insertions(+), 1547 deletions(-) delete mode 100644 fms_fsdp/utils/dataset_utils_v2.py diff --git a/fms_fsdp/utils/dataset_utils_v2.py b/fms_fsdp/utils/dataset_utils_v2.py deleted file mode 100644 index 08e47eba..00000000 --- a/fms_fsdp/utils/dataset_utils_v2.py +++ /dev/null @@ -1,1097 +0,0 @@ -import csv -import logging -import math -import os -import random -import time -from typing import Any, Callable, List, Optional, Set, Type, Union - -import pyarrow as pa -import torch -import torch.utils.data as data - -from fms_fsdp.utils.checkpointing_utils import get_latest - - -""" -The following distributed dataloaders are designed around 3 main principles: - -1. Efficient, asynchronous operation. Workers on different devices do not communicate. -2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator - loading from disk and additional layers adding levels of post-processing (shuffling, - packing, padding, etc.). -3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal - state that can be written/read on disk via implemented recursive `state_dict()` and - `load_state_dict()` calls. -4. Rescalability. Users can save and load checkpoints to/from different numbers of workers - without losing the global state. This is accomplished by splitting state fields for each - layer into `state_params`, which are typically scalar-valued and can be discarded when - rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be - re-distributed over workers (i.e. buffers). - -Our loaders obey the following type heirarchy: -torch.data.IterableDataset -> _Stateful_Dataset -> _Wrapper_Dataset. -`_Stateful_Dataset` implements state and checkpointing logic. A `_Wrapper_Dataset` holds a -single `_Stateful_Dataset` and iterates via calling its wrapped dataset any number of times, -then applying some sort of post-processing and yielding the result. Users build data processing -pipelines by wrapping a base `_Stateful_Dataset` in any number of `_Wrapper_Dataset` layers, -which is then passed to the torch DataLoader. - -NOTE: `_Wrapper_Dataset` currently only implements wrapping a single instantiated sub-dataset layer. -Many layers need multiple sub-layers (i.e. random sampling from distinct data sources). These are -currently implemented as base `_Stateful_Datasets` that take the class of their sub-layers plus any -pass-through arguments, and instantiate all those sub-layers. This is easy on the user, who no longer -needs to instantiate large sets of sub-layers in their code, but leads to awkwardness in this file. -Cleanup is planned for the future. -""" - - -def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. - """ - return itemlist[ - (rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize - ] - - -def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, - and return the span including all owned items, fractional or otherwise. - """ - start = math.floor(len(itemlist) * rank / worldsize) - end = math.ceil(len(itemlist) * (rank + 1) / worldsize) - return itemlist[start:end] - - -class _Stateful_Dataset(data.IterableDataset): - """ - Stub for stateful datasets, extends data.IterableDataset with state_dict methods. - All subclasses should specify the params to be considered stateful or reshardable in the - self.state_params and self.reshard_params lists. - """ - - def __init__( - self, - rank: int, - worldsize: int, - ): - assert rank >= 0, f"Rank {rank} must be a positive integer" - assert ( - worldsize > rank - ), f"Worldsize {worldsize} must be greater than rank {rank}" - self.state_params: List[str] = [] - self.reshard_params: List[str] = [] - self.rank = rank - self.worldsize = worldsize - self.load_worldsize = ( - worldsize # Enable calling load_state_dict() directly, assume no rescaling - ) - - def statename(self, x: str): - # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline - return self.__class__.__name__ + "." + x - - def state_dict(self): - """ - Retrieve all state and reshard flags (each worker/process saves its own state dict shard) - """ - return { - self.statename(flag): getattr(self, flag) - for flag in self.state_params + self.reshard_params - } - - def _reshard(self, sharded_list): - """ - Sharded_list is a list of lists, where each "shard" sublist must have the same length. - These shards should tightly span only the partition of data owned by this worker. - (i.e. if global_list is the list of all entries, sharded_list = _shard_inclusive(global_list) ). - Determine fractional ownership of shards, and get the flattened partition owned by this worker. - """ - # How many shards did _shard_inclusive() drop to the left of sharded_list? - shard_offset = math.floor(self.load_worldsize * self.rank / self.worldsize) - # How long are the list shards? - shard_len = len(sharded_list[0]) - for i, shard in enumerate(sharded_list): - assert ( - len(shard) == shard_len - ), f"Shard {i} with length {len(shard)} does not match expected {shard_len}" - # How many list items did _shard_inclusive() drop to the left of the flattened sharded_list? - item_offset = shard_len * shard_offset - # How many list items are there in total? - n_items = self.load_worldsize * shard_len - # The indices of the flattened sharded_list that this worker owns - my_items = range( - int(n_items * self.rank / self.worldsize) - item_offset, - int(n_items * (self.rank + 1) / self.worldsize) - item_offset, - ) - # Pull out owned items - return [sharded_list[i // shard_len][i % shard_len] for i in my_items] - - def load_state_dict(self, state_dicts, sharded_input=False): - """ - Input state_dicts is a list of state_dicts. If sharded_input=False, this is expected to be the - global list of states across all checkpoint shard files. If sharded_input=True, this expects - _shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading. - Workflow: - 1. if sharded_inputs is false, shard the inputs. - 2. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint - shard (state_dicts is a singleton list). - 3. If worldsize does not match checkpoint, toss state params and assemble reshard params from - across given state_dicts. In this case state_dicts may be singleton (for fractional ownership) - or multi-element (for multiple/partitioned ownership). - 4. Return reduced input for use by downstream loading functions - """ - if not sharded_input: - self.load_worldsize = len(state_dicts) - state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) - if self.load_worldsize == self.worldsize: - [ - setattr(self, flag, state_dicts[0][self.statename(flag)]) - for flag in self.state_params + self.reshard_params - ] - else: - for flag in self.reshard_params: - reshard = self._reshard( - [sd[self.statename(flag)] for sd in state_dicts] - ) - setattr(self, flag, reshard) - return state_dicts - - def load_from_path(self, path: str): - """ - Count shard files in the specified checkpoint folder and determine overlap with current - rank and worldsize partition. Load only matching shardfile(s) and pass to load_state_dict. - This is more efficient than sharding the full loaded state. - """ - assert os.path.exists(path), "Specified checkpoint does not exist" - assert not os.path.isfile(path), "Checkpoint should be a folder of shard states" - fileshards = [x for x in os.listdir(path) if "loader" in x] - fileshards = sorted(fileshards, key=lambda x: int(x.split("_")[2][:-4])) - assert ( - len(fileshards) > 0 - ), "Checkpoint directory must contain checkpoint files with 'loader' in the name" - self.load_worldsize = len(fileshards) - # Grab only the shard files holding data we currently own - my_fileshards = _shard_inclusive(fileshards, self.rank, self.worldsize) - states = [torch.load(os.path.join(path, x)) for x in my_fileshards] - self.load_state_dict(states, True) - - def save_to_path(self, path: str): - """ - Grab recursive shard states and save all shard states to the specified checkpoint folder - """ - os.makedirs(path, exist_ok=True) - state = self.state_dict() - torch.save(state, os.path.join(path, f"loader_state_{self.rank}.pth")) - - -class _Wrapper_Dataset(_Stateful_Dataset): - """ - Stub for nested wrappers of _Stateful_Datasets. Extends state fns with recursion. - Requires a single instantiated sub-dataset. - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - ): - self.dataset = dataset - super().__init__(self.dataset.rank, self.dataset.worldsize) - - def load_state_dict(self, state_dicts, sharded_input=False): - """ - Sets all specified flags at the current level, then recurses into wrapped dataset. - """ - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - self.dataset.load_worldsize = self.load_worldsize - self.dataset.load_state_dict(sharded_dicts, True) - return sharded_dicts - - def state_dict(self): - """ - Fetches state dict recursively from wrapped layers, then adds specified flags. - Overlapping flags are overwritten with a warning. - """ - out = self.dataset.state_dict() - state = super().state_dict() - for flag in self.state_params + self.reshard_params: - if flag in out: - logging.warning( - f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. " - + f"Overwriting with value {state[flag]}" - ) - out.update(state) - return out - - -class Preprocess_Dataset(_Wrapper_Dataset): - """ - Wrapper for a _Stateful_Dataset that applies a specified preprocessing - or augmentation function to dataset outputs. - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - aug_fn : function (any -> any) - The augmentation function to apply to each dataset item. - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - aug_fn: Callable, - ): - super().__init__(dataset) - self.aug_fn = aug_fn - - def __iter__(self): - dataset = iter(self.dataset) - while True: - out = next(dataset) - yield self.aug_fn(out) - - -class Checkpoint_Dataset(_Wrapper_Dataset): - """ - Wrapper for a _Stateful_Dataset that implements auto-checkpoint saving every n steps. - Useful for setting n_workers > 0, so that workers do not rely on the master process - for state saving (inter-process communication unsupported in PyTorch datasets). - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - load_path : str - Absolute path to checkpoint load directory. If a checkpoint exists, loads it. - interval : int - Saves a new checkpoint every interval. - steps_per_batch : optional[int] - Number of steps required to fill a single batch. Increments interval only - when a full batch is formed. Defaults to 1. - save_path : optional[str] - Absolute path to checkpoint save directory. Defaults to load_path. - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - load_path: str, - interval: int, - steps_per_batch: int = 1, - save_path: str = "", - ): - super().__init__(dataset) - self.interval = interval - self.spb = steps_per_batch - load_path = os.path.join(load_path, "checkpoints") - if len(save_path) == 0: - save_path = load_path - else: - save_path = os.path.join(save_path, "checkpoints") - self.path = save_path - self.step = 0 - self.ministep = 0 - self.load_from_path(load_path) - - def __iter__(self): - dataset = iter(self.dataset) - while True: - yield next(dataset) - self.ministep += 1 - if self.ministep == self.spb: - self.ministep = 0 - self.step += 1 - if self.step % self.interval == 0: - newpath = os.path.join(self.path, "step_" + str(self.step) + "_ckp") - self.save_to_path(newpath) - - def report(self, msg): - if self.rank == 0: - print(msg) - - def save_to_path(self, path: str): - self.report(f"Saving dataset to {path}") - start = time.time() - super().save_to_path(path) - self.report( - f"Dataset successfully saved to {path}! Save time: {time.time() - start}" - ) - - def load_from_path(self, path: str): - # If path does not exist, or exists but is empty, exit early - if not os.path.exists(path) or len(os.listdir(path)) == 0: - self.report( - f"No valid checkpoint detected at {path}, dataset starting from scratch." - ) - return - # Grab latest item in path - latest = os.path.join(path, get_latest(path)) - self.report(f"Dataset checkpoint detected at {latest}") - # If item is not a folder, exit early - if os.path.isfile(latest): - self.report( - f"Checkpoint exists but contains no dataset! Dataset starting from scratch." - ) - return - # If item is a folder, get the step count - self.step = int(latest.split("_")[-2]) - # Proceed - start = time.time() - self.dataset.load_from_path(latest) - self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}") - - -class Preload_Buffer_Dataset(_Wrapper_Dataset): - """ - Wrapper for a Stateful_Dataset that implements data shuffling via a single in/out buffer. - Fills buffer two at a time, up to desired size, then switches to one at a time to maintain size. - Passes randomly sampled outputs one by one. - Ensures local mixing of data without relying on sliding windows or shuffling of large buffers. - Any two consecutive inputs will be separated by window_size steps in expectation. - Rescaling-enabled: buffers that shrink will re-grow to window_size, buffers that expand stay large. - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - window_size : int - Max size of input/output buffer - """ - - def __init__(self, dataset: _Stateful_Dataset, window_size: int): - super().__init__(dataset) - assert ( - window_size > 1 - ), f"Window size {window_size} must be greater than 1 for shuffling to occur" - self.window_size = window_size - self.g_state = None - self.generator = torch.Generator().manual_seed(self.rank) - self.buffer: List[List[Any]] = [] - self.buffer_size = 0 - self.state_params = ["g_state"] - self.reshard_params = ["buffer"] - - def __iter__(self): - dataset = iter(self.dataset) - while True: - # Pad out buffer if needed - self._pad_buffer() - - # Load a point to buffer if necessary - if self.buffer_size < self.window_size: - self.buffer[self.buffer_size] = next(dataset) - self.buffer_size += 1 - - # Swap out randomly sampled value from buffer - i = torch.randint(self.buffer_size, (1,), generator=self.generator).item() - out = self.buffer[i] - self.buffer[i] = next(dataset) - yield out - - def _pad_buffer(self): - if self.buffer_size < self.window_size: - self.buffer += [ - [], - ] * (self.window_size - self.buffer_size) - - def state_dict(self): - # Write generator state manually - self.g_state = self.generator.get_state() - # Prune buffer so it can be resharded in future - self.buffer = self.buffer[: self.buffer_size] - out = super().state_dict() - return out - - def load_state_dict(self, state_dicts, sharded_input=False): - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Manually set buffer size - self.buffer_size = len(self.buffer) - return sharded_dicts - - -class Buffer_Dataset(_Wrapper_Dataset): - """ - Wrapper for a _Stateful_Dataset that takes in sequences of varying lengths, and packs/pads them - into sequences of desired length. Input sequences are packed greedily until the buffer would - otherwise overrun, then remaining values are filled depending on initialization flags. - Also injects BOS/EOS into the packed output sequence if desired, and if BOS/EOS tokens are - not already in those positions. Implements rescaling by simply dropping (buffer) state. - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - seq_len : int - The desired sequence length - pack_hard : bool - Split input sequences to fill output buffer, or use pad tokens to fill remaining space? - bos_token : any | None - Token to prepend to every output sequence. If None, no token is added. Type should match data type. - eos_token : any | None - Token to append to every output sequence. If None, no token is added. Type should match data type. - pad_token : any | None - Token used to fill out output sequence. Type should match data type. - drop_final_token : any | None - Drop the final token of each document if it matches this value? - (For edge case where bos=eos=None, and sep already appears at beginning of each doc - - drop added extra sep from end of doc) - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - seq_len: int, - pack_hard: bool, - bos_token=None, - eos_token=None, - pad_token=None, - ): - super().__init__(dataset) - self.len = seq_len - - # Buffer args - self.buffer: List[str] = [] - self.bos = bos_token - self.eos = eos_token - self.pad = pad_token - self.pack_hard = pack_hard - if not pack_hard: - assert ( - pad_token is not None - ), "Error: if using pads, you must supply a pad_token" - - self.state_params = ["buffer"] - - def _get_buffer(self, iterable, length, buffer): - # Pull data until buffer is about to overrun, return exactly proper length - new = [] - while len(buffer) + len(new) < length: - buffer += new - new = next(iterable) - - # Add bos if needed - if self.bos is not None and (len(buffer) == 0 or buffer[0] != self.bos): - buffer = [self.bos] + buffer - - # Handle buffer splitting - if len(buffer) >= length: - # If buffer is too long, force split - out = buffer[:length] - buffer = buffer[length:] - if self.eos is not None and out[-1] != self.eos: - buffer = [out[-1]] + buffer - out[-1] = self.eos - buffer = buffer + new - else: - if self.pack_hard: - # Pack in as much of new sequence as will fit - buffer = buffer + new - out = buffer[:length] - buffer = buffer[length:] - if self.eos is not None and out[-1] != self.eos: - buffer = [out[-1]] + buffer - out[-1] = self.eos - else: - # Fill out with pads as needed - if self.eos is not None and buffer[-1] != self.eos: - buffer.append(self.eos) - if self.pad is not None: - out = buffer + [self.pad] * (length - len(buffer)) - else: - out = buffer - buffer = new - return out, buffer - - # Fill buffer line by line, delimiters and packing/splitting as appropriate - def __iter__(self): - dataset = iter(self.dataset) - while True: - out, buffer = self._get_buffer(dataset, self.len, self.buffer) - self.buffer = buffer - yield out - - -class Streaming_Doc_Dataset(_Stateful_Dataset): - """ - The base distributed dataset for loading sequences/documents from pyarrow shards. - Pyarrow shard files are expected to hold multiple recordBatches, where each recordBatch has a "tokens" - field consisting of a single token list. (i.e. each document is a single sequence under a "token" field, - and the file is a list of such sequences) - Relies on a compiled metadata file to fetch shardfile lengths, assumes file already exists in the parent directory, - and is in proper csv format (first row "dataset/filename,documents,tokens", subsequent rows these values). - - For a single dataset directory, splits shard files into x=worldsize fragments and grabs a 1/n contiguous - span of shard fragments (contiguous to limit file reads from cloud/disk). - Logs the number of documents owned from each shardfile, and relies on ZCG random bijection to - map contiguous range of indices to shuffled, noncontiguous set of documents from each shard file. - Shuffles the file list deterministically to hop from file to file. - - At runtime, iterates through documents in each shuffled shard file, pulling each shard on demand. - Shards are thus pulled no more than once per epoch. - Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents. - - Streaming_Doc_Dataset grabs files from a flat directory representing a single dataset. - For percentage-based sampling of multiple subdatasets, see Sampling_Dataset. - ... - Args - ---- - datapath : str - Absolute path to the dataset directory. Expects directory containing pyarrow shardfiles. - Parent directory should contain 'meta' folder with metadata csv file inside. - rank : int - Current worker index - worldsize : int - Total number of workers - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. Required for downstream - sampling logic (can be removed later via PreProcess_Dataset if needed). - bos_token : Any | None - Optional token used to indicate sequence/document start. Type should match data type. - strip_tokens : set[Any] - Token values that should be removed if detected at beginning or end of document - (i.e. any eos/bos tokens already present in the data). Type should match data type. - seed : int - The random seed for deterministic shuffling/sharding - min_length : int - Sequences below this length are skipped - max_chunksize : int - Maximum sequence length to return. Break long docs into chunks of this size or shorter. - verbose : bool - Track setup progress? - shuffle : bool - Shuffle shard file and document orders? (Disable for simple testing) - """ - - def __init__( - self, - datapath: str, - rank: int, - worldsize: int, - delimiter_token: Any, - bos_token: Optional[Any] = None, - strip_tokens: Optional[Set[Any]] = set(), - seed: int = 42, - min_length: int = 1, - max_chunksize: int = 1024, - verbose: bool = False, - shuffle: bool = True, - ): - super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) - self.seed = seed - self.data = datapath - self.min_length = min_length - assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" - self.chunksize = max_chunksize - self.eos = delimiter_token - self.bos = bos_token - self.drop = strip_tokens - self.verbose = verbose - self.docset: List[ - Any - ] = [] # map of doc indices to (shardid, min docid, max docid) - self.docs_per_shard = {} - - # Guaranteed inconsistent shuffling across workers - random.seed(self.seed + rank) - - # Gather per-file document counts from metadata count file(s) - countfiles = [ - x - for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) - if "counts" in x and "csv" in x - ] - assert len(countfiles) == 1 - doc_counts = {} - pathsplit = (datapath, "") - while len(pathsplit[1]) == 0: - pathsplit = os.path.split(pathsplit[0]) - pardir, dataset = pathsplit - self.dataset = dataset - with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: - key = fullpath[prefix:] - doc_counts[key] = int(row["documents"]) - - # Assemble document set owned by this worker: - # listdir, assemble shardfraglist (ind -> shard, frag) - shards = [ - shard - for shard in os.listdir(datapath) - if os.path.isfile(os.path.join(datapath, shard)) - and "arrow" in os.path.join(datapath, shard) - ] - shards.sort() # Ensure consistent sharding across machines - start_frag = (rank * worldsize * len(shards)) // worldsize - end_frag = ((rank + 1) * worldsize * len(shards)) // worldsize - shardfrags = [ - (shards[i // worldsize], i % worldsize) for i in range(start_frag, end_frag) - ] - - # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): - ndocs = -1 - docset = {} # shardid -> (min docid, max docid) - for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[os.path.join(dataset, shard)] - self.docs_per_shard[shard] = ndocs - doc_start = (ndocs * frag) // worldsize - doc_end = (ndocs * frag + ndocs) // worldsize - 1 # Inclusive upper bound - if shard not in docset: - docset[shard] = [doc_start, doc_end] - min_d, max_d = docset[shard] - if doc_start < min_d: - docset[shard][0] = doc_start - if doc_end > max_d: - docset[shard][1] = doc_end - - # Add all of this dataset's shard entries to self.docset - doccount = 0 - for shardid in docset: - min_d = docset[shardid][0] - max_d = docset[shardid][1] - self.docset.append((shardid, min_d, max_d)) - doccount += max_d - min_d + 1 - self._len = doccount - - if verbose: - logging.info( - f" Worker {rank} ingested {len(shardfrags)} shard fragments from {dataset}" - ) - - # Shuffle shard files - if shuffle: - random.shuffle(self.docset) - - self.docset_index = 0 - self.chunk_index = -1 - - # Stats - self.epochs_seen = -1 - self.tokens_seen = 0 - self.docs_seen = 0 - self.percent_seen = 0 - self.lcg_state = seed + rank - - self.state_params = [ - "dataset", - "docset_index", - "chunk_index", - "epochs_seen", - "tokens_seen", - "docs_seen", - "percent_seen", - "lcg_state", - ] - - def _get_docid(self, i): - """ - Given a global doc index over the set of docs owned by this worker, - return the corresponding data/shard/local index - """ - cur = 0 - assert ( - i <= self._len - ), f"You have requested an illegal doc index {i}, docset length is {self._len}" - for shardid, min_d, max_d in self.docset: - docrange = max_d - min_d + 1 - cur += docrange - if cur > i: - return shardid, docrange, min_d - - def _get_reader(self, path, newpath, reader): - """ - If new filepath does not match the current one, - open a new reader on that filepath (pull file on demand) - """ - if newpath != path: - del reader - if self.verbose: - logging.info(f"Worker {self.rank} opening new file {newpath}") - reader = pa.ipc.open_file(newpath) - path = newpath - return path, reader - - def _construct_chunk(self, j, doc, n_chunks): - """ - Grab a chunk of the desired size from the pyarrow document, - avoiding unnecessary overhead in case of large docs - """ - start_index = j * self.chunksize - n_pull = self.chunksize - if self.bos is not None: - if j == 0: - n_pull -= 1 - else: - start_index -= 1 - chunk = doc.slice(start_index, n_pull).to_pylist() - self.tokens_seen += len(chunk) - # Add bos/eos tokens if needed - if self.bos is not None and j == 0: - chunk = [self.bos] + chunk - if j == n_chunks - 1: - chunk = chunk + [self.eos] - return chunk - - def _random_map_docid(self, size): - """ - Given size of document pool, use saved state (prior index) to generate the next index via LCG. - Implements within-shard document shuffling without materializing any large doc lists. - """ - m = 2 ** math.ceil(math.log2(size)) # Round up to nearest power of 2 - a = 5 # A,C values known to work well with powers of 2 (Knuth, 1997, 3.2.1.3) - c = (self.rank + self.seed) * 2 + 1 - state = self.lcg_state - while True: - state = (a * state + c) % m - if state < size: - return state - - def __iter__(self): - docset_offset = self.docset_index - lcg_offset = self.lcg_state - residual_chunks = self.chunk_index + 1 # pick up AFTER where the ckp left off - ndocs = self._len - path = "" - reader = None - while True: - # Iterate through docs, starting at desired offset - for i in range(ndocs): - doc_index = (docset_offset + i) % ndocs - - # Update stats - if doc_index == 0: - self.epochs_seen += 1 - self.docset_index = doc_index - # Map doc id to shard, id in file - shardid, docrange, mindoc = self._get_docid(doc_index) - - # Read doc - newpath = os.path.join(self.data, shardid) - path, reader = self._get_reader(path, newpath, reader) - # Map id in range of owned docs to new (consistently) shuffled id - doclcg = self._random_map_docid(docrange) - docid = doclcg + mindoc - doc = reader.get_batch(docid)["tokens"] - if doc[0].as_py() in self.drop: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in self.drop: - doc = doc.slice(0, len(doc) - 1) - doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: - n_chunks = math.ceil(doclen / self.chunksize) - for j in range(n_chunks): - if i == 0 and j < residual_chunks: - pass - else: - self.chunk_index = j - # Document complete, update stats - if j == n_chunks - 1: - self.docs_seen += 1 - self.percent_seen = ( - self.docs_seen * 100 / (self._len + 1e-9) - ) - yield self._construct_chunk(j, doc, n_chunks) - - # Advance RNG state - self.lcg_state = doclcg - - # Load any chunks initially skipped in first doc - self.docset_index = docset_offset - self.lcg_state = lcg_offset - shardid, docrange, mindoc = self._get_docid(docset_offset) - docid = self._random_map_docid(docrange) + mindoc - newpath = os.path.join(self.data, shardid) - path, reader = self._get_reader(path, newpath, reader) - doc = reader.get_batch(docid)["tokens"] - if doc[0].as_py() in self.drop: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in self.drop: - doc = doc.slice(0, len(doc) - 1) - doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: - n_chunks = math.ceil(doclen / self.chunksize) - for j in range(residual_chunks): - self.chunk_index = j - yield self._construct_chunk(j, doc, n_chunks) - - def load_state_dict(self, state_dicts, sharded_input=False): - assert ( - self.load_worldsize == self.worldsize - ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." - d = self.dataset - out = super().load_state_dict(state_dicts, sharded_input) - assert ( - d == self.dataset - ), f"Dataset mismatch: checkpoint contains {self.dataset}, expected {d}" - return out - - -class Sampling_Dataset(_Stateful_Dataset): - """ - A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the - number of tokens seen from each subdataset will match those weights as closely as possible. - This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking - the number of tokens emitted by each. Whichever loader is furthest from its target will be - the next to pass a document. - - All args except for dataset_type, datasets, weights and delimiter are pass-through args for - the component _Stateful_Datasets and are documented in the appropriate classes. - ... - Args - ---- - dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset - Underlying iterator for each desired subdataset - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - datasets : list[str] | None - A list of subdatasets to draw from. If None, draws from all subfolders of datapath. - weights : list(float) | None - Weights describing what percent of emitted tokens should come from each subdataset. - Need not sum to 1. If None, tokens are drawn evenly. - ... - Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset - """ - - def __init__( - self, - datapath: str, - dataset_type: Union[ - Type["Streaming_Doc_Dataset"], - Type["Scalable_Shard_Dataset"], - ], - rank: int, - worldsize: int, - delimiter_token: Any, - datasets=None, - weights=None, - verbose=False, - **kwargs, - ): - super().__init__(rank, worldsize) - self.delimiter = delimiter_token - self.datasets = ( - datasets - if datasets is not None - else [ - f - for f in os.listdir(datapath) - if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f - ] - ) - assert len(self.datasets) > 0, "You must specify at least one dataset" - - if weights is not None: - assert len(weights) == len( - self.datasets - ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" - for w in weights: - assert w > 0, f"Sampling rate {w} must be positive" - self.weights = [1] * len(self.datasets) if weights is None else weights - self.weights = [w / sum(self.weights) for w in self.weights] - - self.tokens_seen = [0] * len(self.datasets) - - # Build subdataset iterators - self.data = [] - for i, d in enumerate(self.datasets): - self.data.append( - dataset_type( - datapath=os.path.join(datapath, d), - rank=rank, - worldsize=worldsize, - delimiter_token=delimiter_token, - verbose=verbose, - **kwargs, - ) - ) - if verbose: - logging.info( - f"Worker {rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" - ) - - self.current_iterator = -1 - self.state_params = ["tokens_seen", "current_iterator"] - - def __iter__(self): - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - if self.current_iterator != -1: - # Finish current document - out = next(data[self.current_iterator]) - self.tokens_seen[self.current_iterator] += len(out) - if out[-1] == self.delimiter: - self.current_iterator = -1 - yield out - else: - # Choose new subdataset to draw from - # (whichever is currently most underrepresented compared to target rate) - offset = [ - self.weights[i] - - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) - for i in range(len(self.datasets)) - ] - offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] - self.current_iterator = offset_argmax - - def state_dict(self): - # Manually add state of all subloaders to self state - out = { - self.statename("sample_iterator_states"): [ - d.state_dict() for d in self.data - ] - } - out.update(super().state_dict()) - return out - - def load_state_dict(self, state_dicts, sharded_input=False): - # Load stats - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Load sub-iterator states - for i, subdata in enumerate(self.data): - # Grab just that sub-iterator across all ranks - subdata.load_worldsize = self.load_worldsize - subdata.load_state_dict( - [ - sd[self.statename("sample_iterator_states")][i] - for sd in sharded_dicts - ], - True, - ) - return sharded_dicts - - -class Scalable_Shard_Dataset(_Stateful_Dataset): - """ - A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different - number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track - state individually and reshard over n_gpus. - - All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. - ... - Args - ---- - datapath : str - Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. - rank : int - Current worker index - worldsize : int - Total number of workers - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - n_logical_shards : int - Number of logical shards. Must be a multiple of world size. - ... - Pass-through args, see Streaming_Doc_Dataset - """ - - def __init__( - self, - datapath: str, - rank: int, - worldsize: int, - delimiter_token: Any, - n_logical_shards: int = 2048, - verbose=False, - **kwargs, - ): - assert ( - n_logical_shards % worldsize == 0 - ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" - assert ( - n_logical_shards > 0 - ), f"n_logical_shards {n_logical_shards} must be a positive integer" - - super().__init__(rank, worldsize) - self.data = [] - self.n_logicals = n_logical_shards // worldsize - self.total_shards = n_logical_shards - self.delimiter = delimiter_token - - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - assert len(self.logicals_owned) == self.n_logicals - - # Build logical shards - for i in range(self.n_logicals): - self.data.append( - Streaming_Doc_Dataset( - datapath=datapath, - worldsize=n_logical_shards, - rank=self.logicals_owned[i], - delimiter_token=delimiter_token, - verbose=(rank == 0), - **kwargs, - ) - ) - if verbose: - logging.info( - f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - - # Fetch logical shard sampling stats - self.n_docs_remaining = [d._len for d in self.data] - - # Position "state", used only for maintaining order when n_workers is unchanged - # For scaling up or down, logical position is meaningless, and reset - self.current_reader = None - self.logical_shard_states = None - self.generator = torch.Generator().manual_seed(self.rank) - self.g_state = None - self.state_params = ["current_reader", "g_state"] - self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - - def __iter__(self): - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - # Sample logical shard (or load from ckp) - if self.current_reader is not None: - ind = self.current_reader - else: - ind = torch.multinomial( - torch.tensor(self.n_docs_remaining, dtype=torch.float), - 1, - generator=self.generator, - ).item() - self.current_reader = ind - # Read doc - out = next(data[ind]) - while out[-1] != self.delimiter: - yield out - out = next(data[ind]) - # Update state to show we've finished the doc - self.current_reader = None - self.n_docs_remaining[ind] -= 1 - if sum(self.n_docs_remaining) == 0: - self.n_docs_remaining = [d._len for d in self.data] - self.generator.manual_seed(self.rank) - # Return final piece of doc - yield out - - def state_dict(self): - # Write generator state manually - self.g_state = self.generator.get_state() - # Recursive fetch - self.logical_shard_states = [d.state_dict() for d in self.data] - return super().state_dict() - - def load_state_dict(self, state_dicts, sharded_input=False): - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Recursive set - for i in range(self.n_logicals): - self.data[i].load_state_dict([self.logical_shard_states[i]], True) - return sharded_dicts diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index d0b99d40..97c9f14a 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -75,6 +75,7 @@ class _Stateful_Dataset(data.IterableDataset): def __init__( self, + datapath: str, rank: int, worldsize: int, ): @@ -82,8 +83,12 @@ def __init__( assert ( worldsize > rank ), f"Worldsize {worldsize} must be greater than rank {rank}" + assert datapath is None or ( + os.path.isdir(datapath) and len(os.listdir(datapath)) > 0 + ), f"Data path {datapath} must be a non-empty folder or None" self.state_params: List[str] = [] self.reshard_params: List[str] = [] + self.datapath = datapath self.rank = rank self.worldsize = worldsize self.load_worldsize = ( @@ -209,7 +214,22 @@ def __init__( dataset: _Stateful_Dataset, ): self.dataset = dataset - super().__init__(self.dataset.rank, self.dataset.worldsize) + super().__init__( + self.dataset.datapath, self.dataset.rank, self.dataset.worldsize + ) + + def setup(self): + """ + Datapath/rank/worldsize percolate upwards recursively during initialization, now + project any desired changes downward, also recursively. + """ + if not self.is_setup: + self.is_setup = True + self.dataset.datapath = self.datapath + self.dataset.rank = self.rank + self.dataset.worldsize = self.worldsize + self.dataset.load_worldsize = self.load_worldsize + self.dataset.setup() def load_state_dict(self, state_dicts, sharded_input=False): """ @@ -594,9 +614,9 @@ def __init__( max_chunksize: int = 1024, verbose: bool = False, ): - super().__init__(rank, worldsize) + super().__init__(datapath, rank, worldsize) self.seed = seed - self.data = datapath + self.datapath = datapath self.min_length = min_length assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" self.chunksize = max_chunksize @@ -641,7 +661,7 @@ def setup(self): (rank assignment, subdataset splitting, etc.) """ if not self.is_setup: - datapath = self.data + datapath = self.datapath self.is_setup = True # Gather per-file document counts from metadata count file(s) @@ -676,7 +696,9 @@ def setup(self): ] shards.sort() # Ensure consistent sharding across machines start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize - end_frag = ((self.rank + 1) * self.worldsize * len(shards)) // self.worldsize + end_frag = ( + (self.rank + 1) * self.worldsize * len(shards) + ) // self.worldsize shardfrags = [ (shards[i // self.worldsize], i % self.worldsize) for i in range(start_frag, end_frag) @@ -805,7 +827,7 @@ def __iter__(self): shardid, docrange, mindoc = self._get_docid(doc_index) # Read doc - newpath = os.path.join(self.data, shardid) + newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) # Map id in range of owned docs to new (consistently) shuffled id doclcg = self._random_map_docid(docrange) @@ -839,7 +861,7 @@ def __iter__(self): self.lcg_state = lcg_offset shardid, docrange, mindoc = self._get_docid(docset_offset) docid = self._random_map_docid(docrange) + mindoc - newpath = os.path.join(self.data, shardid) + newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = reader.get_batch(docid)["tokens"] if doc[0].as_py() in self.drop: @@ -857,7 +879,7 @@ def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( self.load_worldsize == self.worldsize - ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." + ), f"Streaming_Doc_Dataset does not support rescaling ({self.load_worldsize, self.worldsize}). Please use a Scalable_Shard_Dataset." d = self.dataset out = super().load_state_dict(state_dicts, sharded_input) assert ( @@ -866,142 +888,7 @@ def load_state_dict(self, state_dicts, sharded_input=False): return out -class Sampling_Dataset(_Stateful_Dataset): - """ - A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the - number of tokens seen from each subdataset will match those weights as closely as possible. - This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking - the number of tokens emitted by each. Whichever loader is furthest from its target will be - the next to pass a document. - - All args except for dataset_type, datasets, weights and delimiter are pass-through args for - the component _Stateful_Datasets and are documented in the appropriate classes. - ... - Args - ---- - dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset - Underlying iterator for each desired subdataset - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - datasets : list[str] | None - A list of subdatasets to draw from. If None, draws from all subfolders of datapath. - weights : list(float) | None - Weights describing what percent of emitted tokens should come from each subdataset. - Need not sum to 1. If None, tokens are drawn evenly. - ... - Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset - """ - - def __init__( - self, - datapath: str, - dataset_type: Union[ - Type["Streaming_Doc_Dataset"], - Type["Scalable_Shard_Dataset"], - ], - rank: int, - worldsize: int, - delimiter_token: Any, - datasets=None, - weights=None, - verbose=False, - **kwargs, - ): - super().__init__(rank, worldsize) - self.delimiter = delimiter_token - self.datasets = ( - datasets - if datasets is not None - else [ - f - for f in os.listdir(datapath) - if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f - ] - ) - assert len(self.datasets) > 0, "You must specify at least one dataset" - - if weights is not None: - assert len(weights) == len( - self.datasets - ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" - for w in weights: - assert w > 0, f"Sampling rate {w} must be positive" - self.weights = [1] * len(self.datasets) if weights is None else weights - self.weights = [w / sum(self.weights) for w in self.weights] - - self.tokens_seen = [0] * len(self.datasets) - - # Build subdataset iterators - self.data = [] - for i, d in enumerate(self.datasets): - self.data.append( - dataset_type( - datapath=os.path.join(datapath, d), - rank=rank, - worldsize=worldsize, - delimiter_token=delimiter_token, - verbose=verbose, - **kwargs, - ) - ) - if verbose: - logging.info( - f"Worker {rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" - ) - - self.current_iterator = -1 - self.state_params = ["tokens_seen", "current_iterator"] - - def __iter__(self): - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - if self.current_iterator != -1: - # Finish current document - out = next(data[self.current_iterator]) - self.tokens_seen[self.current_iterator] += len(out) - if out[-1] == self.delimiter: - self.current_iterator = -1 - yield out - else: - # Choose new subdataset to draw from - # (whichever is currently most underrepresented compared to target rate) - offset = [ - self.weights[i] - - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) - for i in range(len(self.datasets)) - ] - offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] - self.current_iterator = offset_argmax - - def state_dict(self): - # Manually add state of all subloaders to self state - out = { - self.statename("sample_iterator_states"): [ - d.state_dict() for d in self.data - ] - } - out.update(super().state_dict()) - return out - - def load_state_dict(self, state_dicts, sharded_input=False): - # Load stats - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Load sub-iterator states - for i, subdata in enumerate(self.data): - # Grab just that sub-iterator across all ranks - subdata.load_worldsize = self.load_worldsize - subdata.load_state_dict( - [ - sd[self.statename("sample_iterator_states")][i] - for sd in sharded_dicts - ], - True, - ) - return sharded_dicts - - -class Scalable_Shard_Dataset(_Stateful_Dataset): +class Scalable_Shard_Dataset(_Wrapper_Dataset): """ A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. @@ -1028,26 +915,21 @@ class Scalable_Shard_Dataset(_Stateful_Dataset): def __init__( self, - datapath: str, - rank: int, - worldsize: int, - delimiter_token: Any, + dataset: Streaming_Doc_Dataset, + resample_condition: Callable = lambda _: True, n_logical_shards: int = 2048, verbose=False, - **kwargs, ): + super().__init__(dataset) assert ( - n_logical_shards % worldsize == 0 - ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" + n_logical_shards % self.worldsize == 0 + ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" assert ( n_logical_shards > 0 ), f"n_logical_shards {n_logical_shards} must be a positive integer" - super().__init__(rank, worldsize) - self.datapath = datapath self.total_shards = n_logical_shards - self.delimiter = delimiter_token - self.kwargs = kwargs + self.resample = resample_condition self.verbose = verbose # Fields to be populated during setup / subdataset setup @@ -1055,19 +937,17 @@ def __init__( self.logicals_owned = [] self.n_logicals = 0 self.n_docs_remaining = [] + self.generator = None # Position "state", used only for maintaining order when n_workers is unchanged # For scaling up or down, logical position is meaningless, and reset self.current_reader = None self.logical_shard_states = None - self.generator = torch.Generator() - self.g_state = None + self.state_params = ["current_reader", "g_state"] self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - self.is_setup = False - def setup(self): if not self.is_setup: self.is_setup = True @@ -1079,26 +959,21 @@ def setup(self): # Build logical shards for i in range(self.n_logicals): - self.data.append( - Streaming_Doc_Dataset( - datapath=self.datapath, - worldsize=n_logical_shards, - rank=self.logicals_owned[i], - delimiter_token=self.delimiter, - verbose=(self.rank == 0), - **self.kwargs, - ) - ) + self.data.append(deepcopy(self.dataset)) + self.data[-1].worldsize = n_logical_shards + self.data[-1].load_worldsize = n_logical_shards + self.data[-1].rank = self.logicals_owned[i] + self.data[-1].datapath = self.datapath + self.data[-1].verbose = self.rank == 0 if self.verbose: logging.info( f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" ) - - self.generator.manual_seed(self.rank) - [d.setup() for d in self.data] self.n_docs_remaining = [d._len for d in self.data] + self.generator = torch.Generator().manual_seed(self.rank) + def __iter__(self): self.setup() # Grab one doc at a time in random order @@ -1116,7 +991,7 @@ def __iter__(self): self.current_reader = ind # Read doc out = next(data[ind]) - while out[-1] != self.delimiter: + while not self.resample(out): yield out out = next(data[ind]) # Update state to show we've finished the doc @@ -1134,11 +1009,13 @@ def state_dict(self): self.g_state = self.generator.get_state() # Recursive fetch self.logical_shard_states = [d.state_dict() for d in self.data] - return super().state_dict() + return _Stateful_Dataset.state_dict(self) def load_state_dict(self, state_dicts, sharded_input=False): self.setup() - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + sharded_dicts = _Stateful_Dataset.load_state_dict( + self, state_dicts, sharded_input + ) # Manually set generator state if it exists if self.g_state is not None: self.generator.set_state(self.g_state) @@ -1146,123 +1023,141 @@ def load_state_dict(self, state_dicts, sharded_input=False): for i in range(self.n_logicals): self.data[i].load_state_dict([self.logical_shard_states[i]], True) return sharded_dicts - - - -# class Scalable_Shard_Dataset(_Wrapper_Dataset): -# """ -# A _Wrapper_Dataset implementing rescalability: loading from checkpoint into a different -# number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. -# This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track -# state individually and reshard over n_gpus. - -# Because only one Streaming_Doc_Dataset is provided to the wrapper, the wrapper clones it, and -# all rank-dependent setup for the base datasets is deferred to the first step. - -# ... -# Args -# ---- -# dataset : Streaming_Doc_Dataset -# The worker to instantiate in each logical shard -# rank : int -# Current worker index -# worldsize : int -# Total number of workers -# delimiter_token : Any -# Token used to indicate sequence/document breaks. Type should match data type. -# n_logical_shards : int -# Number of logical shards. Must be a multiple of world size. -# """ - -# def __init__( -# self, -# dataset: Streaming_Doc_Dataset, -# rank: int, -# worldsize: int, -# n_logical_shards: int = 2048, -# verbose=False, -# ): -# assert ( -# n_logical_shards % worldsize == 0 -# ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" -# assert ( -# n_logical_shards > 0 -# ), f"n_logical_shards {n_logical_shards} must be a positive integer" - -# super().__init__(rank, worldsize) -# self.data = [] -# self.n_logicals = n_logical_shards // worldsize -# self.total_shards = n_logical_shards -# self.delimiter = dataset.delimiter_token - -# logicals = list(range(n_logical_shards)) -# self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) -# assert len(self.logicals_owned) == self.n_logicals - -# # Build logical shards -# self.data = [deepcopy(dataset) for _ in range(self.n_logicals)] -# for i,d in enumerate(self.data): -# d.worldsize = n_logical_shards -# d.rank = self.logicals_owned[i] -# d.verbose = (rank == 0) -# if verbose: -# logging.info( -# f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" -# ) - -# # Fetch logical shard sampling stats -# self.n_docs_remaining = [d._len for d in self.data] - -# # Position "state", used only for maintaining order when n_workers is unchanged -# # For scaling up or down, logical position is meaningless, and reset -# self.current_reader = None -# self.logical_shard_states = None -# self.generator = torch.Generator().manual_seed(self.rank) -# self.g_state = None -# self.state_params = ["current_reader", "g_state"] -# self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - -# def __iter__(self): -# # Grab one doc at a time in random order -# data = [iter(d) for d in self.data] -# while True: -# # Sample logical shard (or load from ckp) -# if self.current_reader is not None: -# ind = self.current_reader -# else: -# ind = torch.multinomial( -# torch.tensor(self.n_docs_remaining, dtype=torch.float), -# 1, -# generator=self.generator, -# ).item() -# self.current_reader = ind -# # Read doc -# out = next(data[ind]) -# while out[-1] != self.delimiter: -# yield out -# out = next(data[ind]) -# # Update state to show we've finished the doc -# self.current_reader = None -# self.n_docs_remaining[ind] -= 1 -# if sum(self.n_docs_remaining) == 0: -# self.n_docs_remaining = [d._len for d in self.data] -# self.generator.manual_seed(self.rank) -# # Return final piece of doc -# yield out - -# def state_dict(self): -# # Write generator state manually -# self.g_state = self.generator.get_state() -# # Recursive fetch -# self.logical_shard_states = [d.state_dict() for d in self.data] -# return _Stateful_Dataset.state_dict() - -# def load_state_dict(self, state_dicts, sharded_input=False): -# sharded_dicts = _Stateful_Dataset.load_state_dict(state_dicts, sharded_input) -# # Manually set generator state if it exists -# if self.g_state is not None: -# self.generator.set_state(self.g_state) -# # Recursive set -# for i in range(self.n_logicals): -# self.data[i].load_state_dict([self.logical_shard_states[i]], True) -# return sharded_dicts + + +class Sampling_Dataset(_Wrapper_Dataset): + """ + A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the + number of tokens seen from each subdataset will match those weights as closely as possible. + This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking + the number of tokens emitted by each. Whichever loader is furthest from its target will be + the next to pass a document. + + All args except for dataset_type, datasets, weights and delimiter are pass-through args for + the component _Stateful_Datasets and are documented in the appropriate classes. + ... + Args + ---- + dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset + Underlying iterator for each desired subdataset + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + datasets : list[str] | None + A list of subdatasets to draw from. If None, draws from all subfolders of datapath. + weights : list(float) | None + Weights describing what percent of emitted tokens should come from each subdataset. + Need not sum to 1. If None, tokens are drawn evenly. + ... + Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset + """ + + def __init__( + self, + datapath: str, + dataset: Union[Scalable_Shard_Dataset, Streaming_Doc_Dataset], + resample_condition: Callable = lambda _: True, + datasets=None, + weights=None, + verbose=False, + ): + super().__init__(dataset) + self.datapath = datapath + self.resample = resample_condition + self.verbose = verbose + self.datasets = ( + datasets + if datasets is not None + else [ + f + for f in os.listdir(datapath) + if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f + ] + ) + assert len(self.datasets) > 0, "You must specify at least one dataset" + + if weights is not None: + assert len(weights) == len( + self.datasets + ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" + for w in weights: + assert w > 0, f"Sampling rate {w} must be positive" + self.weights = [1] * len(self.datasets) if weights is None else weights + self.weights = [w / sum(self.weights) for w in self.weights] + + self.tokens_seen = [0] * len(self.datasets) + + self.current_iterator = -1 + self.state_params = ["tokens_seen", "current_iterator"] + + def setup(self): + if not self.is_setup: + self.is_setup = True + # Build subdataset iterators + self.data = [] + for i, d in enumerate(self.datasets): + self.data.append(deepcopy(self.dataset)) + self.data[-1].datapath = os.path.join(self.datapath, d) + for flag in ["rank", "worldsize", "load_worldsize"]: + setattr(self.data[-1], flag, getattr(self, flag)) + if self.verbose: + logging.info( + f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" + ) + [d.setup() for d in self.data] + + def __iter__(self): + self.setup() + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + if self.current_iterator != -1: + # Finish current document + out = next(data[self.current_iterator]) + self.tokens_seen[self.current_iterator] += len(out) + if self.resample(out): + self.current_iterator = -1 + yield out + else: + # Choose new subdataset to draw from + # (whichever is currently most underrepresented compared to target rate) + offset = [ + self.weights[i] + - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) + for i in range(len(self.datasets)) + ] + offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] + self.current_iterator = offset_argmax + + def state_dict(self): + self.setup() + # Manually add state of all subloaders to self state + out = { + self.statename("sample_iterator_states"): [ + d.state_dict() for d in self.data + ] + } + out.update(_Stateful_Dataset.state_dict(self)) + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() + # Load stats + sharded_dicts = _Stateful_Dataset.load_state_dict( + self, state_dicts, sharded_input + ) + # Load sub-iterator states + for i, subdata in enumerate(self.data): + # Grab just that sub-iterator across all ranks + subdata.load_worldsize = self.load_worldsize + subdata.load_state_dict( + [ + sd[self.statename("sample_iterator_states")][i] + for sd in sharded_dicts + ], + True, + ) + return sharded_dicts + + +def delimiter_condition(delimiter): + return lambda x: x[-1] == delimiter diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 576b3304..85063fe8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -82,7 +82,7 @@ def multi_reload_stress_check(d): # d is a lambda for a fully-defined dataset (i.e. d() instantiates the dataset) def reload_stress(datasets, datasets2, steps1, steps2): - # Perform the 5-step reload stress test (see test_reload_stress_all) + # Perform the 5-step reload stress test (see test_multi_reload_stress) loaders = [iter(d) for d in datasets] @@ -107,7 +107,7 @@ def reload_stress(datasets, datasets2, steps1, steps2): out1[j] == out2[j] ), f"Dataloader {i} in step {k} has mismatched token in position {j}: {out1[j]} vs {out2[j]}" - steps1 = [1, 10, 100, 1000] + steps1 = [0, 1, 10, 100, 1000] steps2 = [100, 200, 300, 400, 500] for i in range(len(steps1)): # Reset between tests (instantiate fresh datasets) @@ -401,13 +401,10 @@ def basic_sampler( ): return Sampling_Dataset( tmpdir.name, - Streaming_Doc_Dataset, - rank, - worldsize, - -1, - datasets=datasets, - weights=weights, - max_chunksize=max_chunksize, + basic_loader(rank, worldsize, datasets[:1], max_chunksize, None), + delimiter_condition(-1), + datasets, + weights, ) @@ -421,17 +418,13 @@ def basic_scalable( ): assert len(datasets) == 1, "Basic loader takes only 1 dataset" return Scalable_Shard_Dataset( - os.path.join(tmpdir.name, datasets[0]), - rank, - worldsize, - -1, - max_chunksize=max_chunksize, - n_logical_shards=n_logical_shards, - bos_token=bos_token, + basic_loader(rank, worldsize, datasets, max_chunksize, bos_token), + delimiter_condition(-1), + n_logical_shards, ) -def basic_scalable_sampler( +def basic_sampler_scalable( rank=0, worldsize=1, datasets=["dataset_1"], @@ -441,14 +434,12 @@ def basic_scalable_sampler( ): return Sampling_Dataset( tmpdir.name, - Scalable_Shard_Dataset, - rank, - worldsize, - -1, - datasets=datasets, - weights=weights, - max_chunksize=max_chunksize, - n_logical_shards=n_logical_shards, + basic_scalable( + rank, worldsize, datasets[:1], max_chunksize, n_logical_shards, None + ), + delimiter_condition(-1), + datasets, + weights, ) @@ -457,7 +448,7 @@ def test_single_epoch(): single_epoch_check(basic_loader, True) single_epoch_check(basic_scalable) single_epoch_check(basic_sampler) - single_epoch_check(basic_scalable_sampler) + single_epoch_check(basic_sampler_scalable) def test_two_epoch(): @@ -465,7 +456,7 @@ def test_two_epoch(): two_epoch_check(basic_loader, True) two_epoch_check(basic_scalable) two_epoch_check(basic_sampler) - two_epoch_check(basic_scalable_sampler) + two_epoch_check(basic_sampler_scalable) def test_chunk(): @@ -473,7 +464,7 @@ def test_chunk(): chunk_check(functools.partial(basic_loader, max_chunksize=50), True) chunk_check(functools.partial(basic_scalable, max_chunksize=50)) chunk_check(functools.partial(basic_sampler, max_chunksize=50)) - chunk_check(functools.partial(basic_scalable_sampler, max_chunksize=50)) + chunk_check(functools.partial(basic_sampler_scalable, max_chunksize=50)) def test_two_loader(): @@ -481,7 +472,7 @@ def test_two_loader(): two_loader_check(basic_loader, True) two_loader_check(functools.partial(basic_scalable, n_logical_shards=8)) two_loader_check(basic_sampler) - two_loader_check(functools.partial(basic_scalable_sampler, n_logical_shards=8)) + two_loader_check(functools.partial(basic_sampler_scalable, n_logical_shards=8)) def test_multi_file(): @@ -489,7 +480,7 @@ def test_multi_file(): multi_file_check(basic_loader, True) multi_file_check(basic_scalable) multi_file_check(basic_sampler) - multi_file_check(basic_scalable_sampler) + multi_file_check(basic_sampler_scalable) def test_reload_epoch(): @@ -497,7 +488,7 @@ def test_reload_epoch(): reload_epoch_check(basic_loader) reload_epoch_check(functools.partial(basic_scalable, n_logical_shards=8)) reload_epoch_check(basic_sampler) - reload_epoch_check(functools.partial(basic_scalable_sampler, n_logical_shards=8)) + reload_epoch_check(functools.partial(basic_sampler_scalable, n_logical_shards=8)) def test_reload_complete_epoch(): @@ -505,7 +496,9 @@ def test_reload_complete_epoch(): reload_single_epoch_check(basic_loader) reload_single_epoch_check(functools.partial(basic_scalable, n_logical_shards=8)) reload_single_epoch_check(basic_sampler) - reload_single_epoch_check(functools.partial(basic_scalable_sampler, n_logical_shards=8)) + reload_single_epoch_check( + functools.partial(basic_sampler_scalable, n_logical_shards=8) + ) def test_eos_bos_chunking(): @@ -555,7 +548,7 @@ def check_rates(w, t, b, m): ), f"Output {i} length {len(out)} does not match expected 51. Sequence so far: {s}" for i in range(3): - for m in [basic_sampler, basic_scalable_sampler]: + for m in [basic_sampler, basic_sampler_scalable]: check_rates(weights[i], target_rate[i], burnin[i], m) @@ -586,63 +579,38 @@ def test_multi_reload_stress(): ] multi_reload_stress_check(d1) - # Sampling dataset - d2 = lambda: [ - Sampling_Dataset( - tmpdir.name, - Streaming_Doc_Dataset, - i, - 3, - -1, - datasets=["dataset_1", "dataset_2"], - weights=[3, 5], - max_chunksize=17, - ) - for i in range(3) - ] - multi_reload_stress_check(d2) - # Scalable shard dataset - d3 = lambda: [ - Scalable_Shard_Dataset( - os.path.join(tmpdir.name, "dataset_2"), - i, - 3, - -1, - n_logical_shards=15, - max_chunksize=17, - ) - for i in range(3) + d2 = lambda x: [ + Scalable_Shard_Dataset(d, delimiter_condition(-1), n_logical_shards=15) + for d in x ] - multi_reload_stress_check(d3) + multi_reload_stress_check(lambda: d2(d1())) - # Nested scalable sampling dataset - d4 = lambda: [ + # Sampling dataset + d3 = lambda x: [ Sampling_Dataset( tmpdir.name, - Scalable_Shard_Dataset, - i, - 3, - -1, - n_logical_shards=15, + d, + delimiter_condition(-1), datasets=["dataset_1", "dataset_2"], weights=[3, 5], - max_chunksize=17, ) - for i in range(3) + for d in x ] + multi_reload_stress_check(lambda: d3(d1())) + + # Nested scalable sampling dataset + d4 = lambda: d3(d2(d1())) multi_reload_stress_check(d4) # Add buffer dataset d5 = lambda x: [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in x] - # Stress test buffer + all bases - for d in [d1, d2, d3, d4]: - multi_reload_stress_check(lambda: d5(d())) + multi_reload_stress_check(lambda: d5(d4())) # Add preload buffer dataset - d7 = lambda x: [Preload_Buffer_Dataset(d, 99) for d in x] + d6 = lambda x: [Preload_Buffer_Dataset(d, 99) for d in x] # preload / sample / scale / doc pipeline - multi_reload_stress_check(lambda: d7(d5(d4()))) + multi_reload_stress_check(lambda: d6(d5(d4()))) # SCALABLE_DATASET TESTS @@ -654,24 +622,12 @@ def test_scalable_partitioning(): physical worker count. Start with 4 workers with 12 logical shards, and for each of [1,2,3,6,12], verify that: 1) no overlap exists between workers and 2) in over one epoch's worth of steps, each data point appears at least once """ - for layer in [Scalable_Shard_Dataset, Sampling_Dataset]: - kwargs = { - "n_logical_shards": 12, - "max_chunksize": 200, - "worldsize": 4, - "delimiter_token": -1, - } - src = ( - tmpdir.name - if layer == Sampling_Dataset - else os.path.join(tmpdir.name, "dataset_1") - ) - datasets = [ - layer(src, Scalable_Shard_Dataset, i, datasets=["dataset_1"], **kwargs) - if layer == Sampling_Dataset - else layer(src, i, **kwargs) - for i in range(4) - ] # 25 steps per epoch + l1 = lambda r, w: basic_scalable(r, w, max_chunksize=200, n_logical_shards=12) + l2 = lambda r, w: basic_sampler_scalable( + r, w, max_chunksize=200, n_logical_shards=12 + ) + for layer in [l1, l2]: + datasets = [layer(i, 4) for i in range(4)] # 25 steps per epoch loaders = [iter(d) for d in datasets] for _ in range(50): @@ -679,30 +635,8 @@ def test_scalable_partitioning(): states = [d.state_dict() for d in datasets] - kwargs = { - "n_logical_shards": 12, - "max_chunksize": 200, - "delimiter_token": -1, - } for worldsize in [1, 2, 3, 6, 12]: - datasets = [ - layer( - src, - Scalable_Shard_Dataset, - i, - worldsize, - datasets=["dataset_1"], - **kwargs, - ) - if layer == Sampling_Dataset - else layer( - src, - i, - worldsize, - **kwargs, - ) - for i in range(worldsize) - ] + datasets = [layer(i, worldsize) for i in range(worldsize)] [d.load_state_dict(states) for d in datasets] loaders = [iter(d) for d in datasets] outs = [[] for _ in datasets] @@ -737,15 +671,7 @@ def test_scalable_shard_reload_scale(): Because logical shards won't all be the exact same length when checkpointed, we complete the epoch of the shortest of the new workers. """ datasets = [ - Scalable_Shard_Dataset( - os.path.join(tmpdir.name, "dataset_1"), - i, - 2, - -1, - n_logical_shards=8, - max_chunksize=40, - ) - for i in range(2) + basic_scalable(i, 2, max_chunksize=40, n_logical_shards=8) for i in range(2) ] # Length 300 loaders = [iter(d) for d in datasets] @@ -760,15 +686,7 @@ def test_scalable_shard_reload_scale(): states = [d.state_dict() for d in datasets] datasets2 = [ - Scalable_Shard_Dataset( - os.path.join(tmpdir.name, "dataset_1"), - i, - 4, - -1, - n_logical_shards=8, - max_chunksize=40, - ) - for i in range(4) + basic_scalable(i, 4, max_chunksize=40, n_logical_shards=8) for i in range(4) ] # Length 300 [d.load_state_dict(states) for d in datasets2] ndocs = [sum(d.n_docs_remaining) for d in datasets] @@ -793,17 +711,7 @@ def test_scalable_sampler_reload_scale(): Because logical shards and sampling ratios won't be exact, take a few extra steps then check that epoch is complete. """ datasets = [ - Sampling_Dataset( - tmpdir.name, - Scalable_Shard_Dataset, - i, - 2, - -1, - n_logical_shards=8, - datasets=["dataset_1"], - weights=[1], - max_chunksize=40, - ) + basic_sampler_scalable(i, 2, max_chunksize=40, n_logical_shards=8) for i in range(2) ] # Length 300 loaders = [iter(d) for d in datasets] @@ -819,17 +727,7 @@ def test_scalable_sampler_reload_scale(): states = [d.state_dict() for d in datasets] datasets2 = [ - Sampling_Dataset( - tmpdir.name, - Scalable_Shard_Dataset, - i, - 4, - -1, - n_logical_shards=8, - datasets=["dataset_1"], - weights=[1], - max_chunksize=40, - ) + basic_sampler_scalable(i, 4, max_chunksize=40, n_logical_shards=8) for i in range(4) ] # Length 300 [d.load_state_dict(states) for d in datasets2] From bad408b6cfb5d5590d635611b246f4304f5f46c5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 24 Jul 2024 17:44:56 -0400 Subject: [PATCH 34/73] Restore delimiter over condition logic - lambda not picklable --- fms_fsdp/utils/dataset_utils_v3.py | 15 +++++-------- tests/test_datasets.py | 36 ++++++++---------------------- 2 files changed, 15 insertions(+), 36 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index 97c9f14a..af8d4903 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -916,7 +916,7 @@ class Scalable_Shard_Dataset(_Wrapper_Dataset): def __init__( self, dataset: Streaming_Doc_Dataset, - resample_condition: Callable = lambda _: True, + delimiter_token: Any, n_logical_shards: int = 2048, verbose=False, ): @@ -929,7 +929,7 @@ def __init__( ), f"n_logical_shards {n_logical_shards} must be a positive integer" self.total_shards = n_logical_shards - self.resample = resample_condition + self.delimiter = delimiter_token self.verbose = verbose # Fields to be populated during setup / subdataset setup @@ -991,7 +991,7 @@ def __iter__(self): self.current_reader = ind # Read doc out = next(data[ind]) - while not self.resample(out): + while out[-1] != self.delimiter: yield out out = next(data[ind]) # Update state to show we've finished the doc @@ -1055,14 +1055,14 @@ def __init__( self, datapath: str, dataset: Union[Scalable_Shard_Dataset, Streaming_Doc_Dataset], - resample_condition: Callable = lambda _: True, + delimiter_token: Any, datasets=None, weights=None, verbose=False, ): super().__init__(dataset) self.datapath = datapath - self.resample = resample_condition + self.delimiter = delimiter_token self.verbose = verbose self.datasets = ( datasets @@ -1114,7 +1114,7 @@ def __iter__(self): # Finish current document out = next(data[self.current_iterator]) self.tokens_seen[self.current_iterator] += len(out) - if self.resample(out): + if out[-1] == self.delimiter: self.current_iterator = -1 yield out else: @@ -1158,6 +1158,3 @@ def load_state_dict(self, state_dicts, sharded_input=False): ) return sharded_dicts - -def delimiter_condition(delimiter): - return lambda x: x[-1] == delimiter diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 85063fe8..2bdd064d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -402,7 +402,7 @@ def basic_sampler( return Sampling_Dataset( tmpdir.name, basic_loader(rank, worldsize, datasets[:1], max_chunksize, None), - delimiter_condition(-1), + -1, datasets, weights, ) @@ -419,7 +419,7 @@ def basic_scalable( assert len(datasets) == 1, "Basic loader takes only 1 dataset" return Scalable_Shard_Dataset( basic_loader(rank, worldsize, datasets, max_chunksize, bos_token), - delimiter_condition(-1), + -1, n_logical_shards, ) @@ -437,7 +437,7 @@ def basic_sampler_scalable( basic_scalable( rank, worldsize, datasets[:1], max_chunksize, n_logical_shards, None ), - delimiter_condition(-1), + -1, datasets, weights, ) @@ -581,7 +581,7 @@ def test_multi_reload_stress(): # Scalable shard dataset d2 = lambda x: [ - Scalable_Shard_Dataset(d, delimiter_condition(-1), n_logical_shards=15) + Scalable_Shard_Dataset(d, -1, n_logical_shards=15) for d in x ] multi_reload_stress_check(lambda: d2(d1())) @@ -591,7 +591,7 @@ def test_multi_reload_stress(): Sampling_Dataset( tmpdir.name, d, - delimiter_condition(-1), + -1, datasets=["dataset_1", "dataset_2"], weights=[3, 5], ) @@ -754,6 +754,7 @@ def __init__(self): self.i = 0 self.rank = 0 self.worldsize = 1 + self.datapath = tmpdir.name def __iter__(self): while True: @@ -840,6 +841,7 @@ def __init__(self, l): self.i = 0 self.rank = 0 self.worldsize = 1 + self.datapath = tmpdir.name self.l = l def __iter__(self): @@ -873,17 +875,7 @@ def test_checkpoint_reload_match(): Check that the auto-checkpointer saves and loads correctly, and that loaded checkpoints resume properly (matching the continued behavior of the saved ones) """ - datasets = [ - Sampling_Dataset( - tmpdir.name, - Streaming_Doc_Dataset, - i, - 3, - -1, - datasets=["dataset_1", "dataset_2"], - weights=[3, 5], - max_chunksize=17, - ) + datasets = [basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3,5], max_chunksize=17) for i in range(3) ] datasets = [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets] @@ -913,17 +905,7 @@ def test_checkpoint_reload_match(): ), f"Expected three checkpoint shards (found {len(ckp_shards)})" # Create a second loader, pointing to first's checkpoint - datasets2 = [ - Sampling_Dataset( - tmpdir.name, - Streaming_Doc_Dataset, - i, - 3, - -1, - datasets=["dataset_1", "dataset_2"], - weights=[3, 5], - max_chunksize=17, - ) + datasets2 = [basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3,5], max_chunksize=17) for i in range(3) ] datasets2 = [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets2] From e344cb7991b90af952c40866c553c8db7c9931bb Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 24 Jul 2024 17:46:35 -0400 Subject: [PATCH 35/73] Type hints --- fms_fsdp/utils/dataset_utils_v3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index af8d4903..b0360dda 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -933,10 +933,10 @@ def __init__( self.verbose = verbose # Fields to be populated during setup / subdataset setup - self.data = [] - self.logicals_owned = [] + self.data: List[Streaming_Doc_Dataset] = [] + self.logicals_owned: List[int] = [] self.n_logicals = 0 - self.n_docs_remaining = [] + self.n_docs_remaining: List[int] = [] self.generator = None # Position "state", used only for maintaining order when n_workers is unchanged From 02d1e72df0dd2fc4482a72bb934f4a20d40da62a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 24 Jul 2024 17:51:27 -0400 Subject: [PATCH 36/73] Linting --- fms_fsdp/utils/dataset_utils_v3.py | 1 - tests/test_datasets.py | 11 +++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py index b0360dda..bcc24887 100644 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ b/fms_fsdp/utils/dataset_utils_v3.py @@ -1157,4 +1157,3 @@ def load_state_dict(self, state_dicts, sharded_input=False): True, ) return sharded_dicts - diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 2bdd064d..33d47d3f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -580,10 +580,7 @@ def test_multi_reload_stress(): multi_reload_stress_check(d1) # Scalable shard dataset - d2 = lambda x: [ - Scalable_Shard_Dataset(d, -1, n_logical_shards=15) - for d in x - ] + d2 = lambda x: [Scalable_Shard_Dataset(d, -1, n_logical_shards=15) for d in x] multi_reload_stress_check(lambda: d2(d1())) # Sampling dataset @@ -875,7 +872,8 @@ def test_checkpoint_reload_match(): Check that the auto-checkpointer saves and loads correctly, and that loaded checkpoints resume properly (matching the continued behavior of the saved ones) """ - datasets = [basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3,5], max_chunksize=17) + datasets = [ + basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17) for i in range(3) ] datasets = [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets] @@ -905,7 +903,8 @@ def test_checkpoint_reload_match(): ), f"Expected three checkpoint shards (found {len(ckp_shards)})" # Create a second loader, pointing to first's checkpoint - datasets2 = [basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3,5], max_chunksize=17) + datasets2 = [ + basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17) for i in range(3) ] datasets2 = [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets2] From 0be1dea95924671ca45a933c5bf3c0c05779ad51 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 25 Jul 2024 13:24:14 -0400 Subject: [PATCH 37/73] Overwrite old dataset file, update tests and constructor --- fms_fsdp/utils/dataloader_utils.py | 20 +- fms_fsdp/utils/dataset_utils.py | 550 +++++++------ fms_fsdp/utils/dataset_utils_v3.py | 1159 ---------------------------- tests/test_datasets.py | 2 +- 4 files changed, 323 insertions(+), 1408 deletions(-) delete mode 100644 fms_fsdp/utils/dataset_utils_v3.py diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 565f6223..50026456 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -7,6 +7,7 @@ Preprocess_Dataset, Sampling_Dataset, Scalable_Shard_Dataset, + Streaming_Doc_Dataset, ) @@ -68,20 +69,31 @@ def causal_lm(data_seq, prompt_len=0): int(x.strip()) for x in cfg.strip_tokens.split(",") if len(x.strip()) > 0 ] droplist = droplist + [cfg.bos_token, cfg.eos_token, cfg.bol_token, cfg.eol_token] - data = Sampling_Dataset( + # Base reader layer + data = Streaming_Doc_Dataset( cfg.data_path, - Scalable_Shard_Dataset, rank, world_size, cfg.eos_token, bos_token=cfg.bos_token, strip_tokens=set(droplist), min_length=3, + seed=cfg.seed, + ) + # Add rescaling/resharding + data = Scalable_Shard_Dataset( + data, + cfg.eos_token, + n_logical_shards=cfg.logical_shards, + ) + # Add multi-dataset handling + data = Sampling_Dataset( + cfg.data_path, + data, + cfg.eos_token, datasets=datasets, weights=weights, - seed=cfg.seed, verbose=(rank == 0), - n_logical_shards=cfg.logical_shards, ) # Wrap above dataset in packing logic to form constant-length lines. data = Buffer_Dataset( diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 08e47eba..bcc24887 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -4,6 +4,7 @@ import os import random import time +from copy import deepcopy from typing import Any, Callable, List, Optional, Set, Type, Union import pyarrow as pa @@ -74,6 +75,7 @@ class _Stateful_Dataset(data.IterableDataset): def __init__( self, + datapath: str, rank: int, worldsize: int, ): @@ -81,13 +83,23 @@ def __init__( assert ( worldsize > rank ), f"Worldsize {worldsize} must be greater than rank {rank}" + assert datapath is None or ( + os.path.isdir(datapath) and len(os.listdir(datapath)) > 0 + ), f"Data path {datapath} must be a non-empty folder or None" self.state_params: List[str] = [] self.reshard_params: List[str] = [] + self.datapath = datapath self.rank = rank self.worldsize = worldsize self.load_worldsize = ( worldsize # Enable calling load_state_dict() directly, assume no rescaling ) + self.is_setup = False + + def setup(self): + if not self.is_setup: + self.is_setup = True + pass def statename(self, x: str): # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline @@ -95,8 +107,10 @@ def statename(self, x: str): def state_dict(self): """ - Retrieve all state and reshard flags (each worker/process saves its own state dict shard) + Retrieve all state and reshard flags (each worker/process saves its own state dict shard). + On the off chance that you're saving a checkpoint with zero steps, run setup first. """ + self.setup() return { self.statename(flag): getattr(self, flag) for flag in self.state_params + self.reshard_params @@ -135,14 +149,16 @@ def load_state_dict(self, state_dicts, sharded_input=False): global list of states across all checkpoint shard files. If sharded_input=True, this expects _shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading. Workflow: - 1. if sharded_inputs is false, shard the inputs. - 2. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint + 1. Run setup to prepare dataset + 2. if sharded_inputs is false, shard the inputs. + 3. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint shard (state_dicts is a singleton list). - 3. If worldsize does not match checkpoint, toss state params and assemble reshard params from + 4. If worldsize does not match checkpoint, toss state params and assemble reshard params from across given state_dicts. In this case state_dicts may be singleton (for fractional ownership) or multi-element (for multiple/partitioned ownership). - 4. Return reduced input for use by downstream loading functions + 5. Return reduced input for use by downstream loading functions """ + self.setup() if not sharded_input: self.load_worldsize = len(state_dicts) state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) @@ -198,12 +214,28 @@ def __init__( dataset: _Stateful_Dataset, ): self.dataset = dataset - super().__init__(self.dataset.rank, self.dataset.worldsize) + super().__init__( + self.dataset.datapath, self.dataset.rank, self.dataset.worldsize + ) + + def setup(self): + """ + Datapath/rank/worldsize percolate upwards recursively during initialization, now + project any desired changes downward, also recursively. + """ + if not self.is_setup: + self.is_setup = True + self.dataset.datapath = self.datapath + self.dataset.rank = self.rank + self.dataset.worldsize = self.worldsize + self.dataset.load_worldsize = self.load_worldsize + self.dataset.setup() def load_state_dict(self, state_dicts, sharded_input=False): """ Sets all specified flags at the current level, then recurses into wrapped dataset. """ + self.setup() sharded_dicts = super().load_state_dict(state_dicts, sharded_input) self.dataset.load_worldsize = self.load_worldsize self.dataset.load_state_dict(sharded_dicts, True) @@ -214,6 +246,7 @@ def state_dict(self): Fetches state dict recursively from wrapped layers, then adds specified flags. Overlapping flags are overwritten with a warning. """ + self.setup() out = self.dataset.state_dict() state = super().state_dict() for flag in self.state_params + self.reshard_params: @@ -580,11 +613,10 @@ def __init__( min_length: int = 1, max_chunksize: int = 1024, verbose: bool = False, - shuffle: bool = True, ): - super(Streaming_Doc_Dataset, self).__init__(rank, worldsize) + super().__init__(datapath, rank, worldsize) self.seed = seed - self.data = datapath + self.datapath = datapath self.min_length = min_length assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" self.chunksize = max_chunksize @@ -595,82 +627,8 @@ def __init__( self.docset: List[ Any ] = [] # map of doc indices to (shardid, min docid, max docid) - self.docs_per_shard = {} - - # Guaranteed inconsistent shuffling across workers - random.seed(self.seed + rank) - - # Gather per-file document counts from metadata count file(s) - countfiles = [ - x - for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) - if "counts" in x and "csv" in x - ] - assert len(countfiles) == 1 - doc_counts = {} - pathsplit = (datapath, "") - while len(pathsplit[1]) == 0: - pathsplit = os.path.split(pathsplit[0]) - pardir, dataset = pathsplit - self.dataset = dataset - with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: - key = fullpath[prefix:] - doc_counts[key] = int(row["documents"]) - - # Assemble document set owned by this worker: - # listdir, assemble shardfraglist (ind -> shard, frag) - shards = [ - shard - for shard in os.listdir(datapath) - if os.path.isfile(os.path.join(datapath, shard)) - and "arrow" in os.path.join(datapath, shard) - ] - shards.sort() # Ensure consistent sharding across machines - start_frag = (rank * worldsize * len(shards)) // worldsize - end_frag = ((rank + 1) * worldsize * len(shards)) // worldsize - shardfrags = [ - (shards[i // worldsize], i % worldsize) for i in range(start_frag, end_frag) - ] - - # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): - ndocs = -1 - docset = {} # shardid -> (min docid, max docid) - for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[os.path.join(dataset, shard)] - self.docs_per_shard[shard] = ndocs - doc_start = (ndocs * frag) // worldsize - doc_end = (ndocs * frag + ndocs) // worldsize - 1 # Inclusive upper bound - if shard not in docset: - docset[shard] = [doc_start, doc_end] - min_d, max_d = docset[shard] - if doc_start < min_d: - docset[shard][0] = doc_start - if doc_end > max_d: - docset[shard][1] = doc_end - - # Add all of this dataset's shard entries to self.docset - doccount = 0 - for shardid in docset: - min_d = docset[shardid][0] - max_d = docset[shardid][1] - self.docset.append((shardid, min_d, max_d)) - doccount += max_d - min_d + 1 - self._len = doccount - - if verbose: - logging.info( - f" Worker {rank} ingested {len(shardfrags)} shard fragments from {dataset}" - ) - - # Shuffle shard files - if shuffle: - random.shuffle(self.docset) + # Position self.docset_index = 0 self.chunk_index = -1 @@ -679,7 +637,6 @@ def __init__( self.tokens_seen = 0 self.docs_seen = 0 self.percent_seen = 0 - self.lcg_state = seed + rank self.state_params = [ "dataset", @@ -692,6 +649,99 @@ def __init__( "lcg_state", ] + # Setup flags + self.is_setup = False + self._len = 0 + self.dataset = "" + self.lcg_state = 0 + + def setup(self): + """ + All rank-dependent setup, which must occur after init + (rank assignment, subdataset splitting, etc.) + """ + if not self.is_setup: + datapath = self.datapath + self.is_setup = True + + # Gather per-file document counts from metadata count file(s) + countfiles = [ + x + for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) + if "counts" in x and "csv" in x + ] + assert len(countfiles) == 1 + doc_counts = {} + pathsplit = (datapath, "") + while len(pathsplit[1]) == 0: + pathsplit = os.path.split(pathsplit[0]) + pardir, dataset = pathsplit + self.dataset = dataset + with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find("/" + dataset) + 1 + if prefix > 0: + key = fullpath[prefix:] + doc_counts[key] = int(row["documents"]) + + # Assemble document set owned by this worker: + # listdir, assemble shardfraglist (ind -> shard, frag) + shards = [ + shard + for shard in os.listdir(datapath) + if os.path.isfile(os.path.join(datapath, shard)) + and "arrow" in os.path.join(datapath, shard) + ] + shards.sort() # Ensure consistent sharding across machines + start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize + end_frag = ( + (self.rank + 1) * self.worldsize * len(shards) + ) // self.worldsize + shardfrags = [ + (shards[i // self.worldsize], i % self.worldsize) + for i in range(start_frag, end_frag) + ] + + # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): + ndocs = -1 + docset = {} # shardid -> (min docid, max docid) + for i, (shard, frag) in enumerate(shardfrags): + ndocs = doc_counts[os.path.join(dataset, shard)] + doc_start = (ndocs * frag) // self.worldsize + doc_end = ( + ndocs * frag + ndocs + ) // self.worldsize - 1 # Inclusive upper bound + if shard not in docset: + docset[shard] = [doc_start, doc_end] + min_d, max_d = docset[shard] + if doc_start < min_d: + docset[shard][0] = doc_start + if doc_end > max_d: + docset[shard][1] = doc_end + + # Add all of this dataset's shard entries to self.docset + doccount = 0 + for shardid in docset: + min_d = docset[shardid][0] + max_d = docset[shardid][1] + self.docset.append((shardid, min_d, max_d)) + doccount += max_d - min_d + 1 + self._len = doccount + + if self.verbose: + logging.info( + f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" + ) + + # Shuffle shard files - guaranteed inconsistent across workers + seed = self.seed + self.rank + random.seed(seed) + random.shuffle(self.docset) + # Setup doc shuffle - same guarantee + self.lcg_state = seed + def _get_docid(self, i): """ Given a global doc index over the set of docs owned by this worker, @@ -756,6 +806,8 @@ def _random_map_docid(self, size): return state def __iter__(self): + if not self.is_setup: + self.setup() docset_offset = self.docset_index lcg_offset = self.lcg_state residual_chunks = self.chunk_index + 1 # pick up AFTER where the ckp left off @@ -775,7 +827,7 @@ def __iter__(self): shardid, docrange, mindoc = self._get_docid(doc_index) # Read doc - newpath = os.path.join(self.data, shardid) + newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) # Map id in range of owned docs to new (consistently) shuffled id doclcg = self._random_map_docid(docrange) @@ -809,7 +861,7 @@ def __iter__(self): self.lcg_state = lcg_offset shardid, docrange, mindoc = self._get_docid(docset_offset) docid = self._random_map_docid(docrange) + mindoc - newpath = os.path.join(self.data, shardid) + newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = reader.get_batch(docid)["tokens"] if doc[0].as_py() in self.drop: @@ -824,9 +876,10 @@ def __iter__(self): yield self._construct_chunk(j, doc, n_chunks) def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() assert ( self.load_worldsize == self.worldsize - ), "Streaming_Doc_Dataset does not support rescaling. Please use a Scalable_Shard_Dataset." + ), f"Streaming_Doc_Dataset does not support rescaling ({self.load_worldsize, self.worldsize}). Please use a Scalable_Shard_Dataset." d = self.dataset out = super().load_state_dict(state_dicts, sharded_input) assert ( @@ -835,7 +888,144 @@ def load_state_dict(self, state_dicts, sharded_input=False): return out -class Sampling_Dataset(_Stateful_Dataset): +class Scalable_Shard_Dataset(_Wrapper_Dataset): + """ + A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track + state individually and reshard over n_gpus. + + All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. + rank : int + Current worker index + worldsize : int + Total number of workers + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + n_logical_shards : int + Number of logical shards. Must be a multiple of world size. + ... + Pass-through args, see Streaming_Doc_Dataset + """ + + def __init__( + self, + dataset: Streaming_Doc_Dataset, + delimiter_token: Any, + n_logical_shards: int = 2048, + verbose=False, + ): + super().__init__(dataset) + assert ( + n_logical_shards % self.worldsize == 0 + ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert ( + n_logical_shards > 0 + ), f"n_logical_shards {n_logical_shards} must be a positive integer" + + self.total_shards = n_logical_shards + self.delimiter = delimiter_token + self.verbose = verbose + + # Fields to be populated during setup / subdataset setup + self.data: List[Streaming_Doc_Dataset] = [] + self.logicals_owned: List[int] = [] + self.n_logicals = 0 + self.n_docs_remaining: List[int] = [] + self.generator = None + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = None + self.logical_shard_states = None + self.g_state = None + + self.state_params = ["current_reader", "g_state"] + self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + + def setup(self): + if not self.is_setup: + self.is_setup = True + n_logical_shards = self.total_shards + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + self.n_logicals = n_logical_shards // self.worldsize + assert len(self.logicals_owned) == self.n_logicals + + # Build logical shards + for i in range(self.n_logicals): + self.data.append(deepcopy(self.dataset)) + self.data[-1].worldsize = n_logical_shards + self.data[-1].load_worldsize = n_logical_shards + self.data[-1].rank = self.logicals_owned[i] + self.data[-1].datapath = self.datapath + self.data[-1].verbose = self.rank == 0 + if self.verbose: + logging.info( + f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + [d.setup() for d in self.data] + self.n_docs_remaining = [d._len for d in self.data] + + self.generator = torch.Generator().manual_seed(self.rank) + + def __iter__(self): + self.setup() + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + # Sample logical shard (or load from ckp) + if self.current_reader is not None: + ind = self.current_reader + else: + ind = torch.multinomial( + torch.tensor(self.n_docs_remaining, dtype=torch.float), + 1, + generator=self.generator, + ).item() + self.current_reader = ind + # Read doc + out = next(data[ind]) + while out[-1] != self.delimiter: + yield out + out = next(data[ind]) + # Update state to show we've finished the doc + self.current_reader = None + self.n_docs_remaining[ind] -= 1 + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) + # Return final piece of doc + yield out + + def state_dict(self): + self.setup() + # Write generator state manually + self.g_state = self.generator.get_state() + # Recursive fetch + self.logical_shard_states = [d.state_dict() for d in self.data] + return _Stateful_Dataset.state_dict(self) + + def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() + sharded_dicts = _Stateful_Dataset.load_state_dict( + self, state_dicts, sharded_input + ) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Recursive set + for i in range(self.n_logicals): + self.data[i].load_state_dict([self.logical_shard_states[i]], True) + return sharded_dicts + + +class Sampling_Dataset(_Wrapper_Dataset): """ A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the number of tokens seen from each subdataset will match those weights as closely as possible. @@ -864,20 +1054,16 @@ class Sampling_Dataset(_Stateful_Dataset): def __init__( self, datapath: str, - dataset_type: Union[ - Type["Streaming_Doc_Dataset"], - Type["Scalable_Shard_Dataset"], - ], - rank: int, - worldsize: int, + dataset: Union[Scalable_Shard_Dataset, Streaming_Doc_Dataset], delimiter_token: Any, datasets=None, weights=None, verbose=False, - **kwargs, ): - super().__init__(rank, worldsize) + super().__init__(dataset) + self.datapath = datapath self.delimiter = delimiter_token + self.verbose = verbose self.datasets = ( datasets if datasets is not None @@ -900,28 +1086,27 @@ def __init__( self.tokens_seen = [0] * len(self.datasets) - # Build subdataset iterators - self.data = [] - for i, d in enumerate(self.datasets): - self.data.append( - dataset_type( - datapath=os.path.join(datapath, d), - rank=rank, - worldsize=worldsize, - delimiter_token=delimiter_token, - verbose=verbose, - **kwargs, - ) - ) - if verbose: - logging.info( - f"Worker {rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" - ) - self.current_iterator = -1 self.state_params = ["tokens_seen", "current_iterator"] + def setup(self): + if not self.is_setup: + self.is_setup = True + # Build subdataset iterators + self.data = [] + for i, d in enumerate(self.datasets): + self.data.append(deepcopy(self.dataset)) + self.data[-1].datapath = os.path.join(self.datapath, d) + for flag in ["rank", "worldsize", "load_worldsize"]: + setattr(self.data[-1], flag, getattr(self, flag)) + if self.verbose: + logging.info( + f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" + ) + [d.setup() for d in self.data] + def __iter__(self): + self.setup() # Grab one doc at a time in random order data = [iter(d) for d in self.data] while True: @@ -944,18 +1129,22 @@ def __iter__(self): self.current_iterator = offset_argmax def state_dict(self): + self.setup() # Manually add state of all subloaders to self state out = { self.statename("sample_iterator_states"): [ d.state_dict() for d in self.data ] } - out.update(super().state_dict()) + out.update(_Stateful_Dataset.state_dict(self)) return out def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() # Load stats - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + sharded_dicts = _Stateful_Dataset.load_state_dict( + self, state_dicts, sharded_input + ) # Load sub-iterator states for i, subdata in enumerate(self.data): # Grab just that sub-iterator across all ranks @@ -968,130 +1157,3 @@ def load_state_dict(self, state_dicts, sharded_input=False): True, ) return sharded_dicts - - -class Scalable_Shard_Dataset(_Stateful_Dataset): - """ - A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different - number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track - state individually and reshard over n_gpus. - - All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. - ... - Args - ---- - datapath : str - Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. - rank : int - Current worker index - worldsize : int - Total number of workers - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - n_logical_shards : int - Number of logical shards. Must be a multiple of world size. - ... - Pass-through args, see Streaming_Doc_Dataset - """ - - def __init__( - self, - datapath: str, - rank: int, - worldsize: int, - delimiter_token: Any, - n_logical_shards: int = 2048, - verbose=False, - **kwargs, - ): - assert ( - n_logical_shards % worldsize == 0 - ), f"World size {worldsize} must divide n_logical_shards {n_logical_shards} evenly" - assert ( - n_logical_shards > 0 - ), f"n_logical_shards {n_logical_shards} must be a positive integer" - - super().__init__(rank, worldsize) - self.data = [] - self.n_logicals = n_logical_shards // worldsize - self.total_shards = n_logical_shards - self.delimiter = delimiter_token - - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - assert len(self.logicals_owned) == self.n_logicals - - # Build logical shards - for i in range(self.n_logicals): - self.data.append( - Streaming_Doc_Dataset( - datapath=datapath, - worldsize=n_logical_shards, - rank=self.logicals_owned[i], - delimiter_token=delimiter_token, - verbose=(rank == 0), - **kwargs, - ) - ) - if verbose: - logging.info( - f"Worker {rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - - # Fetch logical shard sampling stats - self.n_docs_remaining = [d._len for d in self.data] - - # Position "state", used only for maintaining order when n_workers is unchanged - # For scaling up or down, logical position is meaningless, and reset - self.current_reader = None - self.logical_shard_states = None - self.generator = torch.Generator().manual_seed(self.rank) - self.g_state = None - self.state_params = ["current_reader", "g_state"] - self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - - def __iter__(self): - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - # Sample logical shard (or load from ckp) - if self.current_reader is not None: - ind = self.current_reader - else: - ind = torch.multinomial( - torch.tensor(self.n_docs_remaining, dtype=torch.float), - 1, - generator=self.generator, - ).item() - self.current_reader = ind - # Read doc - out = next(data[ind]) - while out[-1] != self.delimiter: - yield out - out = next(data[ind]) - # Update state to show we've finished the doc - self.current_reader = None - self.n_docs_remaining[ind] -= 1 - if sum(self.n_docs_remaining) == 0: - self.n_docs_remaining = [d._len for d in self.data] - self.generator.manual_seed(self.rank) - # Return final piece of doc - yield out - - def state_dict(self): - # Write generator state manually - self.g_state = self.generator.get_state() - # Recursive fetch - self.logical_shard_states = [d.state_dict() for d in self.data] - return super().state_dict() - - def load_state_dict(self, state_dicts, sharded_input=False): - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Recursive set - for i in range(self.n_logicals): - self.data[i].load_state_dict([self.logical_shard_states[i]], True) - return sharded_dicts diff --git a/fms_fsdp/utils/dataset_utils_v3.py b/fms_fsdp/utils/dataset_utils_v3.py deleted file mode 100644 index bcc24887..00000000 --- a/fms_fsdp/utils/dataset_utils_v3.py +++ /dev/null @@ -1,1159 +0,0 @@ -import csv -import logging -import math -import os -import random -import time -from copy import deepcopy -from typing import Any, Callable, List, Optional, Set, Type, Union - -import pyarrow as pa -import torch -import torch.utils.data as data - -from fms_fsdp.utils.checkpointing_utils import get_latest - - -""" -The following distributed dataloaders are designed around 3 main principles: - -1. Efficient, asynchronous operation. Workers on different devices do not communicate. -2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator - loading from disk and additional layers adding levels of post-processing (shuffling, - packing, padding, etc.). -3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal - state that can be written/read on disk via implemented recursive `state_dict()` and - `load_state_dict()` calls. -4. Rescalability. Users can save and load checkpoints to/from different numbers of workers - without losing the global state. This is accomplished by splitting state fields for each - layer into `state_params`, which are typically scalar-valued and can be discarded when - rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be - re-distributed over workers (i.e. buffers). - -Our loaders obey the following type heirarchy: -torch.data.IterableDataset -> _Stateful_Dataset -> _Wrapper_Dataset. -`_Stateful_Dataset` implements state and checkpointing logic. A `_Wrapper_Dataset` holds a -single `_Stateful_Dataset` and iterates via calling its wrapped dataset any number of times, -then applying some sort of post-processing and yielding the result. Users build data processing -pipelines by wrapping a base `_Stateful_Dataset` in any number of `_Wrapper_Dataset` layers, -which is then passed to the torch DataLoader. - -NOTE: `_Wrapper_Dataset` currently only implements wrapping a single instantiated sub-dataset layer. -Many layers need multiple sub-layers (i.e. random sampling from distinct data sources). These are -currently implemented as base `_Stateful_Datasets` that take the class of their sub-layers plus any -pass-through arguments, and instantiate all those sub-layers. This is easy on the user, who no longer -needs to instantiate large sets of sub-layers in their code, but leads to awkwardness in this file. -Cleanup is planned for the future. -""" - - -def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. - """ - return itemlist[ - (rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize - ] - - -def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, - and return the span including all owned items, fractional or otherwise. - """ - start = math.floor(len(itemlist) * rank / worldsize) - end = math.ceil(len(itemlist) * (rank + 1) / worldsize) - return itemlist[start:end] - - -class _Stateful_Dataset(data.IterableDataset): - """ - Stub for stateful datasets, extends data.IterableDataset with state_dict methods. - All subclasses should specify the params to be considered stateful or reshardable in the - self.state_params and self.reshard_params lists. - """ - - def __init__( - self, - datapath: str, - rank: int, - worldsize: int, - ): - assert rank >= 0, f"Rank {rank} must be a positive integer" - assert ( - worldsize > rank - ), f"Worldsize {worldsize} must be greater than rank {rank}" - assert datapath is None or ( - os.path.isdir(datapath) and len(os.listdir(datapath)) > 0 - ), f"Data path {datapath} must be a non-empty folder or None" - self.state_params: List[str] = [] - self.reshard_params: List[str] = [] - self.datapath = datapath - self.rank = rank - self.worldsize = worldsize - self.load_worldsize = ( - worldsize # Enable calling load_state_dict() directly, assume no rescaling - ) - self.is_setup = False - - def setup(self): - if not self.is_setup: - self.is_setup = True - pass - - def statename(self, x: str): - # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline - return self.__class__.__name__ + "." + x - - def state_dict(self): - """ - Retrieve all state and reshard flags (each worker/process saves its own state dict shard). - On the off chance that you're saving a checkpoint with zero steps, run setup first. - """ - self.setup() - return { - self.statename(flag): getattr(self, flag) - for flag in self.state_params + self.reshard_params - } - - def _reshard(self, sharded_list): - """ - Sharded_list is a list of lists, where each "shard" sublist must have the same length. - These shards should tightly span only the partition of data owned by this worker. - (i.e. if global_list is the list of all entries, sharded_list = _shard_inclusive(global_list) ). - Determine fractional ownership of shards, and get the flattened partition owned by this worker. - """ - # How many shards did _shard_inclusive() drop to the left of sharded_list? - shard_offset = math.floor(self.load_worldsize * self.rank / self.worldsize) - # How long are the list shards? - shard_len = len(sharded_list[0]) - for i, shard in enumerate(sharded_list): - assert ( - len(shard) == shard_len - ), f"Shard {i} with length {len(shard)} does not match expected {shard_len}" - # How many list items did _shard_inclusive() drop to the left of the flattened sharded_list? - item_offset = shard_len * shard_offset - # How many list items are there in total? - n_items = self.load_worldsize * shard_len - # The indices of the flattened sharded_list that this worker owns - my_items = range( - int(n_items * self.rank / self.worldsize) - item_offset, - int(n_items * (self.rank + 1) / self.worldsize) - item_offset, - ) - # Pull out owned items - return [sharded_list[i // shard_len][i % shard_len] for i in my_items] - - def load_state_dict(self, state_dicts, sharded_input=False): - """ - Input state_dicts is a list of state_dicts. If sharded_input=False, this is expected to be the - global list of states across all checkpoint shard files. If sharded_input=True, this expects - _shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading. - Workflow: - 1. Run setup to prepare dataset - 2. if sharded_inputs is false, shard the inputs. - 3. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint - shard (state_dicts is a singleton list). - 4. If worldsize does not match checkpoint, toss state params and assemble reshard params from - across given state_dicts. In this case state_dicts may be singleton (for fractional ownership) - or multi-element (for multiple/partitioned ownership). - 5. Return reduced input for use by downstream loading functions - """ - self.setup() - if not sharded_input: - self.load_worldsize = len(state_dicts) - state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) - if self.load_worldsize == self.worldsize: - [ - setattr(self, flag, state_dicts[0][self.statename(flag)]) - for flag in self.state_params + self.reshard_params - ] - else: - for flag in self.reshard_params: - reshard = self._reshard( - [sd[self.statename(flag)] for sd in state_dicts] - ) - setattr(self, flag, reshard) - return state_dicts - - def load_from_path(self, path: str): - """ - Count shard files in the specified checkpoint folder and determine overlap with current - rank and worldsize partition. Load only matching shardfile(s) and pass to load_state_dict. - This is more efficient than sharding the full loaded state. - """ - assert os.path.exists(path), "Specified checkpoint does not exist" - assert not os.path.isfile(path), "Checkpoint should be a folder of shard states" - fileshards = [x for x in os.listdir(path) if "loader" in x] - fileshards = sorted(fileshards, key=lambda x: int(x.split("_")[2][:-4])) - assert ( - len(fileshards) > 0 - ), "Checkpoint directory must contain checkpoint files with 'loader' in the name" - self.load_worldsize = len(fileshards) - # Grab only the shard files holding data we currently own - my_fileshards = _shard_inclusive(fileshards, self.rank, self.worldsize) - states = [torch.load(os.path.join(path, x)) for x in my_fileshards] - self.load_state_dict(states, True) - - def save_to_path(self, path: str): - """ - Grab recursive shard states and save all shard states to the specified checkpoint folder - """ - os.makedirs(path, exist_ok=True) - state = self.state_dict() - torch.save(state, os.path.join(path, f"loader_state_{self.rank}.pth")) - - -class _Wrapper_Dataset(_Stateful_Dataset): - """ - Stub for nested wrappers of _Stateful_Datasets. Extends state fns with recursion. - Requires a single instantiated sub-dataset. - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - ): - self.dataset = dataset - super().__init__( - self.dataset.datapath, self.dataset.rank, self.dataset.worldsize - ) - - def setup(self): - """ - Datapath/rank/worldsize percolate upwards recursively during initialization, now - project any desired changes downward, also recursively. - """ - if not self.is_setup: - self.is_setup = True - self.dataset.datapath = self.datapath - self.dataset.rank = self.rank - self.dataset.worldsize = self.worldsize - self.dataset.load_worldsize = self.load_worldsize - self.dataset.setup() - - def load_state_dict(self, state_dicts, sharded_input=False): - """ - Sets all specified flags at the current level, then recurses into wrapped dataset. - """ - self.setup() - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - self.dataset.load_worldsize = self.load_worldsize - self.dataset.load_state_dict(sharded_dicts, True) - return sharded_dicts - - def state_dict(self): - """ - Fetches state dict recursively from wrapped layers, then adds specified flags. - Overlapping flags are overwritten with a warning. - """ - self.setup() - out = self.dataset.state_dict() - state = super().state_dict() - for flag in self.state_params + self.reshard_params: - if flag in out: - logging.warning( - f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. " - + f"Overwriting with value {state[flag]}" - ) - out.update(state) - return out - - -class Preprocess_Dataset(_Wrapper_Dataset): - """ - Wrapper for a _Stateful_Dataset that applies a specified preprocessing - or augmentation function to dataset outputs. - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - aug_fn : function (any -> any) - The augmentation function to apply to each dataset item. - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - aug_fn: Callable, - ): - super().__init__(dataset) - self.aug_fn = aug_fn - - def __iter__(self): - dataset = iter(self.dataset) - while True: - out = next(dataset) - yield self.aug_fn(out) - - -class Checkpoint_Dataset(_Wrapper_Dataset): - """ - Wrapper for a _Stateful_Dataset that implements auto-checkpoint saving every n steps. - Useful for setting n_workers > 0, so that workers do not rely on the master process - for state saving (inter-process communication unsupported in PyTorch datasets). - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - load_path : str - Absolute path to checkpoint load directory. If a checkpoint exists, loads it. - interval : int - Saves a new checkpoint every interval. - steps_per_batch : optional[int] - Number of steps required to fill a single batch. Increments interval only - when a full batch is formed. Defaults to 1. - save_path : optional[str] - Absolute path to checkpoint save directory. Defaults to load_path. - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - load_path: str, - interval: int, - steps_per_batch: int = 1, - save_path: str = "", - ): - super().__init__(dataset) - self.interval = interval - self.spb = steps_per_batch - load_path = os.path.join(load_path, "checkpoints") - if len(save_path) == 0: - save_path = load_path - else: - save_path = os.path.join(save_path, "checkpoints") - self.path = save_path - self.step = 0 - self.ministep = 0 - self.load_from_path(load_path) - - def __iter__(self): - dataset = iter(self.dataset) - while True: - yield next(dataset) - self.ministep += 1 - if self.ministep == self.spb: - self.ministep = 0 - self.step += 1 - if self.step % self.interval == 0: - newpath = os.path.join(self.path, "step_" + str(self.step) + "_ckp") - self.save_to_path(newpath) - - def report(self, msg): - if self.rank == 0: - print(msg) - - def save_to_path(self, path: str): - self.report(f"Saving dataset to {path}") - start = time.time() - super().save_to_path(path) - self.report( - f"Dataset successfully saved to {path}! Save time: {time.time() - start}" - ) - - def load_from_path(self, path: str): - # If path does not exist, or exists but is empty, exit early - if not os.path.exists(path) or len(os.listdir(path)) == 0: - self.report( - f"No valid checkpoint detected at {path}, dataset starting from scratch." - ) - return - # Grab latest item in path - latest = os.path.join(path, get_latest(path)) - self.report(f"Dataset checkpoint detected at {latest}") - # If item is not a folder, exit early - if os.path.isfile(latest): - self.report( - f"Checkpoint exists but contains no dataset! Dataset starting from scratch." - ) - return - # If item is a folder, get the step count - self.step = int(latest.split("_")[-2]) - # Proceed - start = time.time() - self.dataset.load_from_path(latest) - self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}") - - -class Preload_Buffer_Dataset(_Wrapper_Dataset): - """ - Wrapper for a Stateful_Dataset that implements data shuffling via a single in/out buffer. - Fills buffer two at a time, up to desired size, then switches to one at a time to maintain size. - Passes randomly sampled outputs one by one. - Ensures local mixing of data without relying on sliding windows or shuffling of large buffers. - Any two consecutive inputs will be separated by window_size steps in expectation. - Rescaling-enabled: buffers that shrink will re-grow to window_size, buffers that expand stay large. - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - window_size : int - Max size of input/output buffer - """ - - def __init__(self, dataset: _Stateful_Dataset, window_size: int): - super().__init__(dataset) - assert ( - window_size > 1 - ), f"Window size {window_size} must be greater than 1 for shuffling to occur" - self.window_size = window_size - self.g_state = None - self.generator = torch.Generator().manual_seed(self.rank) - self.buffer: List[List[Any]] = [] - self.buffer_size = 0 - self.state_params = ["g_state"] - self.reshard_params = ["buffer"] - - def __iter__(self): - dataset = iter(self.dataset) - while True: - # Pad out buffer if needed - self._pad_buffer() - - # Load a point to buffer if necessary - if self.buffer_size < self.window_size: - self.buffer[self.buffer_size] = next(dataset) - self.buffer_size += 1 - - # Swap out randomly sampled value from buffer - i = torch.randint(self.buffer_size, (1,), generator=self.generator).item() - out = self.buffer[i] - self.buffer[i] = next(dataset) - yield out - - def _pad_buffer(self): - if self.buffer_size < self.window_size: - self.buffer += [ - [], - ] * (self.window_size - self.buffer_size) - - def state_dict(self): - # Write generator state manually - self.g_state = self.generator.get_state() - # Prune buffer so it can be resharded in future - self.buffer = self.buffer[: self.buffer_size] - out = super().state_dict() - return out - - def load_state_dict(self, state_dicts, sharded_input=False): - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Manually set buffer size - self.buffer_size = len(self.buffer) - return sharded_dicts - - -class Buffer_Dataset(_Wrapper_Dataset): - """ - Wrapper for a _Stateful_Dataset that takes in sequences of varying lengths, and packs/pads them - into sequences of desired length. Input sequences are packed greedily until the buffer would - otherwise overrun, then remaining values are filled depending on initialization flags. - Also injects BOS/EOS into the packed output sequence if desired, and if BOS/EOS tokens are - not already in those positions. Implements rescaling by simply dropping (buffer) state. - ... - Args - ---- - dataset : _Stateful_Dataset - Fully instantiated dataset - seq_len : int - The desired sequence length - pack_hard : bool - Split input sequences to fill output buffer, or use pad tokens to fill remaining space? - bos_token : any | None - Token to prepend to every output sequence. If None, no token is added. Type should match data type. - eos_token : any | None - Token to append to every output sequence. If None, no token is added. Type should match data type. - pad_token : any | None - Token used to fill out output sequence. Type should match data type. - drop_final_token : any | None - Drop the final token of each document if it matches this value? - (For edge case where bos=eos=None, and sep already appears at beginning of each doc - - drop added extra sep from end of doc) - """ - - def __init__( - self, - dataset: _Stateful_Dataset, - seq_len: int, - pack_hard: bool, - bos_token=None, - eos_token=None, - pad_token=None, - ): - super().__init__(dataset) - self.len = seq_len - - # Buffer args - self.buffer: List[str] = [] - self.bos = bos_token - self.eos = eos_token - self.pad = pad_token - self.pack_hard = pack_hard - if not pack_hard: - assert ( - pad_token is not None - ), "Error: if using pads, you must supply a pad_token" - - self.state_params = ["buffer"] - - def _get_buffer(self, iterable, length, buffer): - # Pull data until buffer is about to overrun, return exactly proper length - new = [] - while len(buffer) + len(new) < length: - buffer += new - new = next(iterable) - - # Add bos if needed - if self.bos is not None and (len(buffer) == 0 or buffer[0] != self.bos): - buffer = [self.bos] + buffer - - # Handle buffer splitting - if len(buffer) >= length: - # If buffer is too long, force split - out = buffer[:length] - buffer = buffer[length:] - if self.eos is not None and out[-1] != self.eos: - buffer = [out[-1]] + buffer - out[-1] = self.eos - buffer = buffer + new - else: - if self.pack_hard: - # Pack in as much of new sequence as will fit - buffer = buffer + new - out = buffer[:length] - buffer = buffer[length:] - if self.eos is not None and out[-1] != self.eos: - buffer = [out[-1]] + buffer - out[-1] = self.eos - else: - # Fill out with pads as needed - if self.eos is not None and buffer[-1] != self.eos: - buffer.append(self.eos) - if self.pad is not None: - out = buffer + [self.pad] * (length - len(buffer)) - else: - out = buffer - buffer = new - return out, buffer - - # Fill buffer line by line, delimiters and packing/splitting as appropriate - def __iter__(self): - dataset = iter(self.dataset) - while True: - out, buffer = self._get_buffer(dataset, self.len, self.buffer) - self.buffer = buffer - yield out - - -class Streaming_Doc_Dataset(_Stateful_Dataset): - """ - The base distributed dataset for loading sequences/documents from pyarrow shards. - Pyarrow shard files are expected to hold multiple recordBatches, where each recordBatch has a "tokens" - field consisting of a single token list. (i.e. each document is a single sequence under a "token" field, - and the file is a list of such sequences) - Relies on a compiled metadata file to fetch shardfile lengths, assumes file already exists in the parent directory, - and is in proper csv format (first row "dataset/filename,documents,tokens", subsequent rows these values). - - For a single dataset directory, splits shard files into x=worldsize fragments and grabs a 1/n contiguous - span of shard fragments (contiguous to limit file reads from cloud/disk). - Logs the number of documents owned from each shardfile, and relies on ZCG random bijection to - map contiguous range of indices to shuffled, noncontiguous set of documents from each shard file. - Shuffles the file list deterministically to hop from file to file. - - At runtime, iterates through documents in each shuffled shard file, pulling each shard on demand. - Shards are thus pulled no more than once per epoch. - Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents. - - Streaming_Doc_Dataset grabs files from a flat directory representing a single dataset. - For percentage-based sampling of multiple subdatasets, see Sampling_Dataset. - ... - Args - ---- - datapath : str - Absolute path to the dataset directory. Expects directory containing pyarrow shardfiles. - Parent directory should contain 'meta' folder with metadata csv file inside. - rank : int - Current worker index - worldsize : int - Total number of workers - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. Required for downstream - sampling logic (can be removed later via PreProcess_Dataset if needed). - bos_token : Any | None - Optional token used to indicate sequence/document start. Type should match data type. - strip_tokens : set[Any] - Token values that should be removed if detected at beginning or end of document - (i.e. any eos/bos tokens already present in the data). Type should match data type. - seed : int - The random seed for deterministic shuffling/sharding - min_length : int - Sequences below this length are skipped - max_chunksize : int - Maximum sequence length to return. Break long docs into chunks of this size or shorter. - verbose : bool - Track setup progress? - shuffle : bool - Shuffle shard file and document orders? (Disable for simple testing) - """ - - def __init__( - self, - datapath: str, - rank: int, - worldsize: int, - delimiter_token: Any, - bos_token: Optional[Any] = None, - strip_tokens: Optional[Set[Any]] = set(), - seed: int = 42, - min_length: int = 1, - max_chunksize: int = 1024, - verbose: bool = False, - ): - super().__init__(datapath, rank, worldsize) - self.seed = seed - self.datapath = datapath - self.min_length = min_length - assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" - self.chunksize = max_chunksize - self.eos = delimiter_token - self.bos = bos_token - self.drop = strip_tokens - self.verbose = verbose - self.docset: List[ - Any - ] = [] # map of doc indices to (shardid, min docid, max docid) - - # Position - self.docset_index = 0 - self.chunk_index = -1 - - # Stats - self.epochs_seen = -1 - self.tokens_seen = 0 - self.docs_seen = 0 - self.percent_seen = 0 - - self.state_params = [ - "dataset", - "docset_index", - "chunk_index", - "epochs_seen", - "tokens_seen", - "docs_seen", - "percent_seen", - "lcg_state", - ] - - # Setup flags - self.is_setup = False - self._len = 0 - self.dataset = "" - self.lcg_state = 0 - - def setup(self): - """ - All rank-dependent setup, which must occur after init - (rank assignment, subdataset splitting, etc.) - """ - if not self.is_setup: - datapath = self.datapath - self.is_setup = True - - # Gather per-file document counts from metadata count file(s) - countfiles = [ - x - for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) - if "counts" in x and "csv" in x - ] - assert len(countfiles) == 1 - doc_counts = {} - pathsplit = (datapath, "") - while len(pathsplit[1]) == 0: - pathsplit = os.path.split(pathsplit[0]) - pardir, dataset = pathsplit - self.dataset = dataset - with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: - key = fullpath[prefix:] - doc_counts[key] = int(row["documents"]) - - # Assemble document set owned by this worker: - # listdir, assemble shardfraglist (ind -> shard, frag) - shards = [ - shard - for shard in os.listdir(datapath) - if os.path.isfile(os.path.join(datapath, shard)) - and "arrow" in os.path.join(datapath, shard) - ] - shards.sort() # Ensure consistent sharding across machines - start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize - end_frag = ( - (self.rank + 1) * self.worldsize * len(shards) - ) // self.worldsize - shardfrags = [ - (shards[i // self.worldsize], i % self.worldsize) - for i in range(start_frag, end_frag) - ] - - # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): - ndocs = -1 - docset = {} # shardid -> (min docid, max docid) - for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[os.path.join(dataset, shard)] - doc_start = (ndocs * frag) // self.worldsize - doc_end = ( - ndocs * frag + ndocs - ) // self.worldsize - 1 # Inclusive upper bound - if shard not in docset: - docset[shard] = [doc_start, doc_end] - min_d, max_d = docset[shard] - if doc_start < min_d: - docset[shard][0] = doc_start - if doc_end > max_d: - docset[shard][1] = doc_end - - # Add all of this dataset's shard entries to self.docset - doccount = 0 - for shardid in docset: - min_d = docset[shardid][0] - max_d = docset[shardid][1] - self.docset.append((shardid, min_d, max_d)) - doccount += max_d - min_d + 1 - self._len = doccount - - if self.verbose: - logging.info( - f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" - ) - - # Shuffle shard files - guaranteed inconsistent across workers - seed = self.seed + self.rank - random.seed(seed) - random.shuffle(self.docset) - # Setup doc shuffle - same guarantee - self.lcg_state = seed - - def _get_docid(self, i): - """ - Given a global doc index over the set of docs owned by this worker, - return the corresponding data/shard/local index - """ - cur = 0 - assert ( - i <= self._len - ), f"You have requested an illegal doc index {i}, docset length is {self._len}" - for shardid, min_d, max_d in self.docset: - docrange = max_d - min_d + 1 - cur += docrange - if cur > i: - return shardid, docrange, min_d - - def _get_reader(self, path, newpath, reader): - """ - If new filepath does not match the current one, - open a new reader on that filepath (pull file on demand) - """ - if newpath != path: - del reader - if self.verbose: - logging.info(f"Worker {self.rank} opening new file {newpath}") - reader = pa.ipc.open_file(newpath) - path = newpath - return path, reader - - def _construct_chunk(self, j, doc, n_chunks): - """ - Grab a chunk of the desired size from the pyarrow document, - avoiding unnecessary overhead in case of large docs - """ - start_index = j * self.chunksize - n_pull = self.chunksize - if self.bos is not None: - if j == 0: - n_pull -= 1 - else: - start_index -= 1 - chunk = doc.slice(start_index, n_pull).to_pylist() - self.tokens_seen += len(chunk) - # Add bos/eos tokens if needed - if self.bos is not None and j == 0: - chunk = [self.bos] + chunk - if j == n_chunks - 1: - chunk = chunk + [self.eos] - return chunk - - def _random_map_docid(self, size): - """ - Given size of document pool, use saved state (prior index) to generate the next index via LCG. - Implements within-shard document shuffling without materializing any large doc lists. - """ - m = 2 ** math.ceil(math.log2(size)) # Round up to nearest power of 2 - a = 5 # A,C values known to work well with powers of 2 (Knuth, 1997, 3.2.1.3) - c = (self.rank + self.seed) * 2 + 1 - state = self.lcg_state - while True: - state = (a * state + c) % m - if state < size: - return state - - def __iter__(self): - if not self.is_setup: - self.setup() - docset_offset = self.docset_index - lcg_offset = self.lcg_state - residual_chunks = self.chunk_index + 1 # pick up AFTER where the ckp left off - ndocs = self._len - path = "" - reader = None - while True: - # Iterate through docs, starting at desired offset - for i in range(ndocs): - doc_index = (docset_offset + i) % ndocs - - # Update stats - if doc_index == 0: - self.epochs_seen += 1 - self.docset_index = doc_index - # Map doc id to shard, id in file - shardid, docrange, mindoc = self._get_docid(doc_index) - - # Read doc - newpath = os.path.join(self.datapath, shardid) - path, reader = self._get_reader(path, newpath, reader) - # Map id in range of owned docs to new (consistently) shuffled id - doclcg = self._random_map_docid(docrange) - docid = doclcg + mindoc - doc = reader.get_batch(docid)["tokens"] - if doc[0].as_py() in self.drop: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in self.drop: - doc = doc.slice(0, len(doc) - 1) - doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: - n_chunks = math.ceil(doclen / self.chunksize) - for j in range(n_chunks): - if i == 0 and j < residual_chunks: - pass - else: - self.chunk_index = j - # Document complete, update stats - if j == n_chunks - 1: - self.docs_seen += 1 - self.percent_seen = ( - self.docs_seen * 100 / (self._len + 1e-9) - ) - yield self._construct_chunk(j, doc, n_chunks) - - # Advance RNG state - self.lcg_state = doclcg - - # Load any chunks initially skipped in first doc - self.docset_index = docset_offset - self.lcg_state = lcg_offset - shardid, docrange, mindoc = self._get_docid(docset_offset) - docid = self._random_map_docid(docrange) + mindoc - newpath = os.path.join(self.datapath, shardid) - path, reader = self._get_reader(path, newpath, reader) - doc = reader.get_batch(docid)["tokens"] - if doc[0].as_py() in self.drop: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in self.drop: - doc = doc.slice(0, len(doc) - 1) - doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: - n_chunks = math.ceil(doclen / self.chunksize) - for j in range(residual_chunks): - self.chunk_index = j - yield self._construct_chunk(j, doc, n_chunks) - - def load_state_dict(self, state_dicts, sharded_input=False): - self.setup() - assert ( - self.load_worldsize == self.worldsize - ), f"Streaming_Doc_Dataset does not support rescaling ({self.load_worldsize, self.worldsize}). Please use a Scalable_Shard_Dataset." - d = self.dataset - out = super().load_state_dict(state_dicts, sharded_input) - assert ( - d == self.dataset - ), f"Dataset mismatch: checkpoint contains {self.dataset}, expected {d}" - return out - - -class Scalable_Shard_Dataset(_Wrapper_Dataset): - """ - A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different - number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track - state individually and reshard over n_gpus. - - All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. - ... - Args - ---- - datapath : str - Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. - rank : int - Current worker index - worldsize : int - Total number of workers - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - n_logical_shards : int - Number of logical shards. Must be a multiple of world size. - ... - Pass-through args, see Streaming_Doc_Dataset - """ - - def __init__( - self, - dataset: Streaming_Doc_Dataset, - delimiter_token: Any, - n_logical_shards: int = 2048, - verbose=False, - ): - super().__init__(dataset) - assert ( - n_logical_shards % self.worldsize == 0 - ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" - assert ( - n_logical_shards > 0 - ), f"n_logical_shards {n_logical_shards} must be a positive integer" - - self.total_shards = n_logical_shards - self.delimiter = delimiter_token - self.verbose = verbose - - # Fields to be populated during setup / subdataset setup - self.data: List[Streaming_Doc_Dataset] = [] - self.logicals_owned: List[int] = [] - self.n_logicals = 0 - self.n_docs_remaining: List[int] = [] - self.generator = None - - # Position "state", used only for maintaining order when n_workers is unchanged - # For scaling up or down, logical position is meaningless, and reset - self.current_reader = None - self.logical_shard_states = None - self.g_state = None - - self.state_params = ["current_reader", "g_state"] - self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - - def setup(self): - if not self.is_setup: - self.is_setup = True - n_logical_shards = self.total_shards - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - self.n_logicals = n_logical_shards // self.worldsize - assert len(self.logicals_owned) == self.n_logicals - - # Build logical shards - for i in range(self.n_logicals): - self.data.append(deepcopy(self.dataset)) - self.data[-1].worldsize = n_logical_shards - self.data[-1].load_worldsize = n_logical_shards - self.data[-1].rank = self.logicals_owned[i] - self.data[-1].datapath = self.datapath - self.data[-1].verbose = self.rank == 0 - if self.verbose: - logging.info( - f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - [d.setup() for d in self.data] - self.n_docs_remaining = [d._len for d in self.data] - - self.generator = torch.Generator().manual_seed(self.rank) - - def __iter__(self): - self.setup() - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - # Sample logical shard (or load from ckp) - if self.current_reader is not None: - ind = self.current_reader - else: - ind = torch.multinomial( - torch.tensor(self.n_docs_remaining, dtype=torch.float), - 1, - generator=self.generator, - ).item() - self.current_reader = ind - # Read doc - out = next(data[ind]) - while out[-1] != self.delimiter: - yield out - out = next(data[ind]) - # Update state to show we've finished the doc - self.current_reader = None - self.n_docs_remaining[ind] -= 1 - if sum(self.n_docs_remaining) == 0: - self.n_docs_remaining = [d._len for d in self.data] - self.generator.manual_seed(self.rank) - # Return final piece of doc - yield out - - def state_dict(self): - self.setup() - # Write generator state manually - self.g_state = self.generator.get_state() - # Recursive fetch - self.logical_shard_states = [d.state_dict() for d in self.data] - return _Stateful_Dataset.state_dict(self) - - def load_state_dict(self, state_dicts, sharded_input=False): - self.setup() - sharded_dicts = _Stateful_Dataset.load_state_dict( - self, state_dicts, sharded_input - ) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Recursive set - for i in range(self.n_logicals): - self.data[i].load_state_dict([self.logical_shard_states[i]], True) - return sharded_dicts - - -class Sampling_Dataset(_Wrapper_Dataset): - """ - A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the - number of tokens seen from each subdataset will match those weights as closely as possible. - This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking - the number of tokens emitted by each. Whichever loader is furthest from its target will be - the next to pass a document. - - All args except for dataset_type, datasets, weights and delimiter are pass-through args for - the component _Stateful_Datasets and are documented in the appropriate classes. - ... - Args - ---- - dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset - Underlying iterator for each desired subdataset - delimiter_token : Any - Token used to indicate sequence/document breaks. Type should match data type. - datasets : list[str] | None - A list of subdatasets to draw from. If None, draws from all subfolders of datapath. - weights : list(float) | None - Weights describing what percent of emitted tokens should come from each subdataset. - Need not sum to 1. If None, tokens are drawn evenly. - ... - Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset - """ - - def __init__( - self, - datapath: str, - dataset: Union[Scalable_Shard_Dataset, Streaming_Doc_Dataset], - delimiter_token: Any, - datasets=None, - weights=None, - verbose=False, - ): - super().__init__(dataset) - self.datapath = datapath - self.delimiter = delimiter_token - self.verbose = verbose - self.datasets = ( - datasets - if datasets is not None - else [ - f - for f in os.listdir(datapath) - if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f - ] - ) - assert len(self.datasets) > 0, "You must specify at least one dataset" - - if weights is not None: - assert len(weights) == len( - self.datasets - ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" - for w in weights: - assert w > 0, f"Sampling rate {w} must be positive" - self.weights = [1] * len(self.datasets) if weights is None else weights - self.weights = [w / sum(self.weights) for w in self.weights] - - self.tokens_seen = [0] * len(self.datasets) - - self.current_iterator = -1 - self.state_params = ["tokens_seen", "current_iterator"] - - def setup(self): - if not self.is_setup: - self.is_setup = True - # Build subdataset iterators - self.data = [] - for i, d in enumerate(self.datasets): - self.data.append(deepcopy(self.dataset)) - self.data[-1].datapath = os.path.join(self.datapath, d) - for flag in ["rank", "worldsize", "load_worldsize"]: - setattr(self.data[-1], flag, getattr(self, flag)) - if self.verbose: - logging.info( - f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" - ) - [d.setup() for d in self.data] - - def __iter__(self): - self.setup() - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - if self.current_iterator != -1: - # Finish current document - out = next(data[self.current_iterator]) - self.tokens_seen[self.current_iterator] += len(out) - if out[-1] == self.delimiter: - self.current_iterator = -1 - yield out - else: - # Choose new subdataset to draw from - # (whichever is currently most underrepresented compared to target rate) - offset = [ - self.weights[i] - - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) - for i in range(len(self.datasets)) - ] - offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] - self.current_iterator = offset_argmax - - def state_dict(self): - self.setup() - # Manually add state of all subloaders to self state - out = { - self.statename("sample_iterator_states"): [ - d.state_dict() for d in self.data - ] - } - out.update(_Stateful_Dataset.state_dict(self)) - return out - - def load_state_dict(self, state_dicts, sharded_input=False): - self.setup() - # Load stats - sharded_dicts = _Stateful_Dataset.load_state_dict( - self, state_dicts, sharded_input - ) - # Load sub-iterator states - for i, subdata in enumerate(self.data): - # Grab just that sub-iterator across all ranks - subdata.load_worldsize = self.load_worldsize - subdata.load_state_dict( - [ - sd[self.statename("sample_iterator_states")][i] - for sd in sharded_dicts - ], - True, - ) - return sharded_dicts diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 33d47d3f..48bf5b01 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -8,7 +8,7 @@ import pyarrow as pa import torch -from fms_fsdp.utils.dataset_utils_v3 import * +from fms_fsdp.utils.dataset_utils import * # Generates test data in a temp directory, and returns that tempdir object. From ae9fc298da9a375609e01f6ab29fe1fc7ed341d4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 25 Jul 2024 16:00:59 -0400 Subject: [PATCH 38/73] Update docs, cleanup --- fms_fsdp/utils/dataset_utils.py | 98 +++++++++++++++++---------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index bcc24887..cefcba32 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -36,14 +36,7 @@ single `_Stateful_Dataset` and iterates via calling its wrapped dataset any number of times, then applying some sort of post-processing and yielding the result. Users build data processing pipelines by wrapping a base `_Stateful_Dataset` in any number of `_Wrapper_Dataset` layers, -which is then passed to the torch DataLoader. - -NOTE: `_Wrapper_Dataset` currently only implements wrapping a single instantiated sub-dataset layer. -Many layers need multiple sub-layers (i.e. random sampling from distinct data sources). These are -currently implemented as base `_Stateful_Datasets` that take the class of their sub-layers plus any -pass-through arguments, and instantiate all those sub-layers. This is easy on the user, who no longer -needs to instantiate large sets of sub-layers in their code, but leads to awkwardness in this file. -Cleanup is planned for the future. +which is then passed to the torch DataLoader. """ @@ -88,15 +81,25 @@ def __init__( ), f"Data path {datapath} must be a non-empty folder or None" self.state_params: List[str] = [] self.reshard_params: List[str] = [] + + # Default fields self.datapath = datapath self.rank = rank self.worldsize = worldsize - self.load_worldsize = ( - worldsize # Enable calling load_state_dict() directly, assume no rescaling - ) + + # Setup / loading flags + self.load_worldsize = worldsize self.is_setup = False def setup(self): + """ + This method should contain all setup depending on datapath or rank. + It is called after init, but immediately before any other operation. + Certain operations higher up in the pipeline may change rank or datapath + after init (for example, wrapping in a subdataset sampler layer, or copying + to worker processes), so all rank- and datapth- dependent ops are deferred to + this function. + """ if not self.is_setup: self.is_setup = True pass @@ -206,7 +209,7 @@ def save_to_path(self, path: str): class _Wrapper_Dataset(_Stateful_Dataset): """ Stub for nested wrappers of _Stateful_Datasets. Extends state fns with recursion. - Requires a single instantiated sub-dataset. + Requires a single instantiated sub-dataset (which may be replicated during setup fn). """ def __init__( @@ -214,21 +217,22 @@ def __init__( dataset: _Stateful_Dataset, ): self.dataset = dataset + # Inherit default flags from sub-dataset super().__init__( self.dataset.datapath, self.dataset.rank, self.dataset.worldsize ) def setup(self): """ - Datapath/rank/worldsize percolate upwards recursively during initialization, now - project any desired changes downward, also recursively. + Datapath/rank/worldsize percolate upwards recursively during initialization, so + now we project any desired changes downward, also recursively. + Any code overriding this function should still include this functionality. """ if not self.is_setup: self.is_setup = True self.dataset.datapath = self.datapath self.dataset.rank = self.rank self.dataset.worldsize = self.worldsize - self.dataset.load_worldsize = self.load_worldsize self.dataset.setup() def load_state_dict(self, state_dicts, sharded_input=False): @@ -384,7 +388,8 @@ class Preload_Buffer_Dataset(_Wrapper_Dataset): Passes randomly sampled outputs one by one. Ensures local mixing of data without relying on sliding windows or shuffling of large buffers. Any two consecutive inputs will be separated by window_size steps in expectation. - Rescaling-enabled: buffers that shrink will re-grow to window_size, buffers that expand stay large. + Rescaling-enabled: buffers that shrink will re-grow to window_size, + buffers that expand will shrink back down to window_size. ... Args ---- @@ -413,15 +418,21 @@ def __iter__(self): # Pad out buffer if needed self._pad_buffer() - # Load a point to buffer if necessary + # If buffer is undersized, add a datapoint if self.buffer_size < self.window_size: self.buffer[self.buffer_size] = next(dataset) self.buffer_size += 1 - # Swap out randomly sampled value from buffer + # Swap out randomly sampled value from buffer. + # If buffer is small, add new item. + # If buffer is large, pop last item into that slot. i = torch.randint(self.buffer_size, (1,), generator=self.generator).item() out = self.buffer[i] - self.buffer[i] = next(dataset) + if self.buffer_size > self.window_size: + self.buffer[i] = self.buffer[self.buffer_size - 1] + self.buffer_size -= 1 + else: + self.buffer[i] = next(dataset) yield out def _pad_buffer(self): @@ -470,10 +481,6 @@ class Buffer_Dataset(_Wrapper_Dataset): Token to append to every output sequence. If None, no token is added. Type should match data type. pad_token : any | None Token used to fill out output sequence. Type should match data type. - drop_final_token : any | None - Drop the final token of each document if it matches this value? - (For edge case where bos=eos=None, and sep already appears at beginning of each doc - - drop added extra sep from end of doc) """ def __init__( @@ -658,7 +665,7 @@ def __init__( def setup(self): """ All rank-dependent setup, which must occur after init - (rank assignment, subdataset splitting, etc.) + (rank assignment, data partitioning, shuffling) """ if not self.is_setup: datapath = self.datapath @@ -721,7 +728,7 @@ def setup(self): if doc_end > max_d: docset[shard][1] = doc_end - # Add all of this dataset's shard entries to self.docset + # Add shard entries to self.docset doccount = 0 for shardid in docset: min_d = docset[shardid][0] @@ -879,7 +886,7 @@ def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( self.load_worldsize == self.worldsize - ), f"Streaming_Doc_Dataset does not support rescaling ({self.load_worldsize, self.worldsize}). Please use a Scalable_Shard_Dataset." + ), f"Streaming_Doc_Dataset does not support rescaling (ckp size: {self.load_worldsize}, world size: {self.worldsize}). Please use a Scalable_Shard_Dataset." d = self.dataset out = super().load_state_dict(state_dicts, sharded_input) assert ( @@ -890,27 +897,21 @@ def load_state_dict(self, state_dicts, sharded_input=False): class Scalable_Shard_Dataset(_Wrapper_Dataset): """ - A _Stateful_Dataset implementing rescalability: loading from checkpoint into a different + A _Wrapper_Dataset implementing rescalability: loading from checkpoint into a different number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, which track - state individually and reshard over n_gpus. - - All keywords except the first are simple pass-through arguments and are documented in Streaming_Doc_Dataset. + This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, cloned from the + original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. ... Args ---- - datapath : str - Absolute path to the dataset directory. Expects folder containing pyarrow shardfiles. - rank : int - Current worker index - worldsize : int - Total number of workers - delimiter_token : Any + dataset : Streaming_Doc_Dataset + Fully instantiated dataset. Cloned into logical workers during setup fn. + delimiter_token : any Token used to indicate sequence/document breaks. Type should match data type. n_logical_shards : int Number of logical shards. Must be a multiple of world size. - ... - Pass-through args, see Streaming_Doc_Dataset + verbose : bool + Track setup progress? """ def __init__( @@ -1027,7 +1028,7 @@ def load_state_dict(self, state_dicts, sharded_input=False): class Sampling_Dataset(_Wrapper_Dataset): """ - A _Stateful_Dataset implementing percentage-based sampling: weights can be floats, and the + A _Wrapper_Dataset implementing percentage-based sampling: weights can be floats, and the number of tokens seen from each subdataset will match those weights as closely as possible. This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking the number of tokens emitted by each. Whichever loader is furthest from its target will be @@ -1038,8 +1039,11 @@ class Sampling_Dataset(_Wrapper_Dataset): ... Args ---- - dataset_type : Scalable_Shard_Dataset | Streaming_Doc_Dataset - Underlying iterator for each desired subdataset + datapath : str + Absolute path to the dataset directory. Expects directory to contain subfolders with + pyarrow shardfiles, and also a 'meta' folder with metadata csv file inside. + dataset : Scalable_Shard_Dataset | Streaming_Doc_Dataset + Fully instantiated dataset. Cloned across desired subdatasets during setup. delimiter_token : Any Token used to indicate sequence/document breaks. Type should match data type. datasets : list[str] | None @@ -1047,8 +1051,8 @@ class Sampling_Dataset(_Wrapper_Dataset): weights : list(float) | None Weights describing what percent of emitted tokens should come from each subdataset. Need not sum to 1. If None, tokens are drawn evenly. - ... - Pass-through args, see Streaming_Doc_Dataset or Scalable_Shard_Dataset + verbose : bool + Track setup progress? """ def __init__( @@ -1097,8 +1101,8 @@ def setup(self): for i, d in enumerate(self.datasets): self.data.append(deepcopy(self.dataset)) self.data[-1].datapath = os.path.join(self.datapath, d) - for flag in ["rank", "worldsize", "load_worldsize"]: - setattr(self.data[-1], flag, getattr(self, flag)) + self.data[-1].rank = self.rank + self.data[-1].worldsize = self.worldsize if self.verbose: logging.info( f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" From 217a1470d52359173a8770d4107b6541f913622f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 25 Jul 2024 16:13:24 -0400 Subject: [PATCH 39/73] Put sampling dataset earlier again? --- fms_fsdp/utils/dataset_utils.py | 262 ++++++++++++++++---------------- 1 file changed, 131 insertions(+), 131 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index cefcba32..691555ec 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -895,137 +895,6 @@ def load_state_dict(self, state_dicts, sharded_input=False): return out -class Scalable_Shard_Dataset(_Wrapper_Dataset): - """ - A _Wrapper_Dataset implementing rescalability: loading from checkpoint into a different - number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, cloned from the - original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. - ... - Args - ---- - dataset : Streaming_Doc_Dataset - Fully instantiated dataset. Cloned into logical workers during setup fn. - delimiter_token : any - Token used to indicate sequence/document breaks. Type should match data type. - n_logical_shards : int - Number of logical shards. Must be a multiple of world size. - verbose : bool - Track setup progress? - """ - - def __init__( - self, - dataset: Streaming_Doc_Dataset, - delimiter_token: Any, - n_logical_shards: int = 2048, - verbose=False, - ): - super().__init__(dataset) - assert ( - n_logical_shards % self.worldsize == 0 - ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" - assert ( - n_logical_shards > 0 - ), f"n_logical_shards {n_logical_shards} must be a positive integer" - - self.total_shards = n_logical_shards - self.delimiter = delimiter_token - self.verbose = verbose - - # Fields to be populated during setup / subdataset setup - self.data: List[Streaming_Doc_Dataset] = [] - self.logicals_owned: List[int] = [] - self.n_logicals = 0 - self.n_docs_remaining: List[int] = [] - self.generator = None - - # Position "state", used only for maintaining order when n_workers is unchanged - # For scaling up or down, logical position is meaningless, and reset - self.current_reader = None - self.logical_shard_states = None - self.g_state = None - - self.state_params = ["current_reader", "g_state"] - self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - - def setup(self): - if not self.is_setup: - self.is_setup = True - n_logical_shards = self.total_shards - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - self.n_logicals = n_logical_shards // self.worldsize - assert len(self.logicals_owned) == self.n_logicals - - # Build logical shards - for i in range(self.n_logicals): - self.data.append(deepcopy(self.dataset)) - self.data[-1].worldsize = n_logical_shards - self.data[-1].load_worldsize = n_logical_shards - self.data[-1].rank = self.logicals_owned[i] - self.data[-1].datapath = self.datapath - self.data[-1].verbose = self.rank == 0 - if self.verbose: - logging.info( - f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - [d.setup() for d in self.data] - self.n_docs_remaining = [d._len for d in self.data] - - self.generator = torch.Generator().manual_seed(self.rank) - - def __iter__(self): - self.setup() - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - # Sample logical shard (or load from ckp) - if self.current_reader is not None: - ind = self.current_reader - else: - ind = torch.multinomial( - torch.tensor(self.n_docs_remaining, dtype=torch.float), - 1, - generator=self.generator, - ).item() - self.current_reader = ind - # Read doc - out = next(data[ind]) - while out[-1] != self.delimiter: - yield out - out = next(data[ind]) - # Update state to show we've finished the doc - self.current_reader = None - self.n_docs_remaining[ind] -= 1 - if sum(self.n_docs_remaining) == 0: - self.n_docs_remaining = [d._len for d in self.data] - self.generator.manual_seed(self.rank) - # Return final piece of doc - yield out - - def state_dict(self): - self.setup() - # Write generator state manually - self.g_state = self.generator.get_state() - # Recursive fetch - self.logical_shard_states = [d.state_dict() for d in self.data] - return _Stateful_Dataset.state_dict(self) - - def load_state_dict(self, state_dicts, sharded_input=False): - self.setup() - sharded_dicts = _Stateful_Dataset.load_state_dict( - self, state_dicts, sharded_input - ) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Recursive set - for i in range(self.n_logicals): - self.data[i].load_state_dict([self.logical_shard_states[i]], True) - return sharded_dicts - - class Sampling_Dataset(_Wrapper_Dataset): """ A _Wrapper_Dataset implementing percentage-based sampling: weights can be floats, and the @@ -1161,3 +1030,134 @@ def load_state_dict(self, state_dicts, sharded_input=False): True, ) return sharded_dicts + + +class Scalable_Shard_Dataset(_Wrapper_Dataset): + """ + A _Wrapper_Dataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, cloned from the + original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. + ... + Args + ---- + dataset : Streaming_Doc_Dataset + Fully instantiated dataset. Cloned into logical workers during setup fn. + delimiter_token : any + Token used to indicate sequence/document breaks. Type should match data type. + n_logical_shards : int + Number of logical shards. Must be a multiple of world size. + verbose : bool + Track setup progress? + """ + + def __init__( + self, + dataset: Streaming_Doc_Dataset, + delimiter_token: Any, + n_logical_shards: int = 2048, + verbose=False, + ): + super().__init__(dataset) + assert ( + n_logical_shards % self.worldsize == 0 + ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert ( + n_logical_shards > 0 + ), f"n_logical_shards {n_logical_shards} must be a positive integer" + + self.total_shards = n_logical_shards + self.delimiter = delimiter_token + self.verbose = verbose + + # Fields to be populated during setup / subdataset setup + self.data: List[Streaming_Doc_Dataset] = [] + self.logicals_owned: List[int] = [] + self.n_logicals = 0 + self.n_docs_remaining: List[int] = [] + self.generator = None + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = None + self.logical_shard_states = None + self.g_state = None + + self.state_params = ["current_reader", "g_state"] + self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + + def setup(self): + if not self.is_setup: + self.is_setup = True + n_logical_shards = self.total_shards + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + self.n_logicals = n_logical_shards // self.worldsize + assert len(self.logicals_owned) == self.n_logicals + + # Build logical shards + for i in range(self.n_logicals): + self.data.append(deepcopy(self.dataset)) + self.data[-1].worldsize = n_logical_shards + self.data[-1].load_worldsize = n_logical_shards + self.data[-1].rank = self.logicals_owned[i] + self.data[-1].datapath = self.datapath + self.data[-1].verbose = self.rank == 0 + if self.verbose: + logging.info( + f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + [d.setup() for d in self.data] + self.n_docs_remaining = [d._len for d in self.data] + + self.generator = torch.Generator().manual_seed(self.rank) + + def __iter__(self): + self.setup() + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + # Sample logical shard (or load from ckp) + if self.current_reader is not None: + ind = self.current_reader + else: + ind = torch.multinomial( + torch.tensor(self.n_docs_remaining, dtype=torch.float), + 1, + generator=self.generator, + ).item() + self.current_reader = ind + # Read doc + out = next(data[ind]) + while out[-1] != self.delimiter: + yield out + out = next(data[ind]) + # Update state to show we've finished the doc + self.current_reader = None + self.n_docs_remaining[ind] -= 1 + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) + # Return final piece of doc + yield out + + def state_dict(self): + self.setup() + # Write generator state manually + self.g_state = self.generator.get_state() + # Recursive fetch + self.logical_shard_states = [d.state_dict() for d in self.data] + return _Stateful_Dataset.state_dict(self) + + def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() + sharded_dicts = _Stateful_Dataset.load_state_dict( + self, state_dicts, sharded_input + ) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Recursive set + for i in range(self.n_logicals): + self.data[i].load_state_dict([self.logical_shard_states[i]], True) + return sharded_dicts From 2b479fd0e7904c4f59c1fb34f90e141dc7e52be1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 25 Jul 2024 16:20:20 -0400 Subject: [PATCH 40/73] Remove Weird_Casing --- fms_fsdp/utils/dataloader_utils.py | 28 +++++----- fms_fsdp/utils/dataset_utils.py | 90 +++++++++++++++--------------- tests/test_datasets.py | 46 +++++++-------- 3 files changed, 82 insertions(+), 82 deletions(-) diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 50026456..b9e4e772 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -1,13 +1,13 @@ import torch from fms_fsdp.utils.dataset_utils import ( - Buffer_Dataset, - Checkpoint_Dataset, - Preload_Buffer_Dataset, - Preprocess_Dataset, - Sampling_Dataset, - Scalable_Shard_Dataset, - Streaming_Doc_Dataset, + BufferDataset, + CheckpointDataset, + PreloadBufferDataset, + PreprocessDataset, + SamplingDataset, + ScalableShardDataset, + StreamingDocDataset, ) @@ -70,7 +70,7 @@ def causal_lm(data_seq, prompt_len=0): ] droplist = droplist + [cfg.bos_token, cfg.eos_token, cfg.bol_token, cfg.eol_token] # Base reader layer - data = Streaming_Doc_Dataset( + data = StreamingDocDataset( cfg.data_path, rank, world_size, @@ -81,13 +81,13 @@ def causal_lm(data_seq, prompt_len=0): seed=cfg.seed, ) # Add rescaling/resharding - data = Scalable_Shard_Dataset( + data = ScalableShardDataset( data, cfg.eos_token, n_logical_shards=cfg.logical_shards, ) # Add multi-dataset handling - data = Sampling_Dataset( + data = SamplingDataset( cfg.data_path, data, cfg.eos_token, @@ -96,7 +96,7 @@ def causal_lm(data_seq, prompt_len=0): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. - data = Buffer_Dataset( + data = BufferDataset( data, cfg.seq_length + 1, bos_token=cfg.bol_token, @@ -104,11 +104,11 @@ def causal_lm(data_seq, prompt_len=0): pack_hard=True, ) # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. - data = Preload_Buffer_Dataset(data, 10000) + data = PreloadBufferDataset(data, 10000) # Split line into input and target for the CLM task. - data = Preprocess_Dataset(data, causal_lm) + data = PreprocessDataset(data, causal_lm) # Enable auto-saving - data = Checkpoint_Dataset( + data = CheckpointDataset( data, cfg.ckpt_load_path, cfg.checkpoint_interval, diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 691555ec..44c1fc92 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -31,11 +31,11 @@ re-distributed over workers (i.e. buffers). Our loaders obey the following type heirarchy: -torch.data.IterableDataset -> _Stateful_Dataset -> _Wrapper_Dataset. -`_Stateful_Dataset` implements state and checkpointing logic. A `_Wrapper_Dataset` holds a -single `_Stateful_Dataset` and iterates via calling its wrapped dataset any number of times, +torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset. +`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a +single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times, then applying some sort of post-processing and yielding the result. Users build data processing -pipelines by wrapping a base `_Stateful_Dataset` in any number of `_Wrapper_Dataset` layers, +pipelines by wrapping a base `_StatefulDataset` in any number of `_WrapperDataset` layers, which is then passed to the torch DataLoader. """ @@ -59,7 +59,7 @@ def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any return itemlist[start:end] -class _Stateful_Dataset(data.IterableDataset): +class _StatefulDataset(data.IterableDataset): """ Stub for stateful datasets, extends data.IterableDataset with state_dict methods. All subclasses should specify the params to be considered stateful or reshardable in the @@ -206,15 +206,15 @@ def save_to_path(self, path: str): torch.save(state, os.path.join(path, f"loader_state_{self.rank}.pth")) -class _Wrapper_Dataset(_Stateful_Dataset): +class _WrapperDataset(_StatefulDataset): """ - Stub for nested wrappers of _Stateful_Datasets. Extends state fns with recursion. + Stub for nested wrappers of _StatefulDatasets. Extends state fns with recursion. Requires a single instantiated sub-dataset (which may be replicated during setup fn). """ def __init__( self, - dataset: _Stateful_Dataset, + dataset: _StatefulDataset, ): self.dataset = dataset # Inherit default flags from sub-dataset @@ -263,14 +263,14 @@ def state_dict(self): return out -class Preprocess_Dataset(_Wrapper_Dataset): +class PreprocessDataset(_WrapperDataset): """ - Wrapper for a _Stateful_Dataset that applies a specified preprocessing + Wrapper for a _StatefulDataset that applies a specified preprocessing or augmentation function to dataset outputs. ... Args ---- - dataset : _Stateful_Dataset + dataset : _StatefulDataset Fully instantiated dataset aug_fn : function (any -> any) The augmentation function to apply to each dataset item. @@ -278,7 +278,7 @@ class Preprocess_Dataset(_Wrapper_Dataset): def __init__( self, - dataset: _Stateful_Dataset, + dataset: _StatefulDataset, aug_fn: Callable, ): super().__init__(dataset) @@ -291,15 +291,15 @@ def __iter__(self): yield self.aug_fn(out) -class Checkpoint_Dataset(_Wrapper_Dataset): +class CheckpointDataset(_WrapperDataset): """ - Wrapper for a _Stateful_Dataset that implements auto-checkpoint saving every n steps. + Wrapper for a _StatefulDataset that implements auto-checkpoint saving every n steps. Useful for setting n_workers > 0, so that workers do not rely on the master process for state saving (inter-process communication unsupported in PyTorch datasets). ... Args ---- - dataset : _Stateful_Dataset + dataset : _StatefulDataset Fully instantiated dataset load_path : str Absolute path to checkpoint load directory. If a checkpoint exists, loads it. @@ -314,7 +314,7 @@ class Checkpoint_Dataset(_Wrapper_Dataset): def __init__( self, - dataset: _Stateful_Dataset, + dataset: _StatefulDataset, load_path: str, interval: int, steps_per_batch: int = 1, @@ -381,9 +381,9 @@ def load_from_path(self, path: str): self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}") -class Preload_Buffer_Dataset(_Wrapper_Dataset): +class PreloadBufferDataset(_WrapperDataset): """ - Wrapper for a Stateful_Dataset that implements data shuffling via a single in/out buffer. + Wrapper for a StatefulDataset that implements data shuffling via a single in/out buffer. Fills buffer two at a time, up to desired size, then switches to one at a time to maintain size. Passes randomly sampled outputs one by one. Ensures local mixing of data without relying on sliding windows or shuffling of large buffers. @@ -393,13 +393,13 @@ class Preload_Buffer_Dataset(_Wrapper_Dataset): ... Args ---- - dataset : _Stateful_Dataset + dataset : _StatefulDataset Fully instantiated dataset window_size : int Max size of input/output buffer """ - def __init__(self, dataset: _Stateful_Dataset, window_size: int): + def __init__(self, dataset: _StatefulDataset, window_size: int): super().__init__(dataset) assert ( window_size > 1 @@ -459,9 +459,9 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts -class Buffer_Dataset(_Wrapper_Dataset): +class BufferDataset(_WrapperDataset): """ - Wrapper for a _Stateful_Dataset that takes in sequences of varying lengths, and packs/pads them + Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them into sequences of desired length. Input sequences are packed greedily until the buffer would otherwise overrun, then remaining values are filled depending on initialization flags. Also injects BOS/EOS into the packed output sequence if desired, and if BOS/EOS tokens are @@ -469,7 +469,7 @@ class Buffer_Dataset(_Wrapper_Dataset): ... Args ---- - dataset : _Stateful_Dataset + dataset : _StatefulDataset Fully instantiated dataset seq_len : int The desired sequence length @@ -485,7 +485,7 @@ class Buffer_Dataset(_Wrapper_Dataset): def __init__( self, - dataset: _Stateful_Dataset, + dataset: _StatefulDataset, seq_len: int, pack_hard: bool, bos_token=None, @@ -557,7 +557,7 @@ def __iter__(self): yield out -class Streaming_Doc_Dataset(_Stateful_Dataset): +class StreamingDocDataset(_StatefulDataset): """ The base distributed dataset for loading sequences/documents from pyarrow shards. Pyarrow shard files are expected to hold multiple recordBatches, where each recordBatch has a "tokens" @@ -576,8 +576,8 @@ class Streaming_Doc_Dataset(_Stateful_Dataset): Shards are thus pulled no more than once per epoch. Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents. - Streaming_Doc_Dataset grabs files from a flat directory representing a single dataset. - For percentage-based sampling of multiple subdatasets, see Sampling_Dataset. + StreamingDocDataset grabs files from a flat directory representing a single dataset. + For percentage-based sampling of multiple subdatasets, see SamplingDataset. ... Args ---- @@ -590,7 +590,7 @@ class Streaming_Doc_Dataset(_Stateful_Dataset): Total number of workers delimiter_token : Any Token used to indicate sequence/document breaks. Type should match data type. Required for downstream - sampling logic (can be removed later via PreProcess_Dataset if needed). + sampling logic (can be removed later via PreProcessDataset if needed). bos_token : Any | None Optional token used to indicate sequence/document start. Type should match data type. strip_tokens : set[Any] @@ -886,7 +886,7 @@ def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( self.load_worldsize == self.worldsize - ), f"Streaming_Doc_Dataset does not support rescaling (ckp size: {self.load_worldsize}, world size: {self.worldsize}). Please use a Scalable_Shard_Dataset." + ), f"StreamingDocDataset does not support rescaling (ckp size: {self.load_worldsize}, world size: {self.worldsize}). Please use a ScalableShardDataset." d = self.dataset out = super().load_state_dict(state_dicts, sharded_input) assert ( @@ -895,23 +895,23 @@ def load_state_dict(self, state_dicts, sharded_input=False): return out -class Sampling_Dataset(_Wrapper_Dataset): +class SamplingDataset(_WrapperDataset): """ - A _Wrapper_Dataset implementing percentage-based sampling: weights can be floats, and the + A _WrapperDataset implementing percentage-based sampling: weights can be floats, and the number of tokens seen from each subdataset will match those weights as closely as possible. - This is accomplished by maintaining a _Stateful_Dataset for each subdataset, and tracking + This is accomplished by maintaining a _StatefulDataset for each subdataset, and tracking the number of tokens emitted by each. Whichever loader is furthest from its target will be the next to pass a document. All args except for dataset_type, datasets, weights and delimiter are pass-through args for - the component _Stateful_Datasets and are documented in the appropriate classes. + the component _StatefulDatasets and are documented in the appropriate classes. ... Args ---- datapath : str Absolute path to the dataset directory. Expects directory to contain subfolders with pyarrow shardfiles, and also a 'meta' folder with metadata csv file inside. - dataset : Scalable_Shard_Dataset | Streaming_Doc_Dataset + dataset : ScalableShardDataset | StreamingDocDataset Fully instantiated dataset. Cloned across desired subdatasets during setup. delimiter_token : Any Token used to indicate sequence/document breaks. Type should match data type. @@ -927,7 +927,7 @@ class Sampling_Dataset(_Wrapper_Dataset): def __init__( self, datapath: str, - dataset: Union[Scalable_Shard_Dataset, Streaming_Doc_Dataset], + dataset: Union[ScalableShardDataset, StreamingDocDataset], delimiter_token: Any, datasets=None, weights=None, @@ -1009,13 +1009,13 @@ def state_dict(self): d.state_dict() for d in self.data ] } - out.update(_Stateful_Dataset.state_dict(self)) + out.update(_StatefulDataset.state_dict(self)) return out def load_state_dict(self, state_dicts, sharded_input=False): self.setup() # Load stats - sharded_dicts = _Stateful_Dataset.load_state_dict( + sharded_dicts = _StatefulDataset.load_state_dict( self, state_dicts, sharded_input ) # Load sub-iterator states @@ -1032,16 +1032,16 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts -class Scalable_Shard_Dataset(_Wrapper_Dataset): +class ScalableShardDataset(_WrapperDataset): """ - A _Wrapper_Dataset implementing rescalability: loading from checkpoint into a different + A _WrapperDataset implementing rescalability: loading from checkpoint into a different number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small Streaming_Doc_Datasets, cloned from the + This is accomplished by maintaining a large number of small StreamingDocDatasets, cloned from the original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. ... Args ---- - dataset : Streaming_Doc_Dataset + dataset : StreamingDocDataset Fully instantiated dataset. Cloned into logical workers during setup fn. delimiter_token : any Token used to indicate sequence/document breaks. Type should match data type. @@ -1053,7 +1053,7 @@ class Scalable_Shard_Dataset(_Wrapper_Dataset): def __init__( self, - dataset: Streaming_Doc_Dataset, + dataset: StreamingDocDataset, delimiter_token: Any, n_logical_shards: int = 2048, verbose=False, @@ -1071,7 +1071,7 @@ def __init__( self.verbose = verbose # Fields to be populated during setup / subdataset setup - self.data: List[Streaming_Doc_Dataset] = [] + self.data: List[StreamingDocDataset] = [] self.logicals_owned: List[int] = [] self.n_logicals = 0 self.n_docs_remaining: List[int] = [] @@ -1147,11 +1147,11 @@ def state_dict(self): self.g_state = self.generator.get_state() # Recursive fetch self.logical_shard_states = [d.state_dict() for d in self.data] - return _Stateful_Dataset.state_dict(self) + return _StatefulDataset.state_dict(self) def load_state_dict(self, state_dicts, sharded_input=False): self.setup() - sharded_dicts = _Stateful_Dataset.load_state_dict( + sharded_dicts = _StatefulDataset.load_state_dict( self, state_dicts, sharded_input ) # Manually set generator state if it exists diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 48bf5b01..96f78d6e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -386,7 +386,7 @@ def basic_loader( bos_token=None, ): assert len(datasets) == 1, "Basic loader takes only 1 dataset" - return Streaming_Doc_Dataset( + return StreamingDocDataset( os.path.join(tmpdir.name, datasets[0]), rank, worldsize, @@ -399,7 +399,7 @@ def basic_loader( def basic_sampler( rank=0, worldsize=1, datasets=["dataset_1"], weights=[1], max_chunksize=1000 ): - return Sampling_Dataset( + return SamplingDataset( tmpdir.name, basic_loader(rank, worldsize, datasets[:1], max_chunksize, None), -1, @@ -417,7 +417,7 @@ def basic_scalable( bos_token=None, ): assert len(datasets) == 1, "Basic loader takes only 1 dataset" - return Scalable_Shard_Dataset( + return ScalableShardDataset( basic_loader(rank, worldsize, datasets, max_chunksize, bos_token), -1, n_logical_shards, @@ -432,7 +432,7 @@ def basic_sampler_scalable( max_chunksize=1000, n_logical_shards=7, ): - return Sampling_Dataset( + return SamplingDataset( tmpdir.name, basic_scalable( rank, worldsize, datasets[:1], max_chunksize, n_logical_shards, None @@ -514,7 +514,7 @@ def test_eos_bos_chunking(): def test_sampler_rates(): """ - A test for Sampling_Dataset with Streaming_ and Scalable_ subdatasets. + A test for SamplingDataset with Streaming_ and Scalable_ subdatasets. On the full dataset, with varying weights, on a single worker: verify that loaders pull subdatasets at regular intervals (verifying that they're regularly picking the most-underviewed subdataset at each step). """ @@ -568,7 +568,7 @@ def test_multi_reload_stress(): """ # Shard doc dataset d1 = lambda: [ - Streaming_Doc_Dataset( + StreamingDocDataset( os.path.join(tmpdir.name, "dataset_2"), i, 3, @@ -580,12 +580,12 @@ def test_multi_reload_stress(): multi_reload_stress_check(d1) # Scalable shard dataset - d2 = lambda x: [Scalable_Shard_Dataset(d, -1, n_logical_shards=15) for d in x] + d2 = lambda x: [ScalableShardDataset(d, -1, n_logical_shards=15) for d in x] multi_reload_stress_check(lambda: d2(d1())) # Sampling dataset d3 = lambda x: [ - Sampling_Dataset( + SamplingDataset( tmpdir.name, d, -1, @@ -601,16 +601,16 @@ def test_multi_reload_stress(): multi_reload_stress_check(d4) # Add buffer dataset - d5 = lambda x: [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in x] + d5 = lambda x: [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in x] multi_reload_stress_check(lambda: d5(d4())) # Add preload buffer dataset - d6 = lambda x: [Preload_Buffer_Dataset(d, 99) for d in x] + d6 = lambda x: [PreloadBufferDataset(d, 99) for d in x] # preload / sample / scale / doc pipeline multi_reload_stress_check(lambda: d6(d5(d4()))) -# SCALABLE_DATASET TESTS +# SCALABLEDATASET TESTS def test_scalable_partitioning(): @@ -742,7 +742,7 @@ def test_scalable_sampler_reload_scale(): ), f"Expected value {i*100+suf} not found in output set {ins}" -# BUFFER_DATASET TESTS +# BUFFERDATASET TESTS class RandCounter: @@ -767,7 +767,7 @@ def test_buffer_format(): for _ in range(100): # 100 trials of random length inputs base = RandCounter() - dataset = Buffer_Dataset(base, 100, pack_hard=True) + dataset = BufferDataset(base, 100, pack_hard=True) loader = iter(dataset) for _ in range(100): out = next(loader) @@ -781,7 +781,7 @@ def test_buffer_format(): # As above, but now with EOS tokens for _ in range(100): base = RandCounter() - dataset = Buffer_Dataset(base, 100, pack_hard=True, eos_token=-1) + dataset = BufferDataset(base, 100, pack_hard=True, eos_token=-1) loader = iter(dataset) for i in range(100): out = next(loader) @@ -796,7 +796,7 @@ def test_buffer_format(): # As above, but now with BOS tokens for _ in range(100): base = RandCounter() - dataset = Buffer_Dataset(base, 100, pack_hard=True, bos_token=-1) + dataset = BufferDataset(base, 100, pack_hard=True, bos_token=-1) loader = iter(dataset) for i in range(100): out = next(loader) @@ -816,7 +816,7 @@ def test_buffer_delimiter_overlap(): into the first slot in the next (and all subsequent) outputs. BOS should then refrain from adding. """ dataset = basic_loader(max_chunksize=101) - dataset = Buffer_Dataset(dataset, 101, pack_hard=True, bos_token=-1) + dataset = BufferDataset(dataset, 101, pack_hard=True, bos_token=-1) loader = iter(dataset) for _ in range(100): out = next(loader) @@ -829,7 +829,7 @@ def test_buffer_delimiter_overlap(): ), f"Final token {out[-1]} does not end in expected value 99" -# PRELOAD_BUFFER_DATASET TESTS +# PRELOADBUFFERDATASET TESTS class SteadyCounter: @@ -852,7 +852,7 @@ def test_preload_buffer_uniformity(): With underlying SteadyCounter and window size 200, take 1000 steps. Ensure 95% of values between 0 and 100 are emitted. """ - dataset = Preload_Buffer_Dataset(SteadyCounter(1), 200) + dataset = PreloadBufferDataset(SteadyCounter(1), 200) loader = iter(dataset) outs = [] @@ -864,7 +864,7 @@ def test_preload_buffer_uniformity(): assert len(outs) > 95, f"Only {len(outs)} values <100 detected" -# CHECKPOINT_DATASET TESTS +# CHECKPOINTDATASET TESTS def test_checkpoint_reload_match(): @@ -876,9 +876,9 @@ def test_checkpoint_reload_match(): basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17) for i in range(3) ] - datasets = [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets] + datasets = [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets] datasets = [ - Checkpoint_Dataset(x, os.path.join(tmpdir.name, "ckp_test"), 100, 2) + CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 100, 2) for x in datasets ] loaders = [ @@ -907,9 +907,9 @@ def test_checkpoint_reload_match(): basic_sampler(i, 3, ["dataset_1", "dataset_2"], [3, 5], max_chunksize=17) for i in range(3) ] - datasets2 = [Buffer_Dataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets2] + datasets2 = [BufferDataset(d, 73, pack_hard=True, bos_token=-1) for d in datasets2] datasets2 = [ - Checkpoint_Dataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2) + CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2) for x in datasets2 ] From 90ec624fdf183217b7b432826f007e29b403f57a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 25 Jul 2024 16:39:37 -0400 Subject: [PATCH 41/73] Move sampling back to end (sorry for the crazy diff) --- fms_fsdp/utils/dataset_utils.py | 262 ++++++++++++++++---------------- 1 file changed, 131 insertions(+), 131 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 44c1fc92..c15e3e9d 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -895,6 +895,137 @@ def load_state_dict(self, state_dicts, sharded_input=False): return out +class ScalableShardDataset(_WrapperDataset): + """ + A _WrapperDataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of small StreamingDocDatasets, cloned from the + original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. + ... + Args + ---- + dataset : StreamingDocDataset + Fully instantiated dataset. Cloned into logical workers during setup fn. + delimiter_token : any + Token used to indicate sequence/document breaks. Type should match data type. + n_logical_shards : int + Number of logical shards. Must be a multiple of world size. + verbose : bool + Track setup progress? + """ + + def __init__( + self, + dataset: StreamingDocDataset, + delimiter_token: Any, + n_logical_shards: int = 2048, + verbose=False, + ): + super().__init__(dataset) + assert ( + n_logical_shards % self.worldsize == 0 + ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert ( + n_logical_shards > 0 + ), f"n_logical_shards {n_logical_shards} must be a positive integer" + + self.total_shards = n_logical_shards + self.delimiter = delimiter_token + self.verbose = verbose + + # Fields to be populated during setup / subdataset setup + self.data: List[StreamingDocDataset] = [] + self.logicals_owned: List[int] = [] + self.n_logicals = 0 + self.n_docs_remaining: List[int] = [] + self.generator = None + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = None + self.logical_shard_states = None + self.g_state = None + + self.state_params = ["current_reader", "g_state"] + self.reshard_params = ["n_docs_remaining", "logical_shard_states"] + + def setup(self): + if not self.is_setup: + self.is_setup = True + n_logical_shards = self.total_shards + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + self.n_logicals = n_logical_shards // self.worldsize + assert len(self.logicals_owned) == self.n_logicals + + # Build logical shards + for i in range(self.n_logicals): + self.data.append(deepcopy(self.dataset)) + self.data[-1].worldsize = n_logical_shards + self.data[-1].load_worldsize = n_logical_shards + self.data[-1].rank = self.logicals_owned[i] + self.data[-1].datapath = self.datapath + self.data[-1].verbose = self.rank == 0 + if self.verbose: + logging.info( + f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + [d.setup() for d in self.data] + self.n_docs_remaining = [d._len for d in self.data] + + self.generator = torch.Generator().manual_seed(self.rank) + + def __iter__(self): + self.setup() + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + # Sample logical shard (or load from ckp) + if self.current_reader is not None: + ind = self.current_reader + else: + ind = torch.multinomial( + torch.tensor(self.n_docs_remaining, dtype=torch.float), + 1, + generator=self.generator, + ).item() + self.current_reader = ind + # Read doc + out = next(data[ind]) + while out[-1] != self.delimiter: + yield out + out = next(data[ind]) + # Update state to show we've finished the doc + self.current_reader = None + self.n_docs_remaining[ind] -= 1 + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) + # Return final piece of doc + yield out + + def state_dict(self): + self.setup() + # Write generator state manually + self.g_state = self.generator.get_state() + # Recursive fetch + self.logical_shard_states = [d.state_dict() for d in self.data] + return _StatefulDataset.state_dict(self) + + def load_state_dict(self, state_dicts, sharded_input=False): + self.setup() + sharded_dicts = _StatefulDataset.load_state_dict( + self, state_dicts, sharded_input + ) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + # Recursive set + for i in range(self.n_logicals): + self.data[i].load_state_dict([self.logical_shard_states[i]], True) + return sharded_dicts + + class SamplingDataset(_WrapperDataset): """ A _WrapperDataset implementing percentage-based sampling: weights can be floats, and the @@ -1030,134 +1161,3 @@ def load_state_dict(self, state_dicts, sharded_input=False): True, ) return sharded_dicts - - -class ScalableShardDataset(_WrapperDataset): - """ - A _WrapperDataset implementing rescalability: loading from checkpoint into a different - number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. - This is accomplished by maintaining a large number of small StreamingDocDatasets, cloned from the - original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. - ... - Args - ---- - dataset : StreamingDocDataset - Fully instantiated dataset. Cloned into logical workers during setup fn. - delimiter_token : any - Token used to indicate sequence/document breaks. Type should match data type. - n_logical_shards : int - Number of logical shards. Must be a multiple of world size. - verbose : bool - Track setup progress? - """ - - def __init__( - self, - dataset: StreamingDocDataset, - delimiter_token: Any, - n_logical_shards: int = 2048, - verbose=False, - ): - super().__init__(dataset) - assert ( - n_logical_shards % self.worldsize == 0 - ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" - assert ( - n_logical_shards > 0 - ), f"n_logical_shards {n_logical_shards} must be a positive integer" - - self.total_shards = n_logical_shards - self.delimiter = delimiter_token - self.verbose = verbose - - # Fields to be populated during setup / subdataset setup - self.data: List[StreamingDocDataset] = [] - self.logicals_owned: List[int] = [] - self.n_logicals = 0 - self.n_docs_remaining: List[int] = [] - self.generator = None - - # Position "state", used only for maintaining order when n_workers is unchanged - # For scaling up or down, logical position is meaningless, and reset - self.current_reader = None - self.logical_shard_states = None - self.g_state = None - - self.state_params = ["current_reader", "g_state"] - self.reshard_params = ["n_docs_remaining", "logical_shard_states"] - - def setup(self): - if not self.is_setup: - self.is_setup = True - n_logical_shards = self.total_shards - logicals = list(range(n_logical_shards)) - self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) - self.n_logicals = n_logical_shards // self.worldsize - assert len(self.logicals_owned) == self.n_logicals - - # Build logical shards - for i in range(self.n_logicals): - self.data.append(deepcopy(self.dataset)) - self.data[-1].worldsize = n_logical_shards - self.data[-1].load_worldsize = n_logical_shards - self.data[-1].rank = self.logicals_owned[i] - self.data[-1].datapath = self.datapath - self.data[-1].verbose = self.rank == 0 - if self.verbose: - logging.info( - f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" - ) - [d.setup() for d in self.data] - self.n_docs_remaining = [d._len for d in self.data] - - self.generator = torch.Generator().manual_seed(self.rank) - - def __iter__(self): - self.setup() - # Grab one doc at a time in random order - data = [iter(d) for d in self.data] - while True: - # Sample logical shard (or load from ckp) - if self.current_reader is not None: - ind = self.current_reader - else: - ind = torch.multinomial( - torch.tensor(self.n_docs_remaining, dtype=torch.float), - 1, - generator=self.generator, - ).item() - self.current_reader = ind - # Read doc - out = next(data[ind]) - while out[-1] != self.delimiter: - yield out - out = next(data[ind]) - # Update state to show we've finished the doc - self.current_reader = None - self.n_docs_remaining[ind] -= 1 - if sum(self.n_docs_remaining) == 0: - self.n_docs_remaining = [d._len for d in self.data] - self.generator.manual_seed(self.rank) - # Return final piece of doc - yield out - - def state_dict(self): - self.setup() - # Write generator state manually - self.g_state = self.generator.get_state() - # Recursive fetch - self.logical_shard_states = [d.state_dict() for d in self.data] - return _StatefulDataset.state_dict(self) - - def load_state_dict(self, state_dicts, sharded_input=False): - self.setup() - sharded_dicts = _StatefulDataset.load_state_dict( - self, state_dicts, sharded_input - ) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - # Recursive set - for i in range(self.n_logicals): - self.data[i].load_state_dict([self.logical_shard_states[i]], True) - return sharded_dicts From 91b84f9edb542f02be2ae7ffd866963143ce09c3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 26 Jul 2024 15:34:46 -0400 Subject: [PATCH 42/73] Begin testing --- fms_fsdp/utils/dataset_utils.py | 39 +++++++++++++++++++++++++-------- main_training.py | 4 +++- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index c15e3e9d..d53072df 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -684,14 +684,6 @@ def setup(self): pathsplit = os.path.split(pathsplit[0]) pardir, dataset = pathsplit self.dataset = dataset - with open(os.path.join(pardir, "meta", countfiles[0]), "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: - key = fullpath[prefix:] - doc_counts[key] = int(row["documents"]) # Assemble document set owned by this worker: # listdir, assemble shardfraglist (ind -> shard, frag) @@ -711,11 +703,40 @@ def setup(self): for i in range(start_frag, end_frag) ] + # Assemble length of each owned shard file + + countfiles = [ + x + for x in os.listdir(os.path.join(pardir, "meta")) + if "counts" in x and "csv" in x + ] + doc_counts = {} + if len(countfiles) > 0: + # Count file exists, use it + countpath = os.path.join(pardir, "meta", countfiles[0]) + with open(countpath, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find("/" + dataset) + 1 + if prefix > 0: + key = fullpath[prefix + len(dataset) + 1 :] + doc_counts[key] = int(row["documents"]) + else: + # Count file does not exist, touch every owned file for length + unique_shardfiles = set(shard for shard, frag in shardfrags) + doc_counts = { + shard: pa.ipc.open_file( + pa.memory_map(os.path.join(datapath, shard)) + ).num_record_batches + for shard in unique_shardfiles + } + # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): ndocs = -1 docset = {} # shardid -> (min docid, max docid) for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[os.path.join(dataset, shard)] + ndocs = doc_counts[shard] doc_start = (ndocs * frag) // self.worldsize doc_end = ( ndocs * frag + ndocs diff --git a/main_training.py b/main_training.py index bae6dbad..c7835576 100644 --- a/main_training.py +++ b/main_training.py @@ -1,5 +1,6 @@ import math import os +import time import fire import torch @@ -74,9 +75,10 @@ def main(**kwargs): if not cfg.use_dummy_dataset: train_loader = get_data_loader(cfg, rank, world_size) else: + start = time.time() train_loader = get_dummy_loader(cfg, rank, world_size) if rank == 0: - print("Datasets constructed!") + print("Datasets constructed!", time.time()-start) # FSDP model = FSDP( From 6dba5193566ffad9720cbd398cddb1f344b34e06 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 26 Jul 2024 15:39:16 -0400 Subject: [PATCH 43/73] Quit blocking on missing countfile --- fms_fsdp/utils/dataset_utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index d53072df..e6f87b00 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -670,16 +670,8 @@ def setup(self): if not self.is_setup: datapath = self.datapath self.is_setup = True - - # Gather per-file document counts from metadata count file(s) - countfiles = [ - x - for x in os.listdir(os.path.join(os.path.dirname(datapath), "meta")) - if "counts" in x and "csv" in x - ] - assert len(countfiles) == 1 - doc_counts = {} pathsplit = (datapath, "") + # May take an extra round to account for any trailing slashes while len(pathsplit[1]) == 0: pathsplit = os.path.split(pathsplit[0]) pardir, dataset = pathsplit From d96a78f357a406d94fc32621fe4fe4d4a077dd57 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 26 Jul 2024 15:42:51 -0400 Subject: [PATCH 44/73] Time tracking non conditional --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index c7835576..9ad1c1fb 100644 --- a/main_training.py +++ b/main_training.py @@ -72,10 +72,10 @@ def main(**kwargs): # get data loader if rank == 0: print("Constructing datasets...") + start = time.time() if not cfg.use_dummy_dataset: train_loader = get_data_loader(cfg, rank, world_size) else: - start = time.time() train_loader = get_dummy_loader(cfg, rank, world_size) if rank == 0: print("Datasets constructed!", time.time()-start) From 10038d37960a1c388962c016929f8c6c2d3c2b38 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 26 Jul 2024 15:49:37 -0400 Subject: [PATCH 45/73] Build dataset synchronously for timing testing --- fms_fsdp/utils/dataloader_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index b9e4e772..7b7288c7 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -115,7 +115,7 @@ def causal_lm(data_seq, prompt_len=0): cfg.batch_size, cfg.ckpt_save_path, ) - return torch.utils.data.DataLoader(data, num_workers=1, batch_size=cfg.batch_size) + return torch.utils.data.DataLoader(data, num_workers=0, batch_size=cfg.batch_size) def parse_data_args(datas, weights): From afbb2dd59a3dd7ad25ede84ba4a8f9fe0068f05f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 26 Jul 2024 15:57:35 -0400 Subject: [PATCH 46/73] Remove timing stuff --- fms_fsdp/utils/dataloader_utils.py | 2 +- main_training.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 7b7288c7..b9e4e772 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -115,7 +115,7 @@ def causal_lm(data_seq, prompt_len=0): cfg.batch_size, cfg.ckpt_save_path, ) - return torch.utils.data.DataLoader(data, num_workers=0, batch_size=cfg.batch_size) + return torch.utils.data.DataLoader(data, num_workers=1, batch_size=cfg.batch_size) def parse_data_args(datas, weights): diff --git a/main_training.py b/main_training.py index 9ad1c1fb..bae6dbad 100644 --- a/main_training.py +++ b/main_training.py @@ -1,6 +1,5 @@ import math import os -import time import fire import torch @@ -72,13 +71,12 @@ def main(**kwargs): # get data loader if rank == 0: print("Constructing datasets...") - start = time.time() if not cfg.use_dummy_dataset: train_loader = get_data_loader(cfg, rank, world_size) else: train_loader = get_dummy_loader(cfg, rank, world_size) if rank == 0: - print("Datasets constructed!", time.time()-start) + print("Datasets constructed!") # FSDP model = FSDP( From 5a1032e6f970f435f1068e8f9077fe93fe294378 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 12:32:50 -0400 Subject: [PATCH 47/73] Defer ckp loading to setup (post-rank/path adjustment) --- fms_fsdp/utils/dataset_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index c15e3e9d..d1133e84 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -328,12 +328,18 @@ def __init__( save_path = load_path else: save_path = os.path.join(save_path, "checkpoints") + self.load_path = load_path self.path = save_path self.step = 0 self.ministep = 0 - self.load_from_path(load_path) + + def setup(self): + if not self.is_setup(): + super().setup() + self.load_from_path(self.load_path) def __iter__(self): + self.setup() dataset = iter(self.dataset) while True: yield next(dataset) From fd1bba15d077961fdb1eb7b9113d5a868c483be6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 12:36:32 -0400 Subject: [PATCH 48/73] Don't call bool --- fms_fsdp/utils/dataset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index d1133e84..17522f96 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -334,7 +334,7 @@ def __init__( self.ministep = 0 def setup(self): - if not self.is_setup(): + if not self.is_setup: super().setup() self.load_from_path(self.load_path) From 825304a861fc8198726e727b1309c21c283f7e5c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 13:13:05 -0400 Subject: [PATCH 49/73] Add validation, reset, step override to ckp dataset --- fms_fsdp/utils/dataset_utils.py | 72 ++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 17522f96..6823a3f7 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -310,6 +310,8 @@ class CheckpointDataset(_WrapperDataset): when a full batch is formed. Defaults to 1. save_path : optional[str] Absolute path to checkpoint save directory. Defaults to load_path. + reset_stepcount : bool + After loading an external checkpoint, start counting new checkpoints from zero, or from loaded step? """ def __init__( @@ -319,6 +321,7 @@ def __init__( interval: int, steps_per_batch: int = 1, save_path: str = "", + reset_stepcount: bool = False ): super().__init__(dataset) self.interval = interval @@ -330,6 +333,7 @@ def __init__( save_path = os.path.join(save_path, "checkpoints") self.load_path = load_path self.path = save_path + self.reset_stepcount = reset_stepcount self.step = 0 self.ministep = 0 @@ -355,6 +359,43 @@ def report(self, msg): if self.rank == 0: print(msg) + def _validate_ckp_path(self, path: str, verbose: bool = False): + """ + Interpret path to appropriate checkpoint. + If found, return modified path. + If not found, return empty string. + """ + # Does path exists, and if it exists, is it non-empty? + if not os.path.exists(path) or len(os.listdir(path)) == 0: + if verbose: + self.report( + f" Dataset: No valid checkpoint detected at {path}, dataset starting from scratch." + ) + return "" + # Check latest path + latest = os.path.join(path, get_latest(path)) + if verbose: + self.report(f"Checkpoint detected at {latest}") + # If item is not a folder, exit early + if os.path.isfile(latest): + if verbose: + self.report( + f" Dataset: Detected checkpoint {latest} is a single file with no dataset info." + + " Dataset starting from scratch." + ) + return "" + # If item is a folder, check that it contains shard files + if len([x for x in os.listdir(latest) if "loader" in x]) == 0: + if verbose: + self.report( + f" Dataset: Detected checkpoint {latest} exists but contains no dataset checkpoints." + + " Dataset starting from scratch." + ) + return "" + # If item is a folder, get the step count + self.step = int(latest.split("_")[-2]) + return latest + def save_to_path(self, path: str): self.report(f"Saving dataset to {path}") start = time.time() @@ -364,26 +405,21 @@ def save_to_path(self, path: str): ) def load_from_path(self, path: str): - # If path does not exist, or exists but is empty, exit early - if not os.path.exists(path) or len(os.listdir(path)) == 0: - self.report( - f"No valid checkpoint detected at {path}, dataset starting from scratch." - ) - return - # Grab latest item in path - latest = os.path.join(path, get_latest(path)) - self.report(f"Dataset checkpoint detected at {latest}") - # If item is not a folder, exit early - if os.path.isfile(latest): - self.report( - f"Checkpoint exists but contains no dataset! Dataset starting from scratch." - ) - return - # If item is a folder, get the step count - self.step = int(latest.split("_")[-2]) + save_path = self._validate_ckp_path(self.path, False) + if len(save_path) > 0: + self.report(f" Dataset: Detected a checkpoint in the save directory {save_path}. Restoring from this checkpoint.") + path = save_path + else: + load_path = self._validate_ckp_path(self.load_path, True) + if len(load_path) == 0: + return + else: + path = load_path + if self.reset_stepcount: + self.step = 0 # Proceed start = time.time() - self.dataset.load_from_path(latest) + self.dataset.load_from_path(path) self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}") From 77d46c60f92f1a656d857d3264248ce8b8c1842e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 13:16:27 -0400 Subject: [PATCH 50/73] Update ckpdataset test to setup() before check --- tests/test_datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 96f78d6e..d7b9cec7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -912,6 +912,7 @@ def test_checkpoint_reload_match(): CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2) for x in datasets2 ] + [d.setup() for d in datasets2] # Assert checkpoints have loaded correctly for d in datasets2: From d3571b1e74e984b3d393ad571f1de7c377c82638 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 13:27:37 -0400 Subject: [PATCH 51/73] Add support for fresh data at user lvl --- fms_fsdp/config/training.py | 3 +++ fms_fsdp/utils/dataloader_utils.py | 3 ++- fms_fsdp/utils/dataset_utils.py | 8 +++++--- tests/test_datasets.py | 2 ++ 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 6b1f3888..8aea8f17 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -36,7 +36,10 @@ class train_config: learning_rate: float = 3e-4 grad_clip_thresh: float = 1.0 seed: int = 2023 + + # continued training spec reset_stepcount: bool = False + new_dataset: bool = False # profiling use_profiler: bool = False diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index b9e4e772..1c4aa984 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -110,10 +110,11 @@ def causal_lm(data_seq, prompt_len=0): # Enable auto-saving data = CheckpointDataset( data, - cfg.ckpt_load_path, + cfg.ckpt_load_path if not cfg.new_dataset else cfg.ckpt_save_path, cfg.checkpoint_interval, cfg.batch_size, cfg.ckpt_save_path, + cfg.reset_stepcount, ) return torch.utils.data.DataLoader(data, num_workers=1, batch_size=cfg.batch_size) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 6823a3f7..9e2c079f 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -321,7 +321,7 @@ def __init__( interval: int, steps_per_batch: int = 1, save_path: str = "", - reset_stepcount: bool = False + reset_stepcount: bool = False, ): super().__init__(dataset) self.interval = interval @@ -362,7 +362,7 @@ def report(self, msg): def _validate_ckp_path(self, path: str, verbose: bool = False): """ Interpret path to appropriate checkpoint. - If found, return modified path. + If found, return modified path. If not found, return empty string. """ # Does path exists, and if it exists, is it non-empty? @@ -407,7 +407,9 @@ def save_to_path(self, path: str): def load_from_path(self, path: str): save_path = self._validate_ckp_path(self.path, False) if len(save_path) > 0: - self.report(f" Dataset: Detected a checkpoint in the save directory {save_path}. Restoring from this checkpoint.") + self.report( + f" Dataset: Detected a checkpoint in the save directory {save_path}. Restoring from this checkpoint." + ) path = save_path else: load_path = self._validate_ckp_path(self.load_path, True) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 96f78d6e..fe04826c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -913,6 +913,8 @@ def test_checkpoint_reload_match(): for x in datasets2 ] + [d.setup() for d in datasets2] + # Assert checkpoints have loaded correctly for d in datasets2: assert d.step == 100, f"Expected to load back to step 100, got {d.step}" From d46c7c0faa4e92e601185963227053171a5b95cf Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 13:56:23 -0400 Subject: [PATCH 52/73] Remove reset_stepcount, always do so when loading externally --- fms_fsdp/config/training.py | 3 +-- fms_fsdp/utils/checkpointing_utils.py | 11 ++++++----- fms_fsdp/utils/dataloader_utils.py | 3 +-- fms_fsdp/utils/dataset_utils.py | 8 ++------ main_training.py | 4 ++-- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 8aea8f17..97f1b21f 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -38,8 +38,7 @@ class train_config: seed: int = 2023 # continued training spec - reset_stepcount: bool = False - new_dataset: bool = False + resuming_dataset: bool = False # profiling use_profiler: bool = False diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 41dd8e2d..5381dc9d 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -178,15 +178,16 @@ def load( Strict determines whether to use strict loading or not FOR SINGLEFILE LOADING ONLY. Returns model, optimizer, dataloader, current step, and current tokens seen. """ + is_resuming = False if self._validate_ckp_path(self.ckp_path) is not None: path = self.ckp_path - reset_stepcount = False + is_resuming = True load_path = self._validate_ckp_path(path) if load_path is None: self.report( f"No valid checkpoint detected at {path}, starting from scratch." ) - return model, optimizer, dataloader, 0, 0 + return model, optimizer, dataloader, 0, 0, False else: self.report(f"Prior checkpoint {load_path} detected.") model_load_time = time.time() @@ -198,7 +199,7 @@ def load( f"Checkpoint {load_path} is a single-file checkpoint containing only a model. Optimizer and dataloader are from scratch.", model_load_time=time.time() - model_load_time, ) - return model, optimizer, dataloader, 0, 0 + return model, optimizer, dataloader, 0, 0, is_resuming else: # Load model with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): @@ -215,7 +216,7 @@ def load( step = 0 ntok = 0 # Load metadata - if not reset_stepcount: + if is_resuming: metadata = torch.load(os.path.join(load_path, "metadata.pth")) step = metadata.get("step", 0) ntok = metadata.get("tokens_seen", 0) @@ -243,7 +244,7 @@ def load( self.report(dataset_load_time=time.time() - data_load_time) else: self.report("Skipping dataset load, no dataloader provided.") - return model, optimizer, dataloader, step, ntok + return model, optimizer, dataloader, step, ntok, is_resuming def save( self, diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 1c4aa984..43f72d06 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -110,11 +110,10 @@ def causal_lm(data_seq, prompt_len=0): # Enable auto-saving data = CheckpointDataset( data, - cfg.ckpt_load_path if not cfg.new_dataset else cfg.ckpt_save_path, + cfg.ckpt_load_path if cfg.resuming_dataset else cfg.ckpt_save_path, cfg.checkpoint_interval, cfg.batch_size, cfg.ckpt_save_path, - cfg.reset_stepcount, ) return torch.utils.data.DataLoader(data, num_workers=1, batch_size=cfg.batch_size) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 9e2c079f..adf19b72 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -310,8 +310,6 @@ class CheckpointDataset(_WrapperDataset): when a full batch is formed. Defaults to 1. save_path : optional[str] Absolute path to checkpoint save directory. Defaults to load_path. - reset_stepcount : bool - After loading an external checkpoint, start counting new checkpoints from zero, or from loaded step? """ def __init__( @@ -321,7 +319,6 @@ def __init__( interval: int, steps_per_batch: int = 1, save_path: str = "", - reset_stepcount: bool = False, ): super().__init__(dataset) self.interval = interval @@ -333,7 +330,6 @@ def __init__( save_path = os.path.join(save_path, "checkpoints") self.load_path = load_path self.path = save_path - self.reset_stepcount = reset_stepcount self.step = 0 self.ministep = 0 @@ -417,8 +413,8 @@ def load_from_path(self, path: str): return else: path = load_path - if self.reset_stepcount: - self.step = 0 + # When loading from external ckp, always reset step count + self.step = 0 # Proceed start = time.time() self.dataset.load_from_path(path) diff --git a/main_training.py b/main_training.py index bae6dbad..b6b27c69 100644 --- a/main_training.py +++ b/main_training.py @@ -118,7 +118,7 @@ def main(**kwargs): checkpointer = Checkpointer( cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank ) - model, optimizer, _, start_step, tokens_seen = checkpointer.load( + model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load( model, optimizer, None, @@ -127,7 +127,7 @@ def main(**kwargs): else cfg.ckpt_load_path, strict=False, ) - if cfg.reset_stepcount: + if not is_resuming: start_step = 0 # Override loaded optim hyperparams with the current values for g in optimizer.param_groups: From e391de54782b4d38541114ab606fe17c93ba2e9b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 14:31:35 -0400 Subject: [PATCH 53/73] Correct impl of setup() in ckptdataset --- fms_fsdp/utils/dataset_utils.py | 8 +++++++- tests/test_datasets.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index e6f87b00..a508c313 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -328,12 +328,18 @@ def __init__( save_path = load_path else: save_path = os.path.join(save_path, "checkpoints") + self.load_path = load_path self.path = save_path self.step = 0 self.ministep = 0 - self.load_from_path(load_path) + + def setup(self): + if not self.is_setup: + super().setup() + self.load_from_path(self.load_path) def __iter__(self): + self.setup() dataset = iter(self.dataset) while True: yield next(dataset) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 96f78d6e..d7b9cec7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -912,6 +912,7 @@ def test_checkpoint_reload_match(): CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2) for x in datasets2 ] + [d.setup() for d in datasets2] # Assert checkpoints have loaded correctly for d in datasets2: From 3dedcd4d593174e734cd9e289bd57b3537162090 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 26 Jul 2024 17:35:01 -0400 Subject: [PATCH 54/73] Add filehandlers and support for HF parquet (testing needed) Signed-off-by: Davis Wertheimer --- fms_fsdp/config/training.py | 2 + fms_fsdp/utils/dataloader_utils.py | 16 ++++ fms_fsdp/utils/dataset_utils.py | 124 ++++++++++++++++++++++++++--- tests/test_datasets.py | 1 + 4 files changed, 130 insertions(+), 13 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 6b1f3888..8b26a389 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -12,6 +12,8 @@ class train_config: # dataset and dataloader use_dummy_dataset: bool = False data_path: str = "/fsx/data" + file_type: str = "arrow" + tokenizer_path: str = "/fsx/tokenizer" datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100" seq_length: int = 4096 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index b9e4e772..fdb82305 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -1,8 +1,10 @@ import torch from fms_fsdp.utils.dataset_utils import ( + ArrowHandler, BufferDataset, CheckpointDataset, + ParquetHandler, PreloadBufferDataset, PreprocessDataset, SamplingDataset, @@ -11,6 +13,12 @@ ) +_handler_map = { + "arrow": ArrowHandler, + "hf_parquet": ParquetHandler, +} + + def get_dummy_loader(cfg, rank, world_size): """ A simple dummy dataloader yielding incrementing vocab indices in an infinite loop @@ -69,11 +77,19 @@ def causal_lm(data_seq, prompt_len=0): int(x.strip()) for x in cfg.strip_tokens.split(",") if len(x.strip()) > 0 ] droplist = droplist + [cfg.bos_token, cfg.eos_token, cfg.bol_token, cfg.eol_token] + assert ( + cfg.file_type in _handler_map + ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" + if cfg.file_type == "hf_parquet": + filehandler = ParquetHandler(cfg.tokenizer_path) + else: + filehandler = _handler_map[cfg.file_type]() # Base reader layer data = StreamingDocDataset( cfg.data_path, rank, world_size, + filehandler, cfg.eos_token, bos_token=cfg.bos_token, strip_tokens=set(droplist), diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index c15e3e9d..a59ac40a 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -8,8 +8,10 @@ from typing import Any, Callable, List, Optional, Set, Type, Union import pyarrow as pa +import pyarrow.parquet as pq import torch import torch.utils.data as data +from transformers import AutoTokenizer from fms_fsdp.utils.checkpointing_utils import get_latest @@ -263,6 +265,107 @@ def state_dict(self): return out +#### ------------------------- FILE READERS ------------------------- #### + + +class _ShardFileHandler: + """ + Stub for shard file readers of different formats. + Must implement open, length, indexing, and slicing functions. + """ + + def open(self, path: str): + """ + Open the file, to be indexed via self.get() method. + Avoid reading entire multi-Gb files when possible! + """ + raise NotImplementedError + + def length(self, path: str): + """ + Calculate the number of documents in the given file. + Avoid reading entire multi-Gb files when possible! + """ + raise NotImplementedError + + def get(self, reader, index: int, drop_tokens: Set): + """ + Given the output of self.open() and an index, return the document at that index. + Then, remove the first and/or last items if they appear in drop_tokens. + Try to avoid reading entire documents at a time in case of long documents, + but this is less important than avoiding reading entire files as above. + Output must support len(). + """ + raise NotImplementedError + + def slice(self, doc, index: int, n_pull: int) -> List: + """ + Given a long document, retrieve n_pull consecutive items starting from index. + Again, try to be memory-efficient when doing so, but efficiency in self.get() + and self.open() is far more important. + Must return a python list. + """ + raise NotImplementedError + + +class ArrowHandler(_ShardFileHandler): + """ + Reader for indexable, pre-tokenized PyArrow shard files. + A preferred format as we can load document chunks without having to ever pull + the entire document or shard file, allowing for graceful handling of large documents. + Non-standard data format, though. + """ + + def open(self, path: str): + return pa.ipc.open_file(pa.memory_map(path)) + + def length(self, path: str): + return self.open(path).num_record_batches + + def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): + doc = reader.get_batch(index)["tokens"] + if doc[0].as_py() in drop_tokens: + doc = doc.slice(1, len(doc) - 1) + if doc[-1].as_py() in drop_tokens: + doc = doc.slice(0, len(doc) - 1) + return doc + + def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: + return doc.slice(index, n_pull).to_pylist() + + +class ParquetHandler(_ShardFileHandler): + """ + Reader for indexable parquet shard files, common in HF datasets. + Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens), + as we rely on parquet/pandas for efficient file reading, and tokenize entire documents + before getting/slicing. However, this is a standard and widely-used data format. + """ + + def __init__(self, tokenizer_path): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + def open(self, path: str): + return pq.read_pandas(path, columns=["text"])["text"] + + def length(self, path: str): + return pq.read_pandas(path, columns=[]).num_rows + + def get(self, reader, index: int, drop_tokens: Set): + doc = self.tokenizer(str(reader[index]))["input_ids"] + if doc[0] in drop_tokens: + doc = doc[1:] + if doc[-1] in drop_tokens: + doc = doc[:-1] + return doc + + def slice(self, doc: List, index: int, n_pull: int) -> List: + return doc[index : index + n_pull] + + +#### ------------------------- PIPELINE LAYERS ------------------------- #### + + class PreprocessDataset(_WrapperDataset): """ Wrapper for a _StatefulDataset that applies a specified preprocessing @@ -588,6 +691,8 @@ class StreamingDocDataset(_StatefulDataset): Current worker index worldsize : int Total number of workers + filereader : _ShardFileReader + A file reader handling specific data shard file formats delimiter_token : Any Token used to indicate sequence/document breaks. Type should match data type. Required for downstream sampling logic (can be removed later via PreProcessDataset if needed). @@ -613,6 +718,7 @@ def __init__( datapath: str, rank: int, worldsize: int, + filehandler: _ShardFileHandler, delimiter_token: Any, bos_token: Optional[Any] = None, strip_tokens: Optional[Set[Any]] = set(), @@ -624,6 +730,7 @@ def __init__( super().__init__(datapath, rank, worldsize) self.seed = seed self.datapath = datapath + self.filehandler = filehandler self.min_length = min_length assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer" self.chunksize = max_chunksize @@ -699,7 +806,6 @@ def setup(self): shard for shard in os.listdir(datapath) if os.path.isfile(os.path.join(datapath, shard)) - and "arrow" in os.path.join(datapath, shard) ] shards.sort() # Ensure consistent sharding across machines start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize @@ -773,7 +879,7 @@ def _get_reader(self, path, newpath, reader): del reader if self.verbose: logging.info(f"Worker {self.rank} opening new file {newpath}") - reader = pa.ipc.open_file(newpath) + reader = self.filehandler.open(newpath) path = newpath return path, reader @@ -789,7 +895,7 @@ def _construct_chunk(self, j, doc, n_chunks): n_pull -= 1 else: start_index -= 1 - chunk = doc.slice(start_index, n_pull).to_pylist() + chunk = self.filehandler.slice(doc, start_index, n_pull) self.tokens_seen += len(chunk) # Add bos/eos tokens if needed if self.bos is not None and j == 0: @@ -839,11 +945,7 @@ def __iter__(self): # Map id in range of owned docs to new (consistently) shuffled id doclcg = self._random_map_docid(docrange) docid = doclcg + mindoc - doc = reader.get_batch(docid)["tokens"] - if doc[0].as_py() in self.drop: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in self.drop: - doc = doc.slice(0, len(doc) - 1) + doc = self.filehandler.get(reader, docid, self.drop) doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 if doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) @@ -870,11 +972,7 @@ def __iter__(self): docid = self._random_map_docid(docrange) + mindoc newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) - doc = reader.get_batch(docid)["tokens"] - if doc[0].as_py() in self.drop: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in self.drop: - doc = doc.slice(0, len(doc) - 1) + doc = self.filehandler.get(reader, docid, self.drop) doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 if doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 96f78d6e..702b5a0f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -390,6 +390,7 @@ def basic_loader( os.path.join(tmpdir.name, datasets[0]), rank, worldsize, + ArrowHandler(), -1, max_chunksize=max_chunksize, bos_token=bos_token, From 3527a7aa3b567a821d2cf0ca084d22c21862799c Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 12:16:39 -0400 Subject: [PATCH 55/73] Fix tests and typing Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 4 ++-- tests/test_datasets.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index a59ac40a..14cabe47 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -5,13 +5,13 @@ import random import time from copy import deepcopy -from typing import Any, Callable, List, Optional, Set, Type, Union +from typing import Any, Callable, List, Optional, Set, Union import pyarrow as pa import pyarrow.parquet as pq import torch import torch.utils.data as data -from transformers import AutoTokenizer +from transformers import AutoTokenizer # type: ignore from fms_fsdp.utils.checkpointing_utils import get_latest diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 702b5a0f..f62c186b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -573,6 +573,7 @@ def test_multi_reload_stress(): os.path.join(tmpdir.name, "dataset_2"), i, 3, + ArrowHandler(), -1, max_chunksize=17, ) From c8bb20b4e3a91e0c6a7c902e68728dc42293b539 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 29 Jul 2024 14:33:56 -0400 Subject: [PATCH 56/73] Proper impl setup() in ckptdataset Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 8 +++++++- tests/test_datasets.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 14cabe47..5824f97b 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -431,12 +431,18 @@ def __init__( save_path = load_path else: save_path = os.path.join(save_path, "checkpoints") + self.load_path = load_path self.path = save_path self.step = 0 self.ministep = 0 - self.load_from_path(load_path) + + def setup(self): + if not self.is_setup: + super().setup() + self.load_from_path(self.load_path) def __iter__(self): + self.setup() dataset = iter(self.dataset) while True: yield next(dataset) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f62c186b..dfa56d6c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -914,6 +914,7 @@ def test_checkpoint_reload_match(): CheckpointDataset(x, os.path.join(tmpdir.name, "ckp_test"), 1000, 2) for x in datasets2 ] + [d.setup() for d in datasets2] # Assert checkpoints have loaded correctly for d in datasets2: From 4b7fe373bfdd23de9ff6f781cf0c56f5a4992592 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 30 Jul 2024 12:50:18 -0400 Subject: [PATCH 57/73] Add col_name support Signed-off-by: Davis Wertheimer --- fms_fsdp/config/training.py | 1 + fms_fsdp/utils/dataloader_utils.py | 4 ++-- fms_fsdp/utils/dataset_utils.py | 9 ++++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 8b26a389..ea8a9b90 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -13,6 +13,7 @@ class train_config: use_dummy_dataset: bool = False data_path: str = "/fsx/data" file_type: str = "arrow" + col_name: str = "tokens" tokenizer_path: str = "/fsx/tokenizer" datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange" weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100" diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index fdb82305..57860f45 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -81,9 +81,9 @@ def causal_lm(data_seq, prompt_len=0): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet": - filehandler = ParquetHandler(cfg.tokenizer_path) + filehandler = ParquetHandler(cfg.tokenizer_path, cfg.col_name) else: - filehandler = _handler_map[cfg.file_type]() + filehandler = _handler_map[cfg.file_type](cfg.col_name) # Base reader layer data = StreamingDocDataset( cfg.data_path, diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 5824f97b..6a197814 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -315,6 +315,8 @@ class ArrowHandler(_ShardFileHandler): the entire document or shard file, allowing for graceful handling of large documents. Non-standard data format, though. """ + def __init__(self, col_name: str = "tokens"): + self.col_name = col_name def open(self, path: str): return pa.ipc.open_file(pa.memory_map(path)) @@ -323,7 +325,7 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - doc = reader.get_batch(index)["tokens"] + doc = reader.get_batch(index)[self.col_name] if doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) if doc[-1].as_py() in drop_tokens: @@ -342,11 +344,12 @@ class ParquetHandler(_ShardFileHandler): before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path): + def __init__(self, tokenizer_path: str, col_name: str = "text"): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.col_name = col_name def open(self, path: str): - return pq.read_pandas(path, columns=["text"])["text"] + return pq.read_pandas(path, columns=[self.col_name])[self.col_name] def length(self, path: str): return pq.read_pandas(path, columns=[]).num_rows From b513855ba7ba5e82e86bd183996561bc81acb4fb Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 30 Jul 2024 12:52:16 -0400 Subject: [PATCH 58/73] Blacking Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 4e4d9931..52a7aa20 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -315,6 +315,7 @@ class ArrowHandler(_ShardFileHandler): the entire document or shard file, allowing for graceful handling of large documents. Non-standard data format, though. """ + def __init__(self, col_name: str = "tokens"): self.col_name = col_name From 9b89637849c62ec5bf6f5e7e542e4a0089342974 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 30 Jul 2024 16:02:31 -0400 Subject: [PATCH 59/73] Add multiprocess support Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 29 +++++++++++++++++++----- tests/test_datasets.py | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 52a7aa20..e1436649 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -88,6 +88,7 @@ def __init__( self.datapath = datapath self.rank = rank self.worldsize = worldsize + self.local_worldsize = -1 # Setup / loading flags self.load_worldsize = worldsize @@ -101,10 +102,21 @@ def setup(self): after init (for example, wrapping in a subdataset sampler layer, or copying to worker processes), so all rank- and datapth- dependent ops are deferred to this function. + Currently, this function simply adjusts rank/worldsize to account for + multiprocess dataloaders. """ if not self.is_setup: self.is_setup = True - pass + # Perform adjustment only if not already adjusted (i.e. via _WrapperDataset) + if self.local_worldsize == -1: + info = data.get_worker_info() + if info is None or info.num_workers == 1: + # No multi-worker rank adjustment needed + self.local_worldsize = 1 + else: + self.local_worldsize = info.num_workers + self.worldsize = self.worldsize * self.local_worldsize + self.rank = self.local_worldsize * self.rank + info.id def statename(self, x: str): # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline @@ -228,13 +240,16 @@ def setup(self): """ Datapath/rank/worldsize percolate upwards recursively during initialization, so now we project any desired changes downward, also recursively. + We also project local_worldsize downward to prevent subsequent layers from + further inflating the rank/worldsize - we only need to account for multiprocessing once! Any code overriding this function should still include this functionality. """ if not self.is_setup: - self.is_setup = True + super().setup() self.dataset.datapath = self.datapath self.dataset.rank = self.rank self.dataset.worldsize = self.worldsize + self.dataset.local_worldsize = self.local_worldsize self.dataset.setup() def load_state_dict(self, state_dicts, sharded_input=False): @@ -294,7 +309,7 @@ def get(self, reader, index: int, drop_tokens: Set): Then, remove the first and/or last items if they appear in drop_tokens. Try to avoid reading entire documents at a time in case of long documents, but this is less important than avoiding reading entire files as above. - Output must support len(). + Output must support len() method. """ raise NotImplementedError @@ -819,8 +834,8 @@ def setup(self): (rank assignment, data partitioning, shuffling) """ if not self.is_setup: + super().setup() datapath = self.datapath - self.is_setup = True pathsplit = (datapath, "") # May take an extra round to account for any trailing slashes while len(pathsplit[1]) == 0: @@ -1106,7 +1121,7 @@ def __init__( def setup(self): if not self.is_setup: - self.is_setup = True + _StatefulDataset.setup(self) n_logical_shards = self.total_shards logicals = list(range(n_logical_shards)) self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) @@ -1119,6 +1134,7 @@ def setup(self): self.data[-1].worldsize = n_logical_shards self.data[-1].load_worldsize = n_logical_shards self.data[-1].rank = self.logicals_owned[i] + self.data[-1].local_worldsize = 1 self.data[-1].datapath = self.datapath self.data[-1].verbose = self.rank == 0 if self.verbose: @@ -1250,7 +1266,7 @@ def __init__( def setup(self): if not self.is_setup: - self.is_setup = True + _StatefulDataset.setup(self) # Build subdataset iterators self.data = [] for i, d in enumerate(self.datasets): @@ -1258,6 +1274,7 @@ def setup(self): self.data[-1].datapath = os.path.join(self.datapath, d) self.data[-1].rank = self.rank self.data[-1].worldsize = self.worldsize + self.data[-1].local_worldsize = self.local_worldsize if self.verbose: logging.info( f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dfa56d6c..63685c88 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -375,6 +375,26 @@ def single_doc_bos_eos_check(loader, do_bos): ), f"Expected chunk 2 to follow chunk1, got {c1[-1]} and {c2[0]}" +def single_epoch_loader_worker_check(d, n_workers=0): + # For dataset_1 partitioned over logical shards / workers / ranks, + # check that every doc appears once per epoch + loaders = [ + torch.utils.data.DataLoader(x, num_workers=n_workers, batch_size=1) for x in d + ] + loaders = [iter(l) for l in loaders] + n_steps = 100 // len(loaders) + ins = [] + for _ in range(n_steps): + for l in loaders: + out = next(l) + ins.append(out[0].item()) + + for i in range(100): + assert ( + i * 100 in ins + ), f"Line starting with {i * 100} failed to appear in generated data: worldsize {len(loaders)}, n_workers {n_workers}" + + # BASE DATASET TESTS @@ -937,3 +957,23 @@ def test_checkpoint_reload_match(): ), f"Expected same output lengths, got {len(out)}, {len(targ)}" for i, (x, y) in enumerate(zip(out, targ)): assert x == y, f"Mismatch in position {i}: got {x}, {y}" + + +# MULTIPROCESS DATALOADER WORKER TESTS + + +def test_multiprocess_epoch(): + """ + Check that ScalableShardDataset partitions correctly over various worldsize / n_worker + combinations. A single epoch should contain each datapoint exactly once. + """ + n_workers = [0, 1, 5] + worldsizes = [2, 5] + logicals = [50, 100] + for n in n_workers: + for w in worldsizes: + for l in logicals: + d = [basic_scalable(i, w, n_logical_shards=l) for i in range(w)] + # Add a dummy wrapper (append some pads) to test correct wrapper behavior + d = [BufferDataset(x, 110, False, pad_token=-1) for x in d] + single_epoch_loader_worker_check(d, n) From 4f0bd6706842d3c8ffbb1d56731432741958f51b Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 30 Jul 2024 16:05:12 -0400 Subject: [PATCH 60/73] num_workers cfg flag, clearer failure msg Signed-off-by: Davis Wertheimer --- fms_fsdp/config/training.py | 1 + fms_fsdp/utils/dataloader_utils.py | 4 +++- fms_fsdp/utils/dataset_utils.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index e9a61c10..18cf8c79 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -25,6 +25,7 @@ class train_config: eol_token: Optional[int] = None strip_tokens: str = "" logical_shards: int = 1024 + num_workers: int = 1 # fsdp policies sharding_strategy: str = "hsdp" diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 87eb608c..314364d6 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -131,7 +131,9 @@ def causal_lm(data_seq, prompt_len=0): cfg.batch_size, cfg.ckpt_save_path, ) - return torch.utils.data.DataLoader(data, num_workers=1, batch_size=cfg.batch_size) + return torch.utils.data.DataLoader( + data, num_workers=cfg.num_workers, batch_size=cfg.batch_size + ) def parse_data_args(datas, weights): diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index e1436649..689e40a5 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -1126,7 +1126,9 @@ def setup(self): logicals = list(range(n_logical_shards)) self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) self.n_logicals = n_logical_shards // self.worldsize - assert len(self.logicals_owned) == self.n_logicals + assert ( + len(self.logicals_owned) == self.n_logicals + ), "(world size * num workers) does not divide logical shards evenly" # Build logical shards for i in range(self.n_logicals): From 1052c3e8c98b76584a09d62f029a088d41309f6a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 30 Jul 2024 16:18:36 -0400 Subject: [PATCH 61/73] Lower max n_workers in test to respect CI limits Signed-off-by: Davis Wertheimer --- tests/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 63685c88..96f512e4 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -967,7 +967,7 @@ def test_multiprocess_epoch(): Check that ScalableShardDataset partitions correctly over various worldsize / n_worker combinations. A single epoch should contain each datapoint exactly once. """ - n_workers = [0, 1, 5] + n_workers = [0, 1, 2] worldsizes = [2, 5] logicals = [50, 100] for n in n_workers: From 3e1e21ca061b26bb185ea8533322fd855d1ce94e Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 30 Jul 2024 16:23:06 -0400 Subject: [PATCH 62/73] Shorten test, fix shard ratios Signed-off-by: Davis Wertheimer --- tests/test_datasets.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 96f512e4..b4b8378a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -967,13 +967,11 @@ def test_multiprocess_epoch(): Check that ScalableShardDataset partitions correctly over various worldsize / n_worker combinations. A single epoch should contain each datapoint exactly once. """ - n_workers = [0, 1, 2] + n_workers = [0, 2] worldsizes = [2, 5] - logicals = [50, 100] for n in n_workers: for w in worldsizes: - for l in logicals: - d = [basic_scalable(i, w, n_logical_shards=l) for i in range(w)] - # Add a dummy wrapper (append some pads) to test correct wrapper behavior - d = [BufferDataset(x, 110, False, pad_token=-1) for x in d] - single_epoch_loader_worker_check(d, n) + d = [basic_scalable(i, w, n_logical_shards=20) for i in range(w)] + # Add a dummy wrapper (append some pads) to test correct wrapper behavior + d = [BufferDataset(x, 110, False, pad_token=-1) for x in d] + single_epoch_loader_worker_check(d, n) From aba623cec54d1977030387f031bd0fa8481e5dd8 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 31 Jul 2024 10:17:12 -0400 Subject: [PATCH 63/73] Skip counting countfiles if meta does not exist Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 52a7aa20..edad7f46 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -847,11 +847,13 @@ def setup(self): # Assemble length of each owned shard file - countfiles = [ - x - for x in os.listdir(os.path.join(pardir, "meta")) - if "counts" in x and "csv" in x - ] + countfiles = [] + if os.path.exists(os.path.join(pardir, "meta")): + countfiles = [ + x + for x in os.listdir(os.path.join(pardir, "meta")) + if "counts" in x and "csv" in x + ] doc_counts = {} if len(countfiles) > 0: # Count file exists, use it From 6a1f51d6c2ed625cf68902de8cb907365961342a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 31 Jul 2024 11:47:47 -0400 Subject: [PATCH 64/73] Correct no countfile / parquet combo handling Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index edad7f46..e8c79c46 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -870,9 +870,7 @@ def setup(self): # Count file does not exist, touch every owned file for length unique_shardfiles = set(shard for shard, frag in shardfrags) doc_counts = { - shard: pa.ipc.open_file( - pa.memory_map(os.path.join(datapath, shard)) - ).num_record_batches + shard: self.filehandler.length(os.path.join(datapath, shard)) for shard in unique_shardfiles } From 864bcd72a9873a05ed616526724410abfee17c88 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Fri, 2 Aug 2024 00:44:47 -0400 Subject: [PATCH 65/73] fix zero length doc --- fms_fsdp/utils/dataset_utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index e8c79c46..6dba7016 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -327,10 +327,11 @@ def length(self, path: str): def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): doc = reader.get_batch(index)[self.col_name] - if doc[0].as_py() in drop_tokens: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in drop_tokens: - doc = doc.slice(0, len(doc) - 1) + if len(doc) > 0: + if doc[0].as_py() in drop_tokens: + doc = doc.slice(1, len(doc) - 1) + if doc[-1].as_py() in drop_tokens: + doc = doc.slice(0, len(doc) - 1) return doc def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: @@ -357,10 +358,11 @@ def length(self, path: str): def get(self, reader, index: int, drop_tokens: Set): doc = self.tokenizer(str(reader[index]))["input_ids"] - if doc[0] in drop_tokens: - doc = doc[1:] - if doc[-1] in drop_tokens: - doc = doc[:-1] + if len(doc) > 0: + if doc[0] in drop_tokens: + doc = doc[1:] + if doc[-1] in drop_tokens: + doc = doc[:-1] return doc def slice(self, doc: List, index: int, n_pull: int) -> List: @@ -1003,6 +1005,8 @@ def __iter__(self): doclcg = self._random_map_docid(docrange) docid = doclcg + mindoc doc = self.filehandler.get(reader, docid, self.drop) + if len(doc) == 0: + continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 if doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) @@ -1030,6 +1034,8 @@ def __iter__(self): newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = self.filehandler.get(reader, docid, self.drop) + if len(doc) == 0: + continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 if doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) From 18332a6316132e45a70533c912cdfedcfc2dfaff Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Mon, 5 Aug 2024 13:14:51 -0400 Subject: [PATCH 66/73] add annealing lr scheduler --- fms_fsdp/config/training.py | 1 + main_training.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index e9a61c10..ece2f720 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -36,6 +36,7 @@ class train_config: # training spec batch_size: int = 2 num_steps: int = 1000000 + training_stage: str = "initial" learning_rate: float = 3e-4 grad_clip_thresh: float = 1.0 seed: int = 2023 diff --git a/main_training.py b/main_training.py index b6b27c69..67cccee2 100644 --- a/main_training.py +++ b/main_training.py @@ -134,14 +134,17 @@ def main(**kwargs): g["initial_lr"] = cfg.learning_rate # LR schedule - warmup_interval = min(2000, cfg.num_steps // 20) - schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, - 0.1 - + 0.5 - * (1 - 0.1) - * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), - ) + if cfg.training_stage == "annealing": + schedule = lambda x: 1 - x / cfg.num_steps + else: + warmup_interval = min(2000, cfg.num_steps // 20) + schedule = lambda x: min( + 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + 0.1 + + 0.5 + * (1 - 0.1) + * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + ) scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) # profiler From bbcc2980798a847e7c7c04be75e65f567faf174d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 12 Aug 2024 15:08:46 -0400 Subject: [PATCH 67/73] Allow nonflat data dirs, impl walk rather than list Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/dataset_utils.py | 56 +++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 6a449529..fc70c60c 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -289,6 +289,13 @@ class _ShardFileHandler: Must implement open, length, indexing, and slicing functions. """ + def is_legal(self, filepath: str): + """ + Given a file path, determine if it qualifies for this handler. + Ideally does not involve opening the file. + """ + return os.path.isfile(filepath) + def open(self, path: str): """ Open the file, to be indexed via self.get() method. @@ -326,6 +333,11 @@ def slice(self, doc, index: int, n_pull: int) -> List: class ArrowHandler(_ShardFileHandler): """ Reader for indexable, pre-tokenized PyArrow shard files. + Pyarrow shard files are expected to hold multiple RecordBatches, + where each RecordBatch has a "tokens" field consisting of + a single token list (i.e. each document is a single sequence + under a "token" field, and the file is a list of such sequences). + A preferred format as we can load document chunks without having to ever pull the entire document or shard file, allowing for graceful handling of large documents. Non-standard data format, though. @@ -334,6 +346,9 @@ class ArrowHandler(_ShardFileHandler): def __init__(self, col_name: str = "tokens"): self.col_name = col_name + def is_legal(self, filepath: str): + return "arrow" in os.path.splitext(filepath)[1] + def open(self, path: str): return pa.ipc.open_file(pa.memory_map(path)) @@ -365,6 +380,9 @@ def __init__(self, tokenizer_path: str, col_name: str = "text"): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.col_name = col_name + def is_legal(self, filepath: str): + return "parquet" in os.path.splitext(filepath)[1] + def open(self, path: str): return pq.read_pandas(path, columns=[self.col_name])[self.col_name] @@ -723,16 +741,11 @@ def __iter__(self): class StreamingDocDataset(_StatefulDataset): """ - The base distributed dataset for loading sequences/documents from pyarrow shards. - Pyarrow shard files are expected to hold multiple recordBatches, where each recordBatch has a "tokens" - field consisting of a single token list. (i.e. each document is a single sequence under a "token" field, - and the file is a list of such sequences) - Relies on a compiled metadata file to fetch shardfile lengths, assumes file already exists in the parent directory, - and is in proper csv format (first row "dataset/filename,documents,tokens", subsequent rows these values). + The base distributed dataset for loading sequences/documents from file shards. For a single dataset directory, splits shard files into x=worldsize fragments and grabs a 1/n contiguous span of shard fragments (contiguous to limit file reads from cloud/disk). - Logs the number of documents owned from each shardfile, and relies on ZCG random bijection to + Logs the number of documents owned from each shardfile, and relies on LCG random bijection to map contiguous range of indices to shuffled, noncontiguous set of documents from each shard file. Shuffles the file list deterministically to hop from file to file. @@ -740,14 +753,19 @@ class StreamingDocDataset(_StatefulDataset): Shards are thus pulled no more than once per epoch. Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents. - StreamingDocDataset grabs files from a flat directory representing a single dataset. - For percentage-based sampling of multiple subdatasets, see SamplingDataset. + StreamingDocDataset grabs files from a directory representing a single dataset. + This directory need not be flat. + For percentage-based sampling over multiple such subdatasets, see SamplingDataset. + + When available in the parent directory, relies on a compiled metadata file to fetch shardfile lengths. + Expects csv file (first row "dataset/filename,documents,tokens", subsequent rows these values) under a 'meta' directory. + This can be removed in the future. ... Args ---- datapath : str - Absolute path to the dataset directory. Expects directory containing pyarrow shardfiles. - Parent directory should contain 'meta' folder with metadata csv file inside. + Absolute path to the dataset directory. Expects directory containing shardfiles. + Directory need not be flat. rank : int Current worker index worldsize : int @@ -765,7 +783,7 @@ class StreamingDocDataset(_StatefulDataset): seed : int The random seed for deterministic shuffling/sharding min_length : int - Sequences below this length are skipped + Documents below this length are skipped max_chunksize : int Maximum sequence length to return. Break long docs into chunks of this size or shorter. verbose : bool @@ -848,9 +866,10 @@ def setup(self): # Assemble document set owned by this worker: # listdir, assemble shardfraglist (ind -> shard, frag) shards = [ - shard - for shard in os.listdir(datapath) - if os.path.isfile(os.path.join(datapath, shard)) + os.path.join(root, name)[len(datapath) + 1 :] + for root, dirs, files in os.walk(datapath, topdown=False) + for name in files + if self.filehandler.is_legal(os.path.join(root, name)) ] shards.sort() # Ensure consistent sharding across machines start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize @@ -1212,15 +1231,12 @@ class SamplingDataset(_WrapperDataset): This is accomplished by maintaining a _StatefulDataset for each subdataset, and tracking the number of tokens emitted by each. Whichever loader is furthest from its target will be the next to pass a document. - - All args except for dataset_type, datasets, weights and delimiter are pass-through args for - the component _StatefulDatasets and are documented in the appropriate classes. ... Args ---- datapath : str - Absolute path to the dataset directory. Expects directory to contain subfolders with - pyarrow shardfiles, and also a 'meta' folder with metadata csv file inside. + Absolute path to the dataset directory. Expects directory to contain subfolders, + which in turn contain shard files. dataset : ScalableShardDataset | StreamingDocDataset Fully instantiated dataset. Cloned across desired subdatasets during setup. delimiter_token : Any From 68273290ee0bc74c6a6387bfe5559d4508166cb2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 12 Aug 2024 15:40:29 -0400 Subject: [PATCH 68/73] Update unit tests to include nonflat directory Signed-off-by: Davis Wertheimer --- tests/test_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b4b8378a..83b2426b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -20,6 +20,7 @@ def generate_sequential_multidata(): os.mkdir(os.path.join(tmpdir.name, "dataset_1")) os.mkdir(os.path.join(tmpdir.name, "dataset_2")) + os.mkdir(os.path.join(tmpdir.name, "dataset_2", "subfolder")) with pa.ipc.new_file( os.path.join(tmpdir.name, "dataset_1/fullshard.arrow"), schema ) as writer: @@ -35,7 +36,7 @@ def generate_sequential_multidata(): writer.write(pa.record_batch([out], schema=schema)) with pa.ipc.new_file( - os.path.join(tmpdir.name, "dataset_2/quartershard_2.arrow"), schema + os.path.join(tmpdir.name, "dataset_2/subfolder/quartershard_2.arrow"), schema ) as writer: for i in range(50): out = list(range(2500 + i * 50, 2500 + i * 50 + 50)) @@ -47,7 +48,7 @@ def generate_sequential_multidata(): f.write("dataset/filename,documents,tokens\n") f.write("/dataset_1/fullshard.arrow,100,10000\n") f.write("/dataset_2/quartershard_1.arrow,50,2500\n") - f.write("/dataset_2/quartershard_2.arrow,50,2500\n") + f.write("/dataset_2/subfolder/quartershard_2.arrow,50,2500\n") f.close() return tmpdir From 948cb2a97ebcbec066d880dce618f6fe917372f3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 14 Aug 2024 15:20:22 -0400 Subject: [PATCH 69/73] Informative error msg for dataloading an empty dir (#109) The `ScalableShardDataset` assumes that at least one of its sub-loaders owns at least one document, and when this is not the case, it errors out via the `torch.multinomial()` call. The resulting error message is not particularly clear or useful, so add a simple assert message to catch this case and inform the user that none of the workers were able to find any documents to load. --- fms_fsdp/utils/dataset_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index fc70c60c..f8996a28 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -1182,6 +1182,9 @@ def __iter__(self): if self.current_reader is not None: ind = self.current_reader else: + assert ( + sum(self.n_docs_remaining) > 0 + ), f"No documents detected in {self.datapath}" ind = torch.multinomial( torch.tensor(self.n_docs_remaining, dtype=torch.float), 1, From e36d5283589fd4e52ab9d891ce1740f92874a983 Mon Sep 17 00:00:00 2001 From: sahil suneja Date: Mon, 9 Sep 2024 19:17:19 +0000 Subject: [PATCH 70/73] speculator training code refresh Signed-off-by: sahil suneja Originallt authored by: Davis Wertheimer --- fms_fsdp/config/training.py | 13 + fms_fsdp/utils/checkpointing_utils.py | 40 +- fms_fsdp/utils/config_utils.py | 15 +- fms_fsdp/utils/dataloader_utils.py | 38 +- fms_fsdp/utils/train_utils.py | 15 +- requirements-speculator.txt | 2 + scripts/README_SPECULATOR.md | 45 ++ scripts/train_speculator.sh | 38 ++ speculator/train_speculator.py | 330 +++++++++++++++ speculator/train_speculator_utils.py | 570 ++++++++++++++++++++++++++ 10 files changed, 1079 insertions(+), 27 deletions(-) create mode 100644 requirements-speculator.txt create mode 100644 scripts/README_SPECULATOR.md create mode 100644 scripts/train_speculator.sh create mode 100644 speculator/train_speculator.py create mode 100644 speculator/train_speculator_utils.py diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 5d5b9e7d..1d072958 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -59,3 +59,16 @@ class train_config: # compile use_torch_compile: bool = True + + # speculator training + tp_size: int = 8 + model_arch: str = "embedllama" + model_path: str = "/path/to/model/" + n_speculator_heads: int = 3 + speculator_width: int = 4096 + speculator_tie_weights: bool = True + speculator_scale_input: bool = True + stage2_start_step: int = 15000 + stage2_prompt_length: int = 64 + stage2_batch_size: int = 96 + stage2_seq_length: int = 256 diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 5381dc9d..e146ac94 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -69,6 +69,8 @@ class Checkpointer: report_fn : Callable or None Optional function for reporting or logging status updates. Expected to handle arbitrary *args, **kwargs. Defaults to self._selective_print(). + model_auto_placement : bool + Optional; If True, auto detect GPU device to move model to, as set in device mesh init Methods ------- @@ -87,6 +89,7 @@ def __init__( rank, local_rank, report_fn=None, + model_auto_placement=False, ): self.max_ckps = n_to_save self.rank = rank @@ -96,6 +99,7 @@ def __init__( self.p_mode = parallel_mode assert parallel_mode in ["fsdp", "hsdp", "ddp"] self.report = self._selective_print if report_fn is None else report_fn + self.model_auto_placement = model_auto_placement def _selective_print(self, *args, **kwargs): if self.rank == 0: @@ -168,7 +172,14 @@ def _validate_ckp_path(self, path): return None def load( - self, model, optimizer, dataloader, path="", reset_stepcount=False, strict=True + self, + model, + optimizer, + dataloader, + path="", + reset_stepcount=False, + strict=True, + is_compiled=False, ): """ Handle checkpoint loading for model/optimizer/dataloader from given path, according to arguments. @@ -193,8 +204,18 @@ def load( model_load_time = time.time() if os.path.isfile(load_path): checkpoint_data = torch.load(load_path, map_location="cpu") - model.load_state_dict(checkpoint_data.get("model_state"), strict=strict) - model.to(self.local_rank) + if is_compiled: + model._orig_mod.load_state_dict( + checkpoint_data.get("model_state"), strict=strict + ) + else: + model.load_state_dict( + checkpoint_data.get("model_state"), strict=strict + ) + if self.model_auto_placement: + model.to("cuda") + else: + model.to(self.local_rank) self.report( f"Checkpoint {load_path} is a single-file checkpoint containing only a model. Optimizer and dataloader are from scratch.", model_load_time=time.time() - model_load_time, @@ -211,7 +232,10 @@ def load( planner=DefaultLoadPlanner(), ) model.load_state_dict(model_ckp["model_state"]) - model.to(self.local_rank) + if self.model_auto_placement: + model.to("cuda") + else: + model.to(self.local_rank) self.report(model_load_time=time.time() - model_load_time) step = 0 ntok = 0 @@ -240,7 +264,7 @@ def load( # Load dataset if dataloader is not None: data_load_time = time.time() - dataloader.dataset.load_from_path(load_path) + dataloader.dataset.load_from_path(path) self.report(dataset_load_time=time.time() - data_load_time) else: self.report("Skipping dataset load, no dataloader provided.") @@ -285,6 +309,7 @@ def save_single_file( self, step, model, + is_compiled=False, **kwargs, ): # Note: metadata kwargs cannot contain any of: @@ -296,7 +321,10 @@ def save_single_file( StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): - model_state = model.state_dict() + if is_compiled: + model_state = model._orig_mod.state_dict() + else: + model_state = model.state_dict() if self.rank == 0: metadata = kwargs metadata["step"] = step diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index c0389b12..da5c6b40 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -39,6 +39,8 @@ def get_model_config(model_variant): kvheads=8, nlayers=48, hidden_grow_factor=22016 / 8192, + max_expected_seq_len=16384, + rope_theta=1000000.0, ) elif model_variant == "llama2_13b": llama_config = LLaMAConfig( @@ -49,8 +51,8 @@ def get_model_config(model_variant): ) elif model_variant == "llama2_7b": llama_config = LLaMAConfig( - hidden_grow_factor=3, - kvheads=8, + hidden_grow_factor=11008 / 4096, + kvheads=32, ) elif model_variant == "llama2_1.4b": llama_config = LLaMAConfig( @@ -69,6 +71,7 @@ def get_model_config(model_variant): nlayers=32, hidden_grow_factor=3.5, max_expected_seq_len=8192, + rope_theta=500000.0, ) elif model_variant == "llama3_8b_4k": llama_config = LLaMAConfig( @@ -79,6 +82,7 @@ def get_model_config(model_variant): nlayers=32, hidden_grow_factor=3.5, max_expected_seq_len=4096, + rope_theta=500000.0, ) elif model_variant == "llama3_1.8b": llama_config = LLaMAConfig( @@ -89,6 +93,7 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=3.5, max_expected_seq_len=8192, + rope_theta=500000.0, ) elif model_variant == "llama3_1.8b_4k": llama_config = LLaMAConfig( @@ -99,6 +104,7 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=3.5, max_expected_seq_len=4096, + rope_theta=500000.0, ) elif model_variant == "llama3_3.2b": llama_config = LLaMAConfig( @@ -109,6 +115,7 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=8 / 3, max_expected_seq_len=8192, + rope_theta=500000.0, ) elif model_variant == "llama3_3.2b_4k": llama_config = LLaMAConfig( @@ -119,6 +126,7 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=8 / 3, max_expected_seq_len=4096, + rope_theta=500000.0, ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( @@ -129,6 +137,7 @@ def get_model_config(model_variant): nlayers=80, hidden_grow_factor=3.5, max_expected_seq_len=8192, + rope_theta=500000.0, ) elif model_variant == "llama3_70b_4k": llama_config = LLaMAConfig( @@ -139,6 +148,7 @@ def get_model_config(model_variant): nlayers=80, hidden_grow_factor=3.5, max_expected_seq_len=4096, + rope_theta=500000.0, ) elif model_variant == "llama3_194m_4k": llama_config = LLaMAConfig( @@ -147,6 +157,7 @@ def get_model_config(model_variant): nheads=8, nlayers=10, max_expected_seq_len=4096, + rope_theta=500000.0, ) else: raise ValueError(f"model variant {model_variant} not supported.") diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 314364d6..2faeffb7 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -19,6 +19,18 @@ } +def causal_lm(data_seq, prompt_len=1): + """ + Perform causal language modeling by right-shifting the input sequence. + Sets first prompt_len tokens to be ignored by the loss. + """ + data_seq = torch.tensor(data_seq, dtype=torch.int) + t = data_seq.clone()[1:] + data_seq = data_seq[:-1] + t[:prompt_len] = -100 + return data_seq, t + + def get_dummy_loader(cfg, rank, world_size): """ A simple dummy dataloader yielding incrementing vocab indices in an infinite loop @@ -43,7 +55,7 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size): +def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): """ Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. Assumes underlying data is sequences of integer values. @@ -56,21 +68,13 @@ def get_data_loader(cfg, rank, world_size): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. + postprocess : List[Callable] + Any task-specific postprocessing to apply before handing over data. Steps will apply in + the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ datasets, weights = parse_data_args(cfg.datasets, cfg.weights) - def causal_lm(data_seq, prompt_len=0): - """ - Perform causal language modeling by right-shifting the input sequence. - Sets first prompt_len tokens to be ignored by the loss. - """ - data_seq = torch.IntTensor(data_seq) - t = data_seq.clone()[1:] - data_seq = data_seq[:-1] - t[:prompt_len] = -100 - return data_seq, t - # Base streaming dataset. Returns doc chunks in sequence. # Implements dataset sampling and rescalability. droplist = [ @@ -114,15 +118,19 @@ def causal_lm(data_seq, prompt_len=0): # Wrap above dataset in packing logic to form constant-length lines. data = BufferDataset( data, - cfg.seq_length + 1, + cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, ) # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Split line into input and target for the CLM task. - data = PreprocessDataset(data, causal_lm) + + # Apply desired postprocessing steps in sequence + data = PreprocessDataset(data, torch.IntTensor) + for p in postprocess: + data = PreprocessDataset(data, p) + # Enable auto-saving data = CheckpointDataset( data, diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index 1a75bd44..ef421f6f 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -187,10 +187,7 @@ def setup_environ_flags(): os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) -def get_policies(cfg, rank, block): - """Get policies for mixed precision, wrapping, sharding, ac and param init function.""" - - # mixed precision +def get_mixed_precision_policy(cfg, rank): verify_bfloat_support = ( torch.version.cuda and torch.cuda.is_bf16_supported() @@ -198,6 +195,7 @@ def get_policies(cfg, rank, block): and dist.is_nccl_available() and nccl.version() >= (2, 10) ) + if cfg.mixed_precision: bf16_ready = verify_bfloat_support if bf16_ready: @@ -211,6 +209,15 @@ def get_policies(cfg, rank, block): else: mixed_precision_policy = None + return mixed_precision_policy + + +def get_policies(cfg, rank, block): + """Get policies for mixed precision, wrapping, sharding, ac and param init function.""" + + # mixed precision + mixed_precision_policy = get_mixed_precision_policy(cfg, rank) + # wrapping policy wrapping_policy = get_wrapper(block) diff --git a/requirements-speculator.txt b/requirements-speculator.txt new file mode 100644 index 00000000..4270aa6d --- /dev/null +++ b/requirements-speculator.txt @@ -0,0 +1,2 @@ +-r requirements.txt +fms-extras @ git+https://github.com/foundation-model-stack/fms-extras@main diff --git a/scripts/README_SPECULATOR.md b/scripts/README_SPECULATOR.md new file mode 100644 index 00000000..c1e850d2 --- /dev/null +++ b/scripts/README_SPECULATOR.md @@ -0,0 +1,45 @@ +### Following parameters are relevant for speculator training: + +- *model_arch*: architecture of the base model (one of: embedllama, embedmixtral, embedgpt_bigcode-- FMS implementations extending the base arch to also emit embedding vector together with the model output. See 'EmbedLLaMA' in train_spculator_utils.py) + +- *model_variant*: identifier with which a specific variant (e.g., 7b) is registered for the model architecture. See 'example model registrations' in train_spculator_utils.py. + +- *model_path*: path to dir containing base model weights + +- *ckpt_save_path*: path to dir for storing intermediate checkpoints during speculator training + +- *ckpt_load_path*: path to dir for loading intermediate speculator checkpoint to resume training + +- *sharding_strategy*: how to shard the model across process group: tp / fsdp / hsdp + +- *tp_size*: If loading base model using tensor parallel, no. of GPUs/ranks to split the model across + +- *seq_length*: sequence length of the base model + +- *batch_size*: batch size for stage 1 training for aligning speculator to base model input behavior + +- *report_interval*: no. of steps after which to report training stats + +- *checkpoint_interval*: no. of steps after which to save an intermediate speculator checkpoint + +- *num_steps*: total no. of speculator training steps (stage 1 + stage 2) + +- *stage2_start_step*: no. of steps after which to switch to stage 2 training + +- *stage2_batch_size*: batch size for stage 2 training for aligning speculator to base model output behavior + +- *n_speculator_heads*: no. of lookahead tokens to train the speculator for + +- *speculator_width*: embedding dimension of the speculator MLP + +- *use_torch_compile*: whether to compile base model and speculator-- may speed up training. + +- *learning_rate*: learning rate for speculator training + +- *seed*: random seed to use for training dataset shuffling + +- *data_path*: path to dir containing the training dataset. Expects directory to contain subfolders, which in turn contain shard files. + +- *datasets*: a list of subdatasets (e.g., commoncrawl, github, etc.) to draw from. If None, draws from all subfolders of data_path. + +- *weights*: list of weights reflecting the percentage of tokens to be used from each subdataset during training diff --git a/scripts/train_speculator.sh b/scripts/train_speculator.sh new file mode 100644 index 00000000..86a58dad --- /dev/null +++ b/scripts/train_speculator.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# On AWS, the EFA and OFI paths enable NCCL to use optimized networking. +export LD_LIBRARY_PATH=/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda:/usr/local/cuda/targets/x86_64-linux/lib/:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/lib:$LD_LIBRARY_PATH + +export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +MODEL_ARGS="\ +--model_path=/path/to/models/meta-llama/Llama-2-7b-hf +--model_arch=embedllama +--model_variant=7b +--ckpt_load_path=/path/to/checkpoints/llama2-7b +--ckpt_save_path=/path/to/checkpoints/llama2-7b +--logical_shards=768 +--sharding_strategy=hsdp +--seq_length=4096 +--batch_size=8 +--report_interval=10 +--checkpoint_interval=3000 +--num_steps=21000 +--stage2_start_step=15000 +--stage2_batch_size=96 +--n_speculator_heads=3 +--speculator_width=4096 +--use_torch_compile=False +--learning_rate=1e-3 +--seed=42 +--data_path=/path/to/dataset/ +--datasets="'dataset=commoncrawl'" +--weights="'1'" +" + +torchrun \ + --nproc_per_node=8 \ + speculator/train_speculator.py \ + ${MODEL_ARGS} + + diff --git a/speculator/train_speculator.py b/speculator/train_speculator.py new file mode 100644 index 00000000..4c33fd2c --- /dev/null +++ b/speculator/train_speculator.py @@ -0,0 +1,330 @@ +import math +import os +import time + +import fire +import torch +import torch.optim as optim +from fms.models import get_model +from fms.models.llama import LLaMABlock +from fms.utils import generation, tokenizers +from fms_extras.models.speculator import MLPSpeculator # type: ignore +from torch import distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.optim.lr_scheduler import LambdaLR + +from fms_fsdp import config +from fms_fsdp.utils.checkpointing_utils import Checkpointer +from fms_fsdp.utils.config_utils import update_config +from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader +from fms_fsdp.utils.train_utils import ( + get_mixed_precision_policy, + get_profiler, + setup, + setup_environ_flags, +) +from speculator.train_speculator_utils import train_speculator + + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def test_model(rank, model, arch, cfg, prompt_type="chat"): + if rank == 0: + print("testing model output") + tokenizer = tokenizers.get_tokenizer(cfg.model_path) + if prompt_type == "chat": + template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:" + prompt = template.format( + "Provide a list of instructions for preparing chicken soup." + ) + else: + template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]" + prompt = template.format("Write a bubble sort function in python.") + + tokens = tokenizer.tokenize(prompt) + ids = tokenizer.convert_tokens_to_ids(tokens) + if "llama" in arch: + ids = [tokenizer.bos_token_id] + ids + ids = torch.tensor(ids, dtype=torch.long, device="cuda") + result = generation.generate( + model, + ids, + max_new_tokens=100, + use_cache=True, + do_sample=False, + max_seq_len=8192, + ) + result = generation.truncate_after_eos(result, tokenizer.eos_token_id) + if rank == 0: + print(f"{rank}: quick test of base model") + print( + tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(result)) + ) + + +def get_emb_dim(model): + if hasattr(model.config, "emb_dim"): + emb_dim = model.config.emb_dim + elif hasattr(model.config, "dim"): # Mixtral + emb_dim = model.config.dim + elif hasattr(model.config, "hidden_size"): # HF + emb_dim = model.config.hidden_size + else: + raise Exception("config missing embedding dimension") + return emb_dim + + +def get_vocab_size(model): + if hasattr(model.config, "src_vocab_size"): # FMS + vocab_size = model.config.src_vocab_size + elif hasattr(model.config, "vocab_size"): # HF + vocab_size = model.config.vocab_size + else: + raise Exception("config missing vocab size config") + return vocab_size + + +def get_training_data_loader(rank, cfg, world_size, speculator_mesh): + if rank == 0: + print(f"{time.time()} Constructing datasets...") + if not cfg.use_dummy_dataset: + if cfg.sharding_strategy == "tp": + train_loader = get_data_loader( + cfg, speculator_mesh.get_rank(), speculator_mesh.size(), postprocess=[] + ) + else: + train_loader = get_data_loader(cfg, rank, world_size, postprocess=[]) + else: + train_loader = get_dummy_loader(cfg, rank, world_size) + if rank == 0: + print(f"{time.time()} Datasets constructed!") + return train_loader + + +def main(**kwargs): + # get configs + cfg = config.train_config() + update_config(cfg, **kwargs) + cfg.seq_length = cfg.seq_length + cfg.n_speculator_heads + 1 + + # ensure reproducibility + torch.cuda.manual_seed(cfg.seed) + torch.manual_seed(cfg.seed) + + # torchrun specific + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if rank == 0: + print(f"{time.time()} running with these configs {cfg}") + + # some setups + torch.cuda.set_device(local_rank) + + if cfg.sharding_strategy != "tp": + setup() + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + base_model_mesh = None + speculator_mesh = None + else: + base_model_mesh = dist.device_mesh.init_device_mesh( + "cuda", + (world_size // cfg.tp_size, cfg.tp_size), + mesh_dim_names=("dp", "tp"), + ) + speculator_mesh = dist.device_mesh.init_device_mesh("cuda", (world_size,)) + torch._C._distributed_c10d._register_process_group( + "default", base_model_mesh["tp"].get_group() + ) + + torch.cuda.empty_cache() + setup_environ_flags() + torch.set_default_dtype(torch.bfloat16) + + mixed_precision_policy = get_mixed_precision_policy(cfg, rank) + + model = get_model( + cfg.model_arch, + cfg.model_variant, + model_path=cfg.model_path, + device_type="cuda", + source="hf", + distributed_strategy=cfg.sharding_strategy, + group=( + base_model_mesh["tp"].get_group() if cfg.sharding_strategy == "tp" else None + ), + ) + + if rank == 0: + print(f"{time.time()}") + print(model.config) + print(model) + + model.eval() + with torch.no_grad(): + test_model(rank, model, cfg.model_arch, cfg) + + emb_dim = get_emb_dim(model) + vocab_size = get_vocab_size(model) + + # get speculator + if rank == 0: + print(f"{time.time()} Loading speculator") + speculator = MLPSpeculator( + emb_dim, + cfg.speculator_width, + vocab_size, + cfg.n_speculator_heads, + tie_weights=cfg.speculator_tie_weights, + scale_input=cfg.speculator_scale_input, + ) + speculator.reset_parameters() + + if rank == 0: + total_params = sum( + p.numel() for p in speculator.parameters() if p.requires_grad + ) + print(f"\n{time.time()} speculator has {total_params / 1e6} Million params\n") + + # get data loader + train_loader = get_training_data_loader(rank, cfg, world_size, speculator_mesh) + + # FSDP + speculator = FSDP( + speculator, + auto_wrap_policy=None, + mixed_precision=mixed_precision_policy, + sharding_strategy=ShardingStrategy.NO_SHARD, + use_orig_params=cfg.use_torch_compile, + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + sync_module_states=cfg.low_cpu_fsdp, + param_init_fn=lambda module: ( + module.to_empty(device=torch.device("cuda"), recurse=False) + if cfg.low_cpu_fsdp + else None + ), + device_mesh=speculator_mesh if cfg.sharding_strategy == "tp" else None, + ) + + # torch compile + if cfg.use_torch_compile: + if rank == 0: + print(f"enabling torch compile...") + if cfg.fsdp_activation_checkpointing: + raise ValueError( + "Compile does not yet work well with llama+ac, please" + "either use it without activation checkpointing, or disable" + "compile." + ) + # we need this post-fsdp call to avoid graph break with torch.compile, + if cfg.sharding_strategy != "tp" and hasattr(model, "rot_emb"): + model.rot_emb.compute_freqs_cis( + torch.device("cuda", torch.cuda.current_device()), + model.config.max_expected_seq_len + 10, + ) + model = torch.compile(model) + speculator = torch.compile(speculator) + + # Optimizer + optimizer = optim.AdamW( + speculator.parameters(), + lr=cfg.learning_rate, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # optionally load from checkpoint (when continue pretraining) + if cfg.sharding_strategy == "tp": + checkpointer = Checkpointer( + cfg.ckpt_save_path, + 1000, + "ddp", + speculator_mesh.get_rank(), + speculator_mesh.get_local_rank(), + model_auto_placement=True, + ) + else: + checkpointer = Checkpointer(cfg.ckpt_save_path, 1000, "ddp", rank, local_rank) + speculator, optimizer, train_loader, start_step, tokens_seen, _ = checkpointer.load( + speculator, + optimizer, + train_loader, + path=os.path.join(cfg.ckpt_load_path, "checkpoints/"), + is_compiled=cfg.use_torch_compile, + ) + + # LR schedule + # These functions map step count to LR scaling factor in [0,1]. + # Stage 1: warm up over first 2k or 5% of steps, whichever is smaller. + # Then cosine anneal to 10% of max LR. + warmup_interval1 = min(2000, cfg.stage2_start_step // 20) + stage1_schedule = lambda x: min( + # Parabolic warmup + 1 - (1 - min(x, warmup_interval1) / warmup_interval1) ** 2, + # Final .1 scaling factor + 0.1 + # Cosine anneal from 1 to .1 over stage2_start_step steps + + 0.5 * (1 - 0.1) * (1 + math.cos(x / cfg.stage2_start_step * math.pi)), + ) + # Stage 2: warm up over first 2k or 5% of steps, whichever is smaller. + # Then cosine anneal to 10% of stage 1's final LR. + warmup_interval2 = min(2000, (cfg.num_steps - cfg.stage2_start_step) // 20) + stage2_schedule = lambda x: min( + # Parabolic warmup to stage2's max LR (10% of stage1's max LR) + 0.1 * (1 - (1 - min(x, warmup_interval2) / warmup_interval2) ** 2), + # Final 10% of 10% scaling factor + 0.01 + # Cosine anneal from .1 to .01 over remaining stage2 steps + + 0.05 + * (1 - 0.1) + * ( + 1 + + math.cos( + min(x, cfg.num_steps - cfg.stage2_start_step) + / (cfg.num_steps - cfg.stage2_start_step) + * math.pi + ) + ), + ) + # Assemble full scheduling function with correct step offsets. + schedule = lambda x: ( + stage1_schedule(x) + if x <= cfg.stage2_start_step + else stage2_schedule(x - cfg.stage2_start_step) + ) + scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) + + # profiler + profiler = get_profiler(cfg, rank) + + # Train + if rank == 0: + print(f"{time.time()} Training for {cfg.num_steps} steps") + torch.cuda.empty_cache() + train_speculator( + cfg, + model, + speculator, + local_rank, + rank, + train_loader, + optimizer, + scheduler, + checkpointer, + start_step, + tokens_seen, + profiler, + base_model_mesh, + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py new file mode 100644 index 00000000..5064cf47 --- /dev/null +++ b/speculator/train_speculator_utils.py @@ -0,0 +1,570 @@ +import os +import re +import time +from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from fms.models import register_model +from fms.models.gpt_bigcode import GPTBigCode +from fms.models.gpt_bigcode import _20b_config as _gpt_bigcode_20b_config +from fms.models.gpt_bigcode import _hf_sd_to_fms_sd as _gptbigcode_hf_sd_to_fms_sd +from fms.models.llama import LLaMA +from fms.models.llama import _hf_sd_to_fms_sd as _llama_hf_sd_to_fms_sd +from fms.models.mixtral import Mixtral, MixtralConfig +from fms.models.mixtral import _hf_sd_to_fms_sd as _mixtral_hf_sd_to_fms_sd +from fms.utils import serialization, tokenizers +from fms.utils.generation import _make_cache_contiguous +from torch.nn import CrossEntropyLoss +from torch.utils.data import DataLoader + +from fms_fsdp.config import train_config +from fms_fsdp.utils.checkpointing_utils import Checkpointer +from fms_fsdp.utils.config_utils import get_model_config + + +def generate( + model: Union[Callable, torch.nn.Module], + input_ids: torch.Tensor, + max_seq_len: int = 2048, + max_new_tokens: int = 256, + temperature: float = 1.0, + top_k: int = 10, + do_sample: bool = True, + num_beams: int = 1, + use_cache: bool = False, + contiguous_cache: bool = False, + include_embeds: bool = True, +): + """ + A straightforward copy of the generate method in fms.utils.generation. + The only change is the include_embeds flag, which when true also returns + the embedding vectors corresponding to the tokens in the output sequence. + """ + batched = False + if num_beams != 1: + raise NotImplementedError("generate() does yet not support beam search") + if type(input_ids) == torch.Tensor: + if input_ids.dim() != 1: + batched = True + else: + raise RuntimeError("generate() requires a tensor of token ids as the prefix") + + if not batched: + input_ids = input_ids.unsqueeze(0) + + embeds = None + result = input_ids + next_input = input_ids + kwargs: MutableMapping[str, Any] = dict() + kwargs["past_key_value_states"] = None + kwargs["use_cache"] = use_cache + kwargs["include_embeds"] = include_embeds + + for _ in range(max_new_tokens): + input_ids = next_input[:, -max_seq_len:] + output = model(input_ids, **kwargs) + if not use_cache and not include_embeds: + logits = output + else: + logits = output[0] + if include_embeds: + z = output[-1] + if use_cache: + past_key_value_states = output[1] + # TODO: this should go away when reduce-overhead issues are fixed, or + # maybe could be moved into model code to be more portable. + if contiguous_cache: + kwargs["past_key_value_states"] = _make_cache_contiguous( + past_key_value_states + ) + else: + kwargs["past_key_value_states"] = past_key_value_states + logits = logits[:, -1, :] + + if do_sample: + # get logits from last value in sequence nad scale + logits = logits / temperature + if top_k: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float("inf") + + probs = F.softmax(logits, dim=-1) + next_val = torch.multinomial(probs, num_samples=1) + else: + next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t() + + result = torch.cat((result, next_val), dim=-1) + + if use_cache: + next_input = next_val + else: + next_input = result + + if include_embeds: + if embeds is None: + embeds = z + else: + embeds = torch.cat((embeds, z), dim=-2) + + if not batched: + result = result[0] + + if include_embeds: + return result, embeds + + return result + + +# Stage 1 training +def stage1_loss( + cfg, model, speculator, base_model_input, input, loss_fn, ddp_stats, base_model_mesh +): + """ + Perform a forward pass for stage 1 training and calculate the loss. + Given the sequence of embeddings produced in parallel by the base model, + get n+2,n+3,... speculator predictions and compare to ground truth tokens. + ... + Args + ---- + cfg: train_config + Set of training parameters. + model: nn.Module + The frozen base model. Must return output logits AND corresponding embedding vectors. + speculator: nn.Module + The speculator to be trained. Takes as input sequence of embeddings and token indices, + and return token prediction logits for each head. + input: torch.IntTensor + The ground truth token indices. If using TP, this is per TP rank, + with 'base_model_input' containing all-gathered input across all TP ranks + loss_fn: Callable + Torch loss function comparing logits to indices i.e. CrossEntropyLoss() + ddp_stats: torch.FloatTensor + Aggregate stat tracking buffer. + Entries are: grad norm, accumulation steps, head 1 loss, head 2 loss, etc. + base_model_mesh: torch.distributed.device_mesh.DeviceMesh + Device layout of the particiapting process group ranks + ---- + Returns: scalar loss value, updated ddp stats, number of tokens in input + """ + with torch.no_grad(): + _, embeds = model( + base_model_input[:, : -speculator.n_predict - 1], + include_embeds=True, + use_cache=False, + ) + if cfg.sharding_strategy == "tp": + embeds = embeds.chunk(base_model_mesh["tp"].size())[ + base_model_mesh["tp"].get_local_rank() + ] + + preds = speculator(embeds.detach(), input[:, 1:]) + losses = [] + for i in range(preds.size(0)): + targ = input[:, i + 2 : preds.size(2) + i + 2] # b n + loss = loss_fn(preds[i].reshape(-1, preds.size(3)), targ.long().reshape(-1)) + losses.append(loss) + ddp_stats[2 + i] += loss.item() + loss = sum(losses) + return loss, ddp_stats, input.numel() + + +# Stage 2 training: more heavyweight than stage 1; will take longer +def stage2_loss( + cfg, model, speculator, base_model_input, input, loss_fn, ddp_stats, base_model_mesh +): + """ + Perform a forward pass for stage 2 training and calculate the loss. + Given the sequence of embeddings produced in serial by the base model, + get n+1,n+2,... speculator predictions and compare to base model's generated tokens. + Reshapes input to more entries / shorter sequences, for more efficient generation. + ... + Args + ---- + cfg: train_config + Set of training parameters. Used here for reshaping input batches. + model: nn.Module + The frozen base model. Must return output logits AND corresponding embedding vectors. + speculator: nn.Module + The speculator to be trained. Takes as input sequence of embeddings and token indices, + and return token prediction logits for each head. + input: torch.IntTensor + The ground truth token indices. If using TP, this is per TP rank, + with 'base_model_input' containing all-gathered input across all TP ranks + loss_fn: Callable + Torch loss function comparing logits to indices i.e. CrossEntropyLoss() + ddp_stats: torch.FloatTensor + Aggregate stat tracking buffer. + Entries are: grad norm, accumulation steps, head 1 loss, head 2 loss, etc. + base_model_mesh: torch.distributed.device_mesh.DeviceMesh + Device layout of the particiapting process group ranks + ---- + Returns: scalar loss value, updated ddp stats, number of tokens in input + """ + with torch.no_grad(): + grow_factor = cfg.stage2_batch_size // cfg.batch_size + assert ( + cfg.stage2_prompt_length * grow_factor <= cfg.seq_length + ), "Error: batch is too small for specified partition" + base_model_input = base_model_input[ + :, : cfg.stage2_prompt_length * grow_factor + ].reshape(base_model_input.size(0) * grow_factor, cfg.stage2_prompt_length) + targs, embeds = generate( + model, + base_model_input, + cfg.seq_length, + cfg.stage2_seq_length, + do_sample=True, + use_cache=True, + include_embeds=True, + ) + + if cfg.sharding_strategy == "tp": + targs = targs.chunk(base_model_mesh["tp"].size())[ + base_model_mesh["tp"].get_local_rank() + ] + embeds = embeds.chunk(base_model_mesh["tp"].size())[ + base_model_mesh["tp"].get_local_rank() + ] + targs = targs[:, -cfg.stage2_seq_length :] + embeds = embeds[:, -cfg.stage2_seq_length : -speculator.n_predict] + preds = speculator(embeds.detach(), targs[:, :-1].detach()) + + losses = [] + for i in range(preds.size(0)): + targ = targs[:, i + 1 : preds.size(2) + i + 1] # b n + loss = loss_fn(preds[i].reshape(-1, preds.size(3)), targ.long().reshape(-1)) + losses.append(loss) + ddp_stats[2 + i] += loss.item() + loss = sum(losses) + return loss, ddp_stats, targs.numel() + + +# on demand checkpointing: echo 1 > /path/to/model_ckpt_dir/do_ckpt +def do_ckpt(ckpt_save_path, reset=False): + ckpt_cmd_file = ckpt_save_path + "/do_ckpt" + if not os.path.exists(ckpt_cmd_file): + return False + + if reset: + with open(ckpt_cmd_file, "w") as fd: + fd.write("0") + return False + + with open(ckpt_cmd_file) as fd: + if fd.read().strip() == "1": + return True + + return False + + +def train_speculator( + cfg: train_config, + model: nn.Module, + speculator: nn.Module, + local_rank: int, + rank: int, + train_loader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + checkpointer: Checkpointer, + start_step: int = 0, + n_tok: int = 0, + profiler: Optional[Union[torch.profiler.profile, None]] = None, + base_model_mesh=None, +): + """ + The training loop for speculator training. Handles at a high level: data loading, + forward and backward passes, model updates, stat tracking, reporting, and checkpointing. + ... + Args + ---- + cfg: train_config + The set of training parameters + model: nn.Module + The frozen base model. Must return output logits AND corresponding embedding vectors. + speculator: nn.Module + The speculator to be trained. Takes as input sequence of embeddings and token indices, + and returns token prediction logits for each head. + local_rank: int + The local rank of the current process. Used for stat tracking / aggregation across ranks. + rank: int + The global rank of the current process. Used for reporting. + train_loader: torch.utils.data.DataLoader + The dataloader used for reading in ground truth token sequences. Train_loader.dataset must + support save_to_path() for distributed checkpointing via checkpointer. + optimizer: torch.optim.Optimizer + The optimizer associated with the speculator's weights + scheduler: torch.optim.lr_scheduler.LRScheduler + A scheduler for the optimizer's LR. Scheduler.step() is called on every optimizer step. + checkpointer: fms_fsdp.utils.checkpointing_utils.Checkpointer + A checkpointer tied to the save directory. Used for saving distributed checkpoints. + start_step: optional[int] + If resuming from checkpoint, resume step count from this value. + n_tok: optional[int] + If resuming from checkpoint, resume token count from this value. + profiler: optional[torch.profiler.profile] + Optional torch profiler for performance benchmarking. + base_model_mesh: DeviceMesh + Device layout of the particiapting process group ranks + """ + model.eval() + speculator.train() + ddp_stats = torch.zeros(2 + speculator.n_predict).to(local_rank) + train_loss = 0 + + start = time.time() + loop_start = time.time() + loss_fn = CrossEntropyLoss() + elapsed_tokens = 0 + for batch_idx, input in enumerate(train_loader, start=start_step + 1): + if batch_idx > cfg.num_steps: + break + + input = input.to(local_rank) + + if cfg.sharding_strategy == "tp": + base_model_input = torch.zeros( + base_model_mesh["tp"].size() * input.size(0), + input.size(1), + dtype=input.dtype, + device=input.device, + ) + dist.all_gather_into_tensor( + base_model_input, input, group=base_model_mesh["tp"].get_group() + ) + else: + base_model_input = input + + optimizer.zero_grad() + + if batch_idx <= cfg.stage2_start_step: + loss, ddp_stats, step_tok = stage1_loss( + cfg, + model, + speculator, + base_model_input, + input, + loss_fn, + ddp_stats, + base_model_mesh, + ) + else: + loss, ddp_stats, step_tok = stage2_loss( + cfg, + model, + speculator, + base_model_input, + input, + loss_fn, + ddp_stats, + base_model_mesh, + ) + + loss.backward() + ddp_stats[0] += speculator.clip_grad_norm_(cfg.grad_clip_thresh).item() + optimizer.step() + scheduler.step() + + ddp_stats[1] += 1 + + if profiler: + profiler.step() + + if batch_idx % cfg.report_interval == 0: + dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) + train_loss = ddp_stats[2:] / ddp_stats[1] + g_norm = ddp_stats[0] / ddp_stats[1] + elapsed_time = time.time() - loop_start + world_size = int(os.environ["WORLD_SIZE"]) + elapsed_tokens += cfg.report_interval * world_size * step_tok + if rank == 0: + print(f"{time.time()}") + print("step:", batch_idx) + print("tokens seen:", n_tok + elapsed_tokens) + for i in range(len(train_loss)): + print(f"loss {i+1}:", train_loss[i].item()) + print("gradient norm:", g_norm.item()) + print( + f"speed for these {cfg.report_interval} steps:", + (time.time() - start) / cfg.report_interval, + ) + print("overall speed:", elapsed_time / (batch_idx - start_step)) + print("LR:", scheduler.get_last_lr()) + print( + "reserved memory:", + torch.cuda.max_memory_reserved(device=torch.cuda.current_device()), + ) + print( + "active memory:", + torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), + ) + print( + "overall token per gpu per sec:", + int(elapsed_tokens / world_size / elapsed_time), + ) + print("token per day:", int(elapsed_tokens / elapsed_time * 3600 * 24)) + print() + start = time.time() + ddp_stats.zero_() + torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device()) + + if ( + batch_idx % cfg.checkpoint_interval == 0 + or do_ckpt(cfg.ckpt_save_path) is True + ): + torch.cuda.empty_cache() + checkpointer.save( + batch_idx, + speculator, + optimizer, + train_loader, + tokens_seen=elapsed_tokens + n_tok, + ) + torch.cuda.empty_cache() + do_ckpt(cfg.ckpt_save_path, reset=True) + + checkpointer.save_single_file( + batch_idx, + speculator, + tokens_seen=elapsed_tokens + n_tok, + is_compiled=cfg.use_torch_compile, + ) + + return train_loss + + +class EmbedGPTBigCode(GPTBigCode): + # Overrides the forward function of GPTBigCode to allow returning embedding vectors + def forward( + self, + x: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + use_cache: bool = False, + attn_algorithm: Optional[str] = None, + include_embeds: bool = False, + ): + output, cache = self.base_model( + x, + mask, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + use_cache=use_cache, + attn_algorithm=attn_algorithm, + ) + + preds = self.head(output) + + out = [preds] + if use_cache: + out.append(cache) + if include_embeds: + out.append(output) + if len(out) == 1: + return out[0] + return out + + +class EmbedLLaMA(LLaMA): + # Overrides the forward function of LLaMA to allow returning embedding vectors + def forward( + self, + x, + mask=None, + position_ids=None, + past_key_value_states=None, + use_cache=False, + only_last_token=False, + attn_algorithm=None, + include_embeds=False, + ): + output, cache = self._helper( + x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm + ) + + if only_last_token: + output = output[:, -1, :] + preds = self.shared(output, reverse=True) + + out = [preds] + if use_cache: + out.append(cache) + if include_embeds: + out.append(output) + if len(out) == 1: + return out[0] + return out + + +class EmbedMixtral(Mixtral): # FMS impl of Mixtral + # Overrides the forward function of Mixtral to allow returning embedding vectors + def forward( + self, + x, + mask=None, + position_ids=None, + past_key_value_states=None, + use_cache=False, + only_last_token=False, + attn_algorithm=None, + include_embeds=False, + ): + output, cache = self.base_model( + x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm + ) + + if only_last_token: + output = output[:, -1, :] + preds = self.head(output) + + out = [preds] + if use_cache: + out.append(cache) + if include_embeds: + out.append(output) + if len(out) == 1: + return out[0] + return out + + +def _gpt_bigcode_factory_factory(config): + def factory(**kwargs): + return EmbedGPTBigCode(config, **kwargs) + + return factory + + +def _llama_factory_factory(config): + def factory(**kwargs): + return EmbedLLaMA(config, **kwargs) + + return factory + + +def _mixtral_factory_factory(config): + def factory(**kwargs): + return EmbedMixtral(config, **kwargs) + + return factory + + +# example model registrations +register_model( + "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config) +) +serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd) + +register_model( + "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b")) +) +register_model( + "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b")) +) +serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd) + +register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig())) +serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd) From a972f228b41399a5578b1d616867c503bb406471 Mon Sep 17 00:00:00 2001 From: sahil suneja Date: Mon, 9 Sep 2024 19:33:01 +0000 Subject: [PATCH 71/73] mypy fix Signed-off-by: sahil suneja --- speculator/__init__.py | 0 speculator/train_speculator_utils.py | 1 - 2 files changed, 1 deletion(-) create mode 100644 speculator/__init__.py diff --git a/speculator/__init__.py b/speculator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 5064cf47..89fb7ae1 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -313,7 +313,6 @@ def train_speculator( model.eval() speculator.train() ddp_stats = torch.zeros(2 + speculator.n_predict).to(local_rank) - train_loss = 0 start = time.time() loop_start = time.time() From cd4c9c329413bf4b978a59dadb745e2d244ad80c Mon Sep 17 00:00:00 2001 From: sahil suneja Date: Mon, 9 Sep 2024 21:47:40 +0000 Subject: [PATCH 72/73] mypy fix Signed-off-by: sahil suneja --- speculator/train_speculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speculator/train_speculator.py b/speculator/train_speculator.py index 4c33fd2c..7ef5e8f9 100644 --- a/speculator/train_speculator.py +++ b/speculator/train_speculator.py @@ -2,7 +2,7 @@ import os import time -import fire +import fire # type: ignore import torch import torch.optim as optim from fms.models import get_model From 879ae2164d13ca43cf07895539dda8ee120a062f Mon Sep 17 00:00:00 2001 From: sahil suneja Date: Tue, 10 Sep 2024 14:05:03 +0000 Subject: [PATCH 73/73] minor Signed-off-by: sahil suneja --- speculator/train_speculator_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 89fb7ae1..87b4e7b2 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -432,8 +432,6 @@ def train_speculator( is_compiled=cfg.use_torch_compile, ) - return train_loss - class EmbedGPTBigCode(GPTBigCode): # Overrides the forward function of GPTBigCode to allow returning embedding vectors