Skip to content

Commit

Permalink
Revert "fix generation input preparation (#1512)"
Browse files Browse the repository at this point in the history
This reverts commit 554e312.
  • Loading branch information
echarlaix authored Nov 3, 2023
1 parent 554e312 commit 35ac713
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 141 deletions.
7 changes: 2 additions & 5 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,8 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
# TODO: Remove this if once transformers if much above 4.35
# TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy.
# We need to find a proper solution.
# if AttentionMaskConverter is not None:
# AttentionMaskConverter._make_causal_mask = self.original_make_causal
pass
if AttentionMaskConverter is not None:
AttentionMaskConverter._make_causal_mask = self.original_make_causal

def __init__(
self,
Expand Down
51 changes: 3 additions & 48 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def forward(
loss = None
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_key_values = tuple(
Expand Down Expand Up @@ -629,17 +630,8 @@ def _from_transformers(

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -675,16 +667,6 @@ def can_generate(self):
class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -724,16 +706,6 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
class ORTOPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand All @@ -749,16 +721,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
class ORTMPTForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

Expand Down Expand Up @@ -832,14 +794,7 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
input_ids = input_ids[:, -1:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed
if len(past_key_values[0][0].shape) == 4:
Expand Down
183 changes: 95 additions & 88 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,18 +1192,30 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1224,16 +1236,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1329,18 +1331,28 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(input_features=input_features, attention_mask=attention_mask)

model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1360,16 +1372,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1524,17 +1526,27 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(pixel_values=pixel_values)

model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1553,16 +1565,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down Expand Up @@ -1639,19 +1641,34 @@ def forward(
else:
attention_mask = attention_mask.astype(np.int64)

model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)
decoder_outputs = model(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
# Decode
if past_key_values is None or self.use_cache is False:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
elif self.use_merged is True:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
Expand All @@ -1673,16 +1690,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
) -> Dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)

Expand Down

0 comments on commit 35ac713

Please sign in to comment.