Skip to content

Commit

Permalink
update tes num beams
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 28, 2024
1 parent b86eaf4 commit 6ebf667
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoTokenizer,
GenerationConfig,
MBartForConditionalGeneration,
Pix2StructForConditionalGeneration, # Pix2Struct does not work with AutoModel
PretrainedConfig,
GenerationConfig,
set_seed,
)
from transformers.modeling_outputs import ImageSuperResolutionOutput
Expand Down Expand Up @@ -2401,7 +2401,7 @@ def test_merge_from_onnx_and_save(self, model_arch):
self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents)
self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents)

@parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 3]}))
@parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 4]}))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int):
use_io_binding = None
if use_cache is False:
Expand Down Expand Up @@ -2473,7 +2473,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach

beam_search_gen_config = GenerationConfig(do_sample=False, **gen_kwargs)

if use_cache and num_beams == 3:
if use_cache and num_beams == 4:
beam_sample_gen_config = GenerationConfig(do_sample=True, **gen_kwargs)
group_beam_search_gen_config = GenerationConfig(
do_sample=False, num_beam_groups=2, diversity_penalty=0.0000001, **gen_kwargs
Expand Down

0 comments on commit 6ebf667

Please sign in to comment.