From b41114ad3631c8409c8c2409755e70acbf66e016 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 19 Nov 2024 14:31:28 -0800 Subject: [PATCH] fix graph breaks --- recipes/lora_finetune_single_device.py | 8 +++++++- torchtune/modules/loss/ce_chunked_output_loss.py | 6 ++++++ torchtune/modules/transformer.py | 11 ++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index fcdb3e4ea5..467ad53fc1 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -664,6 +664,7 @@ def train(self) -> None: ) # Initialize tokens count and running loss (for grad accumulation) + start = time.perf_counter() t0 = time.perf_counter() running_loss = 0 num_tokens = 0 @@ -730,6 +731,7 @@ def train(self) -> None: # Log per-step metrics if self.global_step % self._log_every_n_steps == 0: time_per_step = time.perf_counter() - t0 + print(time_per_step) log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], @@ -773,13 +775,17 @@ def train(self) -> None: self.epochs_run += 1 start_save_checkpoint = time.perf_counter() log.info("Starting checkpoint save...") - self.save_checkpoint(epoch=curr_epoch) + # self.save_checkpoint(epoch=curr_epoch) log.info( "Checkpoint saved in {:.2f} seconds.".format( time.perf_counter() - start_save_checkpoint ) ) + end = time.perf_counter() + time_total = end - start + print(f"{time_total=}") + def cleanup(self) -> None: self._metric_logger.close() diff --git a/torchtune/modules/loss/ce_chunked_output_loss.py b/torchtune/modules/loss/ce_chunked_output_loss.py index 17a5eced36..ff1525758a 100644 --- a/torchtune/modules/loss/ce_chunked_output_loss.py +++ b/torchtune/modules/loss/ce_chunked_output_loss.py @@ -78,6 +78,12 @@ def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Ten # compute one chunk at a time total_loss = 0.0 for logits_chunk, labels_chunk in zip(logits, labels): + + # avoid graph breaks when seq_len is not constant in the batch + torch._dynamo.mark_dynamic(logits_chunk, 0) + torch._dynamo.mark_dynamic(labels_chunk, 0) + + # CE total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) return total_loss / total_elements diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 9d64a1228e..23aa79e050 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -628,14 +628,23 @@ def forward( # shape: [b, s, d] h = self.tok_embeddings(tokens) + h.requires_grad = True # avoid graph breaks when using LoRA hidden = [] for i, layer in enumerate(self.layers): if i in self.output_hidden_states: hidden.append(h) + # avoid graph breaks when seq_len is not constant in the batch + torch._dynamo.mark_dynamic(h, 1) + if mask is not None: + torch._dynamo.mark_dynamic(mask, 1) + if encoder_mask is not None: + torch._dynamo.mark_dynamic(encoder_mask, 1) + if input_pos is not None: + torch._dynamo.mark_dynamic(input_pos, 1) + # shape: [b, s, d] - torch._dynamo.mark_dynamic(h, 1) # avoid graph breaks h = layer( h, mask=mask,