Skip to content

Commit

Permalink
fix graph breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Mello committed Nov 19, 2024
1 parent d9f417d commit b41114a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
8 changes: 7 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def train(self) -> None:
)

# Initialize tokens count and running loss (for grad accumulation)
start = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
Expand Down Expand Up @@ -730,6 +731,7 @@ def train(self) -> None:
# Log per-step metrics
if self.global_step % self._log_every_n_steps == 0:
time_per_step = time.perf_counter() - t0
print(time_per_step)
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
Expand Down Expand Up @@ -773,13 +775,17 @@ def train(self) -> None:
self.epochs_run += 1
start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch)
# self.save_checkpoint(epoch=curr_epoch)
log.info(
"Checkpoint saved in {:.2f} seconds.".format(
time.perf_counter() - start_save_checkpoint
)
)

end = time.perf_counter()
time_total = end - start
print(f"{time_total=}")

def cleanup(self) -> None:
self._metric_logger.close()

Expand Down
6 changes: 6 additions & 0 deletions torchtune/modules/loss/ce_chunked_output_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Ten
# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):

# avoid graph breaks when seq_len is not constant in the batch
torch._dynamo.mark_dynamic(logits_chunk, 0)
torch._dynamo.mark_dynamic(labels_chunk, 0)

# CE
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)

return total_loss / total_elements
11 changes: 10 additions & 1 deletion torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,14 +628,23 @@ def forward(

# shape: [b, s, d]
h = self.tok_embeddings(tokens)
h.requires_grad = True # avoid graph breaks when using LoRA

hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)

# avoid graph breaks when seq_len is not constant in the batch
torch._dynamo.mark_dynamic(h, 1)
if mask is not None:
torch._dynamo.mark_dynamic(mask, 1)
if encoder_mask is not None:
torch._dynamo.mark_dynamic(encoder_mask, 1)
if input_pos is not None:
torch._dynamo.mark_dynamic(input_pos, 1)

# shape: [b, s, d]
torch._dynamo.mark_dynamic(h, 1) # avoid graph breaks
h = layer(
h,
mask=mask,
Expand Down

0 comments on commit b41114a

Please sign in to comment.