From 729a8d73089b9a56f7da18acc73547a5718e21d1 Mon Sep 17 00:00:00 2001 From: Manuel Burger Date: Wed, 6 Nov 2024 09:21:05 +0100 Subject: [PATCH] Update dataloader --- petagraph/run_train.py | 26 +- src/nanotron/data/petagraph_dataset.py | 444 ++++++++++++++++++++++++- src/nanotron/trainer.py | 52 +-- 3 files changed, 494 insertions(+), 28 deletions(-) diff --git a/petagraph/run_train.py b/petagraph/run_train.py index 0b8edf76..d8e4e970 100644 --- a/petagraph/run_train.py +++ b/petagraph/run_train.py @@ -46,7 +46,7 @@ hf_hub_version = None tf_version = None -from nanotron.data.petagraph_dataset import PetaGraphStreamDataset +from nanotron.data.petagraph_dataset import PetaGraphStreamDataset, PetaGraphStreamDatasetV2 logger = logging.get_logger(__name__) @@ -126,8 +126,14 @@ def get_dataloader_from_data_stage( # else: # raise ValueError("Data path must contain either 'unitig' or 'contig'") - contig_format = "s3://logan-pub/c/{accession}/{accession}.contigs.fa.zst" - unitig_format = "s3://logan-pub/u/{accession}/{accession}.unitigs.fa.zst" + # ----- URL FORMAT ----- + # contig_format = "s3://logan-pub/c/{accession}/{accession}.contigs.fa.zst" + # unitig_format = "s3://logan-pub/u/{accession}/{accession}.unitigs.fa.zst" + + unitig_format = "https://s3.amazonaws.com/logan-pub/u/{accession}/{accession}.unitigs.fa.zst" + contig_format = "https://s3.amazonaws.com/logan-pub/c/{accession}/{accession}.contigs.fa.zst" + log_rank(f"Contig format: {contig_format}", logger=logger, level=logging.INFO, rank=0) + # ---------------------- assert data.all_sequences_resources_path is not None, "all_sequences_resources_path must be provided" all_sequences_resources_path = Path(data.all_sequences_resources_path) @@ -226,14 +232,13 @@ def get_dataloader_from_data_stage( else: global_rank = trainer.parallel_context.world_pg.rank() - train_dataset = PetaGraphStreamDataset( + train_dataset = PetaGraphStreamDatasetV2( logger=logger, url_list=train_sequence_files, vocabulary=VOCABULARY, from_cloud=True, # not mock_data, maxlen=trainer.sequence_length + 1, create_attention_mask=True, - prefetch_sequences=data.prefetch_buffer_seq_size, log_directory=trainer.config.checkpoints.checkpoints_path, rank=global_rank, packed=True @@ -253,14 +258,23 @@ def get_dataloader_from_data_stage( log_rank(f"Using {num_dl_workers} dataloader workers", logger=logger, level=logging.INFO, rank=0) + prefetch_factor = None + worker_init_fn = None + if num_dl_workers > 0: + prefetch_factor = data.prefetch_buffer_seq_size + if isinstance(train_dataset, PetaGraphStreamDatasetV2): + worker_init_fn = train_dataset.worker_init_fn + + log_rank(f"Prefetch factor: {prefetch_factor}", logger=logger, level=logging.INFO, rank=0) return DataLoader( train_dataset, batch_size=trainer.micro_batch_size, collate_fn=data_collator, drop_last=True, + prefetch_factor=prefetch_factor, num_workers=num_dl_workers, pin_memory=True, - worker_init_fn=get_dataloader_worker_init(dp_rank=trainer.parallel_context.dp_pg.rank()), + worker_init_fn=worker_init_fn, ) else: diff --git a/src/nanotron/data/petagraph_dataset.py b/src/nanotron/data/petagraph_dataset.py index e1c90972..e9522e09 100644 --- a/src/nanotron/data/petagraph_dataset.py +++ b/src/nanotron/data/petagraph_dataset.py @@ -27,6 +27,8 @@ from nanotron.logging import log_rank from collections import defaultdict, deque +# import line_profiler +import requests # ============================================================================= @@ -147,7 +149,7 @@ def __init__(self, # Set the current epoch to the restart epoch self.current_epoch = restart_epoch - log_msg = f"[PetaGraphStreamDataset:{self.rank}] Restarting from epoch {self.current_epoch} with {len(self.consumed_files)} files" + log_msg = f"[PetaGAdd lockaphStreamDataset:{self.rank}] Restarting from epoch {self.current_epoch} with {len(self.consumed_files)} files" log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank) else: self.consumed_files = set() @@ -310,7 +312,6 @@ def find_overlaps_and_build_graph(sequences, k_mer: int = 31): return graph - # Perform random walk on the graph @staticmethod def dfs_paths(graph, start, path = None, all_paths = None, depth: int = 10): """Perform a depth-first search on the graph""" @@ -519,3 +520,442 @@ def __iter__(self) -> dict[str, np.ndarray]: + + + + + +class PetaGraphStreamDatasetV2(torch.utils.data.IterableDataset): + """Training dataset to stream from Logan + + Parameters + ---------- + sampling_seq_len_inflection : int + The sequence length at which to switch from sampling to keeping the sequence + below the inflection point we only keep the sequence with a probability pr + to its length. Above the inflection point we always keep the sequence. + """ + + def __init__(self, + logger, + url_list: list[str], + vocabulary: dict[str, int], + from_cloud: bool = False, + maxlen: int = 128, + samples_per_epoch: int = -1, + create_attention_mask: bool = True, + debug: bool = False, + log_directory: Path = None, + rank: int = 0, + packed: bool = False, + sampling_seq_len_inflection: int = 1024 + ): + + self.maxlen = maxlen + self.create_attention_mask = create_attention_mask + self.debug = debug + self.sampling_seq_len_inflection = sampling_seq_len_inflection + + + self.logger = logger + self.logging_func = partial(log_rank, logger=logger, level=logging.INFO, rank=0) + self.logging_func("=====================================") + self.logging_func(f"[PetaGraphStreamDataset] Creating PetaGraphStreamDataset with maxlen {maxlen}") + # self.logging_func(f"[PetaGraphStreamDataset] Samples per epoch: {samples_per_epoch}") + self.logging_func(f"[PetaGraphStreamDataset] Num. URLs: {len(url_list)}") + self.logging_func(f"[PetaGraphStreamDataset] From Cloud: {from_cloud}") + self.logging_func(f"[PetaGraphStreamDataset] Sampling Seq. Len. Inflection: {self.sampling_seq_len_inflection}") + + self.VOCAB = vocabulary + self._pad_token_id = self.VOCAB["PAD"] + self._eos_token_id = self.VOCAB["EOS"] + self._bos_token_id = self.VOCAB["BOS"] + self._unk_token_id = self.VOCAB["UNK"] + + + self.num_files = len(url_list) + self.current_epoch = 0 + + self.rank = rank + self.log_directory = log_directory + self.num_consumed_sequences = 0 + self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt" + self.consumed_files_lock = mp.Lock() + + # Save the vocabulary as json on head node + if self.rank == 0: + self.vocab_path = log_directory / "vocabulary.json" + with open(self.vocab_path, "w") as f: + json.dump(self.VOCAB, f) + + # Take list of already consumed lists and remove them from the + # url list, to continue training from the last checkpoint properly + # - Check if the consumed_files exist + # - If they exist, load them and assume we are restarting from a checkpoint + # - Find the largest epoch number in the consumed files + # - Filter the files that have been consumed/started in the latest epoch + # - Remove them from the url_list then append them to the end of the url_list + # - Set the current epoch to the latest epoch + if self.consumed_files_path.exists(): + log_msg = f"[PetaGraphStreamDataset:{self.rank}] Consumed files found at {self.consumed_files_path} loading..." + log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank) + + restart_epoch, restart_consumed_files = self.load_restart_consumed_files(self.consumed_files_path) + log_msg = f"[PetaGraphStreamDataset:{self.rank}] Found {restart_epoch} epoch with {len(restart_consumed_files)} files" + log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank) + + # All files in restart_consumed_files should be present in the url_list + for f in restart_consumed_files: + assert f in url_list, f"File {f} from restart not found in the url_list" + + # Remove those files from the url list and append them to the end + # of the url list + restart_consumed_files_set = set(restart_consumed_files) + for f in restart_consumed_files_set: + url_list.remove(f) + url_list.extend(restart_consumed_files) + + # Add the consumed files to the consumed files set + self.consumed_files = set(restart_consumed_files) + + # Set the current epoch to the restart epoch + self.current_epoch = restart_epoch + + log_msg = f"[PetaGAdd lockaphStreamDataset:{self.rank}] Restarting from epoch {self.current_epoch} with {len(self.consumed_files)} files" + log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank) + else: + self.consumed_files = set() + + # Setup the input + self.url_list = url_list + self.url_index = 0 + + self.consumed_seq_len_queue = deque(maxlen=5000) + if self.log_directory is not None: + self.logging_func(f"[PetaGraphStreamDataset] Logging to {self.log_directory} on rank {self.rank}") + + self.packed = packed + if self.packed: + self.logging_func(f"[PetaGraphStreamDataset] Packing sequences to maximize throughput") + + self.logging_func("=====================================") + + @staticmethod + def worker_init_fn(worker_id): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + return + + dataset = worker_info.dataset + num_workers = worker_info.num_workers + # worker_id = worker_info.id + + num_urls = len(dataset.url_list) + urls_per_worker = num_urls // num_workers + start_idx = worker_id * urls_per_worker + end_idx = (worker_id + 1) * urls_per_worker + if worker_id == num_workers - 1: + end_idx = num_urls + + worker_urls = dataset.url_list[start_idx:end_idx] + dataset.url_list = worker_urls + + log_msg = f"[Worker:{worker_id}] Worker {worker_id} processing {len(dataset.url_list)} urls {start_idx} to {end_idx}" + log_rank(log_msg, logger=dataset.logger, level=logging.INFO, rank=0) + + @staticmethod + def load_restart_consumed_files(restart_file: Path): + """Load the consumed files from the restart file + + Returns the latest epoch and the files consumed in the latest epoch + + Parameters: + ---------- + restart_file (Path): The path to the restart file + """ + with open(restart_file, "r") as f: + consumed_files = f.readlines() + consumed_files = [f.strip() for f in consumed_files] + consumed_files_tuples = [(int(f.split("_")[0]), f.split("_")[1]) for f in consumed_files] + + latest_epoch = max([f[0] for f in consumed_files_tuples]) + latest_files = [f[1] for f in consumed_files_tuples if f[0] == latest_epoch] + + return latest_epoch, latest_files + + @staticmethod + def chop_at_first_repeated_kmer(sequence: str, k: int): + """Chop the sequence at the first repeated kmer + + Python implementation of: + https://gitlab.pasteur.fr/rchikhi_pasteur/logan-circles/-/blob/master/fix_repeated_31kmers.cpp?ref_type=heads + + Parameters + ---------- + sequence : str + The sequence to chop + k : int + The kmer length + """ + kmers = set() + for i in range(len(sequence) - k + 1): + kmer = sequence[i:i+k] + if kmer in kmers: + return sequence[:i + k - 1] + kmers.add(kmer) + + return sequence + + @staticmethod + def find_overlaps_and_build_graph(sequences, k_mer: int = 31): + """Reconstruct assembly graph""" + min_overlap = k_mer - 1 + prefix_dict = defaultdict(list) + + # Precompute the suffixes + for i, seq in enumerate(sequences): + prefix_dict[seq[:min_overlap]].append(i) + + graph = defaultdict(list) + + # Check for overlaps + for i, seq1 in enumerate(sequences): + seq1_suffix = seq1[-min_overlap:] + graph[i] = [] + for j in prefix_dict[seq1_suffix]: + if i != j: + graph[i].append(j) + + return graph + + @staticmethod + def dfs_paths(graph, start, path = None, all_paths = None, depth: int = 10): + """Perform a depth-first search on the graph""" + if path is None: + path = [start] # Initialize the path with the starting node + if all_paths is None: + all_paths = [] # Initialize the list to store all paths + + # If we revisit a node in the current path, it's a cycle, so we stop + if start in path[:-1]: + all_paths.append(path[:-1]) + return all_paths + + # Check if the current node is a leaf (no neighbors) + if start not in graph or not graph[start]: + all_paths.append(path) # Add the current path as a complete path + return all_paths + + if len(path) >= depth: + all_paths.append(path) + return all_paths + # Explore each neighbor recursively, ensuring no cycles + for neighbor in graph[start]: + PetaGraphStreamDataset.dfs_paths(graph, neighbor, path + [neighbor], all_paths) + + return all_paths + + @staticmethod + def random_walk_graph_sequences(graph, sequences, k_mer: int = 31) -> list[str]: + """Perform random walk on the graph""" + random_walk_sequences = [] + for node in graph: + paths = PetaGraphStreamDataset.dfs_paths(graph, node) + idx = np.random.randint(len(paths)) + path = paths[idx] + seq = sequences[path[0]] + "".join([sequences[p][k_mer-1:] for p in path[1:]]) + # seq = seq[:MAX_SEQ_LENGTH] + random_walk_sequences.append(seq) + + return random_walk_sequences + + + def length_sampling_filter(self, sequence: str) -> bool: + seq_len = len(sequence) + if seq_len >= self.sampling_seq_len_inflection: + return True + else: + prob = np.random.rand() + if prob < (seq_len / self.sampling_seq_len_inflection): + True + + return False + + + def fasta_parsing_func(self, input_data: Tuple[str, bytes]): + """Parse the fasta data and return the sequences + + Parameters + ---------- + input_data : Tuple[str, bytes] + The path and the data to parse + """ + path, data = input_data + if data is None: + return [[]] + + sequences = [] + decoded_lines = data.decode() + sequences = [str(s.seq) for s in SeqIO.parse(StringIO(decoded_lines), "fasta")] + + # Following DNA-BERTv2: https://arxiv.org/pdf/2306.15006 + # Zhou et al.: "We exclude all sequences with N and retain only sequences that consist of A, T, C, and G. + sequences = [s for s in sequences if set(s).issubset(ALPHABET)] + + # Chop sequences in preparation for graph traversal + sequences = [self.chop_at_first_repeated_kmer(s, k=KMER_LENGTH) for s in sequences] + + # Construct sequence graph and perform random walks + sequences_arr = np.array(sequences) + sequence_graph = self.find_overlaps_and_build_graph(sequences_arr, k_mer=KMER_LENGTH) + random_walk_sequences = self.random_walk_graph_sequences(sequence_graph, sequences_arr, k_mer=KMER_LENGTH) + + # Sample sequences for training + keep_sequences = [(path, s) for s in filter(self.length_sampling_filter, random_walk_sequences)] + + # Test outputs + if len(keep_sequences) == 0: + return [[]] + + assert isinstance(keep_sequences, list) + assert isinstance(keep_sequences[0], tuple) and len(keep_sequences[0]) == 2 + assert isinstance(keep_sequences[0][0], str) and isinstance(keep_sequences[0][1], str) + + return keep_sequences + + def crop_maxlen(self, input_sequence: str, maxlen: int = None): + # path, input_sequence = input_data + if len(input_sequence) <= maxlen: + return input_sequence + else: + # Crop the sequence to the maximum length + # Get random starting point + start = random.randint(0, len(input_sequence) - maxlen) + return input_sequence[start:start + maxlen] + + def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True): + # path, input_sequence = input_data + maxlen = self.maxlen + + # Tokenize the sequence + tokenized_sequence = [self._bos_token_id] # start with BOS token + tokenized_sequence.extend([self.VOCAB.get(base, self._unk_token_id) for base in input_sequence]) # 3 is the UNK token + if len(tokenized_sequence) < maxlen: + tokenized_sequence.append(self._eos_token_id) # end with EOS token + tokenized_sequence = np.array(tokenized_sequence, dtype=np.int32) + + # Pad the sequence + if apply_pad and len(tokenized_sequence) < maxlen: + # 2 is the PAD token + tokenized_sequence = np.pad(tokenized_sequence, + (0, maxlen - len(tokenized_sequence)), + mode="constant", + constant_values=self._pad_token_id) + + return tokenized_sequence + + def generate(self): + current_tokens = None + current_sequences = [] + while True: + try: + + # Need to download the next file + if len(current_sequences) == 0: + new_url = self.url_list[self.url_index] + source_path = new_url + self.url_index += 1 + + # worker_info = torch.utils.data.get_worker_info() + # if worker_info is None: + # worker_id = 0 + # else: + # worker_id = worker_info.id + + raw_stream = requests.get(new_url, stream=True) + try: + dctx = zstandard.ZstdDecompressor() + decompressed_data = dctx.decompress(raw_stream.content) + except Exception as e: + self.logger.warning(f"[PetaGraphStreamDataset] Error decompressing {source_path}: {e}") + continue + + current_sequences = self.fasta_parsing_func((source_path, decompressed_data)) + + # Remove the first sequence + source_path, text_raw = current_sequences.pop(0) + if text_raw is None or len(text_raw) == 0: + continue + + # Log the consumed sequences + self.num_consumed_sequences += 1 + + # Log the consumed files + if self.log_directory is not None: + if source_path not in self.consumed_files: + + self.consumed_files_lock.acquire() + with open(self.consumed_files_path, "a") as f: + f.write(f"{self.current_epoch}_{source_path}\n") + self.consumed_files_lock.release() + + self.consumed_files.add(source_path) + if len(self.consumed_files) == self.num_files: + self.current_epoch += 1 + self.logging_func(f"Epoch {self.current_epoch} completed") + self.consumed_files = set() + + except StopIteration as e: + self.logger.warning(f"Reached end of dataset: {e}") + + if not self.packed: + + # Crop the sequence to the maximum length + maxlen_without_special_tokens = self.maxlen - 1 # for BOS token + text_cropped = self.crop_maxlen(text_raw, maxlen=maxlen_without_special_tokens) + + # Log the consumed sequence length + text_length = len(text_cropped) + self.consumed_seq_len_queue.append(text_length) + + # Tokenize and pad the sequence + text_tokenized = self.tokenize_and_pad(text_cropped) + + yield {"input_ids": text_tokenized} + + else: + + # Crop the sequence to the maximum length + # Leave room for at least BOS + if len(text_raw) >= self.maxlen: + text_cropped = text_raw[:self.maxlen] + else: + text_cropped = text_raw + + # Log the consumed sequence length + text_length = len(text_cropped) + self.consumed_seq_len_queue.append(text_length) + + new_tokens = self.tokenize_and_pad(text_cropped, apply_pad=False) + if current_tokens is None: + current_tokens = new_tokens + else: + # Check the last token of the current sequence + # is an EOS token + assert current_tokens[-1] == self._eos_token_id + current_tokens = np.concatenate([current_tokens, new_tokens]) + + if len(current_tokens) >= self.maxlen: + current_tokens = current_tokens[:self.maxlen] + yield {"input_ids": current_tokens} + current_tokens = None + + def __iter__(self) -> dict[str, np.ndarray]: + + """Abstract method implementation + + Returns: + Dict[str, torch.Tensor]: The sample information wrapped in a dictionary + """ + return cyclic_iter(self.generate()) + diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 07490fd2..3f76b140 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -630,32 +630,41 @@ def train_step_logs( num_consumed_files_t = torch.tensor(num_consumed_files, device="cuda", dtype=torch.int64) num_consumed_files_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) - dist.all_gather_into_tensor( - output_tensor=num_consumed_files_t_all, - input_tensor=num_consumed_files_t, - group=self.parallel_context.dp_pg - ) + if world_size_dp_pg > 1: + dist.all_gather_into_tensor( + output_tensor=num_consumed_files_t_all, + input_tensor=num_consumed_files_t, + group=self.parallel_context.dp_pg + ) + else: + num_consumed_files_t_all = num_consumed_files_t num_consumed_files_ranks = num_consumed_files_t_all.cpu().numpy() num_consumed_files_all = num_consumed_files_ranks.sum() self.metadata.consumed_num_logan_files = int(num_consumed_files_all) current_epoch_t = torch.tensor(current_epoch, device="cuda", dtype=torch.int64) current_epoch_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) - dist.all_gather_into_tensor( - output_tensor=current_epoch_t_all, - input_tensor=current_epoch_t, - group=self.parallel_context.dp_pg - ) + if world_size_dp_pg > 1: + dist.all_gather_into_tensor( + output_tensor=current_epoch_t_all, + input_tensor=current_epoch_t, + group=self.parallel_context.dp_pg + ) + else: + current_epoch_t_all = current_epoch_t current_epoch_ranks = current_epoch_t_all.cpu().numpy() current_epoch_all = current_epoch_ranks.mean() num_consumed_seq_t = torch.tensor(num_consumed_sequences, device="cuda", dtype=torch.int64) num_consumed_seq_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) - dist.all_gather_into_tensor( - output_tensor=num_consumed_seq_t_all, - input_tensor=num_consumed_seq_t, - group=self.parallel_context.dp_pg - ) + if world_size_dp_pg > 1: + dist.all_gather_into_tensor( + output_tensor=num_consumed_seq_t_all, + input_tensor=num_consumed_seq_t, + group=self.parallel_context.dp_pg + ) + else: + num_consumed_seq_t_all = num_consumed_seq_t num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy() num_consumed_seq_all = num_consumed_seq_ranks.sum() self.metadata.consumed_num_sequences += int(num_consumed_seq_all) @@ -663,11 +672,14 @@ def train_step_logs( mean_consumed_seq_len_t = torch.tensor(mean_seq_len, device="cuda", dtype=torch.float32) mean_consumed_seq_len_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.float32) - dist.all_gather_into_tensor( - output_tensor=mean_consumed_seq_len_t_all, - input_tensor=mean_consumed_seq_len_t, - group=self.parallel_context.dp_pg - ) + if world_size_dp_pg > 1: + dist.all_gather_into_tensor( + output_tensor=mean_consumed_seq_len_t_all, + input_tensor=mean_consumed_seq_len_t, + group=self.parallel_context.dp_pg + ) + else: + mean_consumed_seq_len_t_all = mean_consumed_seq_len_t mean_consumed_seq_len_ranks = mean_consumed_seq_len_t_all.cpu().numpy() mean_consumed_seq_len_all = mean_consumed_seq_len_ranks.mean()