Skip to content

Commit

Permalink
rm ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 7, 2024
1 parent 8559f93 commit 77282d6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 21 deletions.
6 changes: 3 additions & 3 deletions wenet/text/LLM/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
if tie_word_embedding:
self.out.weight = self.embed.weight

self.decoders = decoder
self.decoder = decoder
self.sos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<sos>", vocab_size - 1))
self.eos = (vocab_size - 1 if special_tokens is None else
Expand Down Expand Up @@ -66,8 +66,8 @@ def forward(
att_mask = causal_mask & tgt_mask # (B, L, L)

embeding = self.embed(ys_in_pad)
decoder_out = self.out(self.decoders(embeding,
att_mask)) # (B, L, vocab_size)
decoder_out = self.out(self.decoder(embeding,
att_mask)) # (B, L, vocab_size)

loss = self.criterion_att(decoder_out, ys_out_pad)
acc = th_accuracy(decoder_out.view(-1, self.vocab_size),
Expand Down
24 changes: 6 additions & 18 deletions wenet/text/LLM/decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch

import torch.utils.checkpoint as ckpt

from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
Expand Down Expand Up @@ -64,22 +62,12 @@ def forward(self, input: torch.Tensor, att_mask: torch.Tensor):
tgt_mask = att_mask
if self.use_sdpa:
tgt_mask = mask_to_bias(tgt_mask, xs.dtype)
if not self.gradient_checkpoint:
decoder_out, _, _, _ = self.decoders(xs,
tgt_mask,
pos_emb,
mask_pad=None)
else:
assert self.training
decoder_out = xs
for layer in self.decoders:
decoder_out, _, _, _ = ckpt.checkpoint(
layer.__call__,
decoder_out,
tgt_mask,
pos_emb,
)

decoder_out, _, _, _ = self.decoders(
xs,
tgt_mask,
pos_emb,
mask_pad=None,
)
return decoder_out

@property
Expand Down

0 comments on commit 77282d6

Please sign in to comment.