Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix generation input preparation #1512

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def forward(
loss = None
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_key_values = tuple(
Expand Down Expand Up @@ -630,8 +629,17 @@ def _from_transformers(

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
Comment on lines +632 to +640
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -667,6 +675,16 @@ def can_generate(self):
class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -706,6 +724,16 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
class ORTOPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand All @@ -721,6 +749,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
class ORTMPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -794,7 +832,14 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed
if len(past_key_values[0][0].shape) == 4:
Expand Down
183 changes: 88 additions & 95 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,30 +1192,18 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1236,6 +1224,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1331,28 +1329,18 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(input_features=input_features, attention_mask=attention_mask)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1372,6 +1360,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1526,27 +1524,17 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(pixel_values=pixel_values)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1565,6 +1553,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1641,34 +1639,19 @@ def forward(
else:
attention_mask = attention_mask.astype(np.int64)

# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1690,6 +1673,16 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)

Expand Down
Loading