From 2cac863e40caf8c8bf608a1505ca8742461ea143 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 2 Nov 2023 18:44:36 +0100 Subject: [PATCH 1/2] fix generation input preparation --- optimum/onnxruntime/modeling_decoder.py | 51 ++++++- optimum/onnxruntime/modeling_seq2seq.py | 183 ++++++++++++------------ 2 files changed, 136 insertions(+), 98 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 13aef3546a5..8380d240eca 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -204,7 +204,6 @@ def forward( loss = None if self.use_cache: if past_key_values is not None: - input_ids = input_ids[:, -1:] # Flatten the past_key_values (no need to flatten for models using multi-query attn) if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: past_key_values = tuple( @@ -630,8 +629,17 @@ def _from_transformers( # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) @@ -667,6 +675,16 @@ def can_generate(self): class ORTBloomForCausalLM(ORTModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) @@ -706,6 +724,16 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> class ORTOPTForCausalLM(ORTModelForCausalLM): # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) @@ -721,6 +749,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg class ORTMPTForCausalLM(ORTModelForCausalLM): # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) @@ -794,7 +832,14 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict: if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] # the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed if len(past_key_values[0][0].shape) == 4: diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index ee3e0534a32..de37aaa153c 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1192,30 +1192,18 @@ def forward( if encoder_outputs is None: encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - labels=labels, - ) - elif self.use_merged is True: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids[:, -1:], - encoder_hidden_states=encoder_outputs.last_hidden_state, - past_key_values=past_key_values, - encoder_attention_mask=attention_mask, - labels=labels, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - labels=labels, - ) + model = ( + self.decoder + if past_key_values is None or not self.use_cache or self.use_merged + else self.decoder_with_past + ) + decoder_outputs = model( + input_ids=decoder_input_ids, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + labels=labels, + ) return Seq2SeqLMOutput( loss=decoder_outputs.get("loss", None), @@ -1236,6 +1224,16 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ) -> Dict: + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1331,28 +1329,18 @@ def forward( if encoder_outputs is None: encoder_outputs = self.encoder(input_features=input_features, attention_mask=attention_mask) - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - encoder_hidden_states=encoder_outputs.last_hidden_state, - labels=labels, - ) - elif self.use_merged is True: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids[:, -1:], - encoder_hidden_states=encoder_outputs.last_hidden_state, - past_key_values=past_key_values, - encoder_attention_mask=attention_mask, - labels=labels, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - labels=labels, - ) + model = ( + self.decoder + if past_key_values is None or not self.use_cache or self.use_merged + else self.decoder_with_past + ) + decoder_outputs = model( + input_ids=decoder_input_ids, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + labels=labels, + ) return Seq2SeqLMOutput( loss=decoder_outputs.get("loss", None), @@ -1372,6 +1360,16 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ) -> Dict: + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1526,27 +1524,17 @@ def forward( if encoder_outputs is None: encoder_outputs = self.encoder(pixel_values=pixel_values) - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - encoder_hidden_states=encoder_outputs.last_hidden_state, - labels=labels, - ) - elif self.use_merged is True: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids[:, -1:], - encoder_hidden_states=encoder_outputs.last_hidden_state, - past_key_values=past_key_values, - labels=labels, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - labels=labels, - ) + model = ( + self.decoder + if past_key_values is None or not self.use_cache or self.use_merged + else self.decoder_with_past + ) + decoder_outputs = model( + input_ids=decoder_input_ids, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + labels=labels, + ) return Seq2SeqLMOutput( loss=decoder_outputs.get("loss", None), @@ -1565,6 +1553,16 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ) -> Dict: + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1641,34 +1639,19 @@ def forward( else: attention_mask = attention_mask.astype(np.int64) - # Decode - if past_key_values is None or self.use_cache is False: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - labels=labels, - ) - elif self.use_merged is True: - decoder_outputs = self.decoder( - input_ids=decoder_input_ids[:, -1:], - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - labels=labels, - ) - else: - decoder_outputs = self.decoder_with_past( - input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used - decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=attention_mask, - labels=labels, - ) + model = ( + self.decoder + if past_key_values is None or not self.use_cache or self.use_merged + else self.decoder_with_past + ) + decoder_outputs = model( + input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + labels=labels, + ) return Seq2SeqLMOutput( loss=decoder_outputs.get("loss", None), @@ -1690,6 +1673,16 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ) -> Dict: + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] + if decoder_attention_mask is None: decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) From d5361db600688ea76e98deb3deda18d0bbf332c9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 17:32:10 +0100 Subject: [PATCH 2/2] fix --- optimum/exporters/onnx/model_patcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 09cdddc95fe..b9f0df29eaa 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -412,8 +412,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): # TODO: Remove this if once transformers if much above 4.35 - if AttentionMaskConverter is not None: - AttentionMaskConverter._make_causal_mask = self.original_make_causal + # TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy. + # We need to find a proper solution. + # if AttentionMaskConverter is not None: + # AttentionMaskConverter._make_causal_mask = self.original_make_causal + pass def __init__( self,