Skip to content

Commit

Permalink
change up logic so we always truncate to top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 21, 2025
1 parent bb5e6f4 commit bded6df
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/axolotl/integrations/kd/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ def transform_logprobs(self, sample):
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
# and truncate the second dimension of the logprobs to top_k
logprobs = logprobs[:-input_seq_len, :top_k]
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len

# truncate the second dimension of the logprobs to top_k
logprobs = logprobs[:, :top_k]

# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf

Expand Down

0 comments on commit bded6df

Please sign in to comment.