Skip to content

Commit

Permalink
fix for tflite as well
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Dec 19, 2024
1 parent 35d52d2 commit da441fb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.
Args:
framework (`str`, defaults to `"pt"`):
The framework for which to create the dummy inputs.
batch_size (`int`, defaults to {batch_size}):
The batch size to use in the dummy inputs.
sequence_length (`int`, defaults to {sequence_length}):
Expand Down
11 changes: 10 additions & 1 deletion optimum/exporters/tflite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
if is_tf_available():
import tensorflow as tf

from ...utils import DTYPE_MAPPER
from ..base import ExportConfig


Expand Down Expand Up @@ -191,12 +192,16 @@ def __init__(
audio_sequence_length: Optional[int] = None,
point_batch_size: Optional[int] = None,
nb_points_per_image: Optional[int] = None,
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
self._config = config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.mandatory_axes = ()
self.task = task
self._axes: Dict[str, int] = {}
self.int_dtype = int_dtype
self.float_dtype = float_dtype

# To avoid using **kwargs.
axes_values = {
Expand Down Expand Up @@ -310,12 +315,16 @@ def generate_dummy_inputs(self) -> Dict[str, "tf.Tensor"]:
"""
dummy_inputs_generators = self._create_dummy_input_generator_classes()
dummy_inputs = {}
int_dtype = DTYPE_MAPPER.tf(self.int_dtype)
float_dtype = DTYPE_MAPPER.tf(self.float_dtype)

for input_name in self.inputs:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="tf")
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, int_dtype=int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
if not input_was_inserted:
Expand Down

0 comments on commit da441fb

Please sign in to comment.