From bb185a0d460c59cdbe02cbf41979b3f1c6d5d63f Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 18 Dec 2024 17:37:44 +0000 Subject: [PATCH] apply suggestion --- optimum/utils/input_generators.py | 316 ++++++++++++++++-------------- 1 file changed, 169 insertions(+), 147 deletions(-) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index e8c5ab9da0..01309ff9f5 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -14,8 +14,8 @@ # limitations under the License. """Dummy input generation classes.""" -import functools import random +import warnings from abc import ABC, abstractmethod from typing import Any, List, Optional, Tuple, Union @@ -39,21 +39,6 @@ import tensorflow as tf # type: ignore -def check_framework_is_available(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - dtype = kwargs.get("dtype", "fp32") - framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or kwargs.get("framework", "pt") - pt_asked_but_not_available = framework == "pt" and not is_torch_available() - tf_asked_but_not_available = framework == "tf" and not is_tf_available() - if (pt_asked_but_not_available or tf_asked_but_not_available) and framework != "np": - framework_name = "PyTorch" if framework == "pt" else "TensorFlow" - raise RuntimeError(f"Requested the {framework_name} framework, but it does not seem installed.") - return func(*args, **kwargs) - - return wrapper - - DEFAULT_DUMMY_SHAPES = { "batch_size": 2, "sequence_length": 16, @@ -136,9 +121,9 @@ def supports_input(self, input_name: str) -> bool: def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "int64", - float_dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "fp32", + framework: Optional[str] = None, + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, ): """ Generates the dummy input matching `input_name` for the requested framework. @@ -146,12 +131,12 @@ def generate( Args: input_name (`str`): The name of the input to generate. - framework (`str`, defaults to `"pt"`): - The requested framework. - int_dtype (`Union[str, type, torch.dtype, tf.dtypes.DType]`, defaults to `"int64"`): - The dtypes of generated integer tensors. - float_dtype (`Union[str, type, torch.dtype, tf.dtypes.DType]`, defaults to `"fp32"`): - The dtypes of generated float tensors. + framework (`Optional[str]`, defaults to `None`): + Deprecated, please use the `int_dtype` or `float_dtype` argument to indicate the framework. + int_dtype (`Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]]`, defaults to `None`): + The dtypes of generated integer tensors. Could be int64, int32 or int8, defaults to `torch.int64` if not given. + float_dtype (`Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]]`, defaults to `None`): + The dtypes of generated float tensors. Could be fp32, fp16 or bf16, defaults to `torch.float32` if not given. Returns: A tensor in the requested framework of the input. @@ -159,13 +144,12 @@ def generate( raise NotImplementedError @staticmethod - @check_framework_is_available def random_int_tensor( shape: List[int], max_value: int, min_value: int = 0, - dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "int64", - framework: str = "pt", + dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): """ Generates a tensor of random integers in the [min_value, max_value) range. @@ -177,16 +161,21 @@ def random_int_tensor( The maximum value allowed. min_value (`int`, defaults to 0): The minimum value allowed. - dtype (`Union[str, type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `"int64"`): - The numpy or torch or tensorflow dtype of the generated integer tensor. Could be int64, int32 or int8. - framework (`str`, defaults to `"pt"`): - The requested framework if it's not defined by the dtype. + dtype (`Union[type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `None`): + The numpy or torch or tensorflow dtype of the generated integer tensor. Could be int64, int32 or int8, defaults to `torch.int64` if not given. + framework (`Optional[str]`, defaults to `None`): + Deprecated, please use the `dtype` argument to indicate the framework. Returns: A random tensor in the requested framework. """ - framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework - dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype + if framework is not None: + warnings.warn( + "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", + FutureWarning, + ) + dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) if framework == "pt": return torch.randint(low=min_value, high=max_value, size=shape, dtype=dtype) elif framework == "tf": @@ -195,12 +184,11 @@ def random_int_tensor( return np.random.randint(min_value, high=max_value, size=shape, dtype=dtype) @staticmethod - @check_framework_is_available def random_mask_tensor( shape: List[int], padding_side: str = "right", - dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "int64", - framework: str = "pt", + dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): """ Generates a mask tensor either right or left padded. @@ -210,18 +198,23 @@ def random_mask_tensor( The shape of the random tensor. padding_side (`str`, defaults to "right"): The side on which the padding is applied. - dtype (`dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `"int64"`): - The numpy or torch or tensorflow dtype of the generated integer tensor. Could be int64, int32 or int8. - framework (`str`, defaults to `"pt"`): - The requested framework if it's not defined by the dtype. + dtype (`Union[type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `None`): + The numpy or torch or tensorflow dtype of the generated integer tensor. Could be int64, int32 or int8, defaults to `torch.int64` if not given. + framework (`Optional[str]`, defaults to `None`): + Deprecated, please use the `dtype` argument to indicate the framework. Returns: A random mask tensor either left padded or right padded in the requested framework. """ + if framework is not None: + warnings.warn( + "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", + FutureWarning, + ) shape = tuple(shape) mask_length = random.randint(1, shape[-1] - 1) - framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework - dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype + dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) if framework == "pt": mask_tensor = torch.cat( [ @@ -255,13 +248,12 @@ def random_mask_tensor( return mask_tensor @staticmethod - @check_framework_is_available def random_float_tensor( shape: List[int], min_value: float = 0, max_value: float = 1, - dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "fp32", - framework: str = "pt", + dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): """ Generates a tensor of random floats in the [min_value, max_value) range. @@ -273,16 +265,21 @@ def random_float_tensor( The minimum value allowed. max_value (`float`, defaults to 1): The maximum value allowed. - dtype (`Union[str, type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `"fp32"`): - The numpy or torch or tensorflow dtype of the generated float tensor. Could be fp32, fp16 or bf16. - framework (`str`, defaults to `"pt"`): - The requested framework if it's not defined by the dtype. + dtype (`Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]]`, defaults to `None`): + The numpy or torch or tensorflow dtype of the generated float tensor. Could be fp32, fp16 or bf16, defaults to `torch.float32` if not given. + framework (`Optional[str]`, defaults to `None`): + Deprecated, please use the `dtype` argument to indicate the framework. Returns: A random tensor in the requested framework. """ - framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework - dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype + if framework is not None: + warnings.warn( + "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", + FutureWarning, + ) + dtype = DummyInputGenerator._set_default_float_dtype() if dtype is None else dtype + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) if framework == "pt": tensor = torch.empty(shape, dtype=dtype).uniform_(min_value, max_value) return tensor @@ -292,12 +289,11 @@ def random_float_tensor( return np.random.uniform(low=min_value, high=max_value, size=shape).astype(dtype) @staticmethod - @check_framework_is_available def constant_tensor( shape: List[int], value: Union[int, float] = 1, - dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = None, - framework: str = "pt", + dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): """ Generates a constant tensor. @@ -309,14 +305,19 @@ def constant_tensor( The value to fill the constant tensor with. dtype (`Optional[Union[type, torch.dtype, tf.dtypes.DType]]`, defaults to `None`): The dtype of the constant tensor. - framework (`str`, defaults to `"pt"`): - The requested framework. + framework (`Optional[str]`, defaults to `None`): + Deprecated, please use the `dtype` argument to indicate the framework. Returns: A constant tensor in the requested framework. """ + if framework is not None: + warnings.warn( + "The `framework` argument is deprecated and will be removed soon. Please use the `dtype` argument instead to indicate the framework.", + FutureWarning, + ) + dtype = DummyInputGenerator._set_default_int_dtype() if dtype is None else dtype framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework - dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype if framework == "pt": return torch.full(shape, value, dtype=dtype) elif framework == "tf": @@ -337,9 +338,28 @@ def _infer_framework_from_input(input_) -> str: raise RuntimeError(f"Could not infer the framework from {input_}") return framework + @staticmethod + def _set_default_int_dtype(): + "Default to int64 of available framework." + if is_torch_available(): + return torch.int64 + elif is_tf_available(): + return tf.int64 + else: + return np.int64 + + @staticmethod + def _set_default_float_dtype(): + "Default to float32 of available framework." + if is_torch_available(): + return torch.int64 + elif is_tf_available(): + return tf.int64 + else: + return np.int64 + @staticmethod def infer_framework_from_dtype(dtype): - framework = None if is_torch_available() and isinstance(dtype, torch.dtype): framework = "pt" elif is_tf_available() and isinstance(dtype, tf.dtypes.DType): @@ -347,7 +367,9 @@ def infer_framework_from_dtype(dtype): elif isinstance(dtype, type): framework = "np" else: - framework = None + raise ValueError( + f"Unable to create a tensor/array with dtype({dtype}). Please ensure the corresponding framework is installed." + ) return framework @classmethod @@ -473,9 +495,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size @@ -530,9 +552,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): min_value = 0 max_value = self.tag_pad_id if input_name == "xpath_tags_seq" else self.subs_pad_id @@ -573,9 +595,9 @@ def __init__(self, *args, **kwargs): def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "states": shape = [self.batch_size, self.sequence_length, self.state_dim] @@ -633,9 +655,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name in ["encoder_outputs", "encoder_hidden_states"]: return ( @@ -687,9 +709,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): shape = ( self.batch_size, @@ -761,9 +783,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "past_key_values": encoder_shape = ( @@ -833,9 +855,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): return self.random_int_tensor( [self.batch_size, self.sequence_length, 4], @@ -893,9 +915,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "pixel_mask": return self.random_int_tensor( @@ -939,9 +961,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "input_values": # raw waveform return self.random_float_tensor( @@ -995,9 +1017,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "timestep": shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture) @@ -1051,9 +1073,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): max_value = self.num_labels if self.num_labels is not None else 0 if self.sequence_length is None: @@ -1089,9 +1111,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "input_points": shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2] @@ -1128,9 +1150,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): shape = [self.batch_size, self.output_channels, self.image_embedding_size, self.image_embedding_size] return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) @@ -1165,9 +1187,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): shape = [self.batch_size, self.max_patches, self.flattened_patch_size] return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) @@ -1177,9 +1199,9 @@ class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): past_key_value_shape = ( self.batch_size, @@ -1196,9 +1218,9 @@ class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if check_if_transformers_greater("4.44"): return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) @@ -1247,9 +1269,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): past_shape = ( self.batch_size * self.num_kv_heads, @@ -1295,9 +1317,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): past_key_shape = ( self.batch_size, @@ -1344,9 +1366,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): shape = ( self.batch_size, @@ -1388,9 +1410,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): shape = ( self.batch_size, @@ -1427,9 +1449,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "output_sequence": shape = [self.batch_size, self.sequence_length, self.num_mel_bins] @@ -1487,9 +1509,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): decoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size decoder_num_attention_heads = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_attention_heads @@ -1557,9 +1579,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name in ["decoder_input_ids"]: min_value = 0 @@ -1590,9 +1612,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "audio_codes": # Kind of a hack to use `self.sequence_length` here, for Musicgen pad tokens are filtered out, see @@ -1627,9 +1649,9 @@ def __init__( def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): return self.random_int_tensor(shape=(1,), min_value=20, max_value=22, framework=framework, dtype=int_dtype) @@ -1640,9 +1662,9 @@ class DummyTransformerTimestepInputGenerator(DummyTimestepInputGenerator): def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "timestep": shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor @@ -1664,9 +1686,9 @@ class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "encoder_hidden_states": return super().generate(input_name, framework, int_dtype, float_dtype)[0] @@ -1688,9 +1710,9 @@ class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenera def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "hidden_states": shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels] @@ -1717,9 +1739,9 @@ class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator) def generate( self, input_name: str, - framework: str = "pt", - int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", - float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + int_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + float_dtype: Optional[Union[type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: Optional[str] = None, ): if input_name == "txt_ids": shape = (