Skip to content

Commit

Permalink
Fix forward pass for merged model
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi committed Oct 17, 2023
1 parent e7bd60d commit 5b0d813
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(
loss = None
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_key_values = tuple(
Expand Down

0 comments on commit 5b0d813

Please sign in to comment.