diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 1739f3782..af13e9012 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -542,7 +542,6 @@ def export_neuronx( # Construct compiler configurations if auto_cast is not None: logger.info(f"Using Neuron: --auto-cast {auto_cast}") - auto_cast = "matmult" if auto_cast == "matmul" else auto_cast compiler_args = ["--auto-cast", auto_cast] @@ -552,6 +551,10 @@ def export_neuronx( compiler_args = ["--auto-cast", "none"] compiler_args.extend(["--optlevel", optlevel]) + logger.info(f"Using Neuron: --optlevel {optlevel}") + + if getattr(config._config, "is_encoder_decoder", False): + compiler_args.extend(["--model-type", "transformer"]) compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) # diffusers specific diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index d91d7ed8e..d451c2b88 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -383,7 +383,8 @@ def update_past(self, past_key_values): def reorder_cache(self, past_key_values, beam_idx): for i in range(len(past_key_values)): - past_key_values[i] = torch.index_select(past_key_values[i], 0, beam_idx) + gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i]) + past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index) return past_key_values def forward( diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index 6cb53d1c0..cc86e5557 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -462,14 +462,14 @@ def forward( decoder_hidden_states = None # Skip pkv which can't be copied from memory to buffer - if output_attentions and self.config.neuron.get("output_attentions"): + if output_attentions and self.configs["decoder"].neuron.get("output_attentions"): if self.config.is_encoder_decoder: cross_attentions = outputs[-self.config.num_decoder_layers :] cur_idx += self.config.num_decoder_layers decoder_attentions = outputs[-(self.config.num_decoder_layers + cur_idx) : -cur_idx] cur_idx += self.config.num_decoder_layers - if output_hidden_states and self.config.neuron.get("output_hidden_states"): + if output_hidden_states and self.configs["decoder"].neuron.get("output_hidden_states"): decoder_hidden_states = outputs[-(self.config.num_decoder_layers + 1 + cur_idx) : -cur_idx] decoder_outputs = ModelOutput( diff --git a/optimum/neuron/version.py b/optimum/neuron/version.py index 635898473..8ff56cce1 100644 --- a/optimum/neuron/version.py +++ b/optimum/neuron/version.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.25.dev0" +__version__ = "0.0.27.dev0" __sdk_version__ = "2.20.0" diff --git a/tests/cli/test_export_cli.py b/tests/cli/test_export_cli.py index 863a9f41a..ac333fc9d 100644 --- a/tests/cli/test_export_cli.py +++ b/tests/cli/test_export_cli.py @@ -303,9 +303,6 @@ def test_replace_unet(self): check=True, ) - @unittest.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." - ) @requires_neuronx def test_encoder_decoder(self): model_id = "hf-internal-testing/tiny-random-t5" @@ -335,9 +332,6 @@ def test_encoder_decoder(self): check=True, ) - @unittest.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." - ) @requires_neuronx def test_encoder_decoder_optional_outputs(self): model_id = "hf-internal-testing/tiny-random-t5" @@ -369,9 +363,6 @@ def test_encoder_decoder_optional_outputs(self): check=True, ) - @unittest.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." - ) @requires_neuronx def test_encoder_decoder_tp2(self): model_id = "michaelbenayoun/t5-tiny-random" diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index dcf2b09dd..167ea2d6c 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -310,9 +310,6 @@ def test_export_sd_with_fused_lora_weights(self): ) -@unittest.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." -) @is_inferentia_test @requires_neuronx class NeuronEncoderDecoderExportTestCase(unittest.TestCase): diff --git a/tests/generation/test_generate.py b/tests/generation/test_generate.py index d5de5e018..6bb4ceca1 100644 --- a/tests/generation/test_generate.py +++ b/tests/generation/test_generate.py @@ -35,9 +35,6 @@ def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen assert sample_output.shape[0] == batch_size -@pytest.mark.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." -) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_beam(neuron_seq2seq_beam_path): @@ -58,9 +55,6 @@ def test_seq2seq_generation_beam(neuron_seq2seq_beam_path): assert len(output[0].unique()) <= 5 + 1 # +1 for `decoder_start_token_id` -@pytest.mark.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." -) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_with_optional_outputs): @@ -83,9 +77,6 @@ def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_ assert "decoder_hidden_states" in output -@pytest.mark.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." -) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path): @@ -106,9 +97,6 @@ def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path): assert len(output[0]) <= 5 + 1 # +1 for `decoder_start_token_id` -@pytest.mark.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." -) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_path_with_optional_outputs): @@ -129,29 +117,6 @@ def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_p assert "decoder_hidden_states" in output -@pytest.mark.skip( - "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." -) -@is_inferentia_test -@requires_neuronx -def test_seq2seq_generation_tp2(neuron_seq2seq_tp2_path): - model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_tp2_path) - tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_tp2_path) - inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt") - - output = model.generate( - **inputs, - num_return_sequences=1, - max_length=20, - output_attentions=True, - output_hidden_states=True, - return_dict_in_generate=True, - ) - assert "decoder_attentions" in output - assert "cross_attentions" in output - assert "decoder_hidden_states" in output - - @pytest.mark.skip("Makes pytest fail, to fix.") @pytest.mark.parametrize( "gen_kwargs", @@ -195,10 +160,3 @@ def test_general_seq2seq_generation(export_seq2seq_id, export_seq2seq_model_clas model = export_seq2seq_model_class.from_pretrained(export_seq2seq_id) tokenizer = AutoTokenizer.from_pretrained(export_seq2seq_id) _test_model_generation_trn(model, tokenizer, 1, 10, **gen_kwargs) - - -# Compulsory for multiprocessing tests, since we want children processes to be spawned only in the main program. -# eg. tensor parallel tracing, `neuronx_distributed.parallel_model_trace` will spawn multiple processes to trace -# and compile the model. -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/generation/test_parallel.py b/tests/generation/test_parallel.py new file mode 100644 index 000000000..b70712574 --- /dev/null +++ b/tests/generation/test_parallel.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +from transformers import AutoTokenizer + +from optimum.neuron import NeuronModelForSeq2SeqLM +from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx + + +@is_inferentia_test +@requires_neuronx +def test_seq2seq_generation_tp2(neuron_seq2seq_tp2_path): + model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_tp2_path) + tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_tp2_path) + inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt") + + output = model.generate( + **inputs, + num_return_sequences=1, + max_length=20, + output_attentions=True, + output_hidden_states=True, + return_dict_in_generate=True, + ) + assert "decoder_attentions" in output + assert "cross_attentions" in output + assert "decoder_hidden_states" in output + + +# Compulsory for multiprocessing tests, since we want children processes to be spawned only in the main program. +# eg. tensor parallel tracing, `neuronx_distributed.parallel_model_trace` will spawn multiple processes to trace +# and compile the model. +if __name__ == "__main__": + pytest.main([__file__])