From ae33f1b22f04b1bca8c44a594183c9118557696b Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 5 Dec 2024 14:19:17 +0000 Subject: [PATCH] add torch/tf/np dtype --- optimum/utils/input_generators.py | 326 ++++++++++++++++++++++++------ 1 file changed, 263 insertions(+), 63 deletions(-) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 18a2a5a3fd1..e8c5ab9da0b 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -42,7 +42,8 @@ def check_framework_is_available(func): @functools.wraps(func) def wrapper(*args, **kwargs): - framework = kwargs.get("framework", "pt") + dtype = kwargs.get("dtype", "fp32") + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or kwargs.get("framework", "pt") pt_asked_but_not_available = framework == "pt" and not is_torch_available() tf_asked_but_not_available = framework == "tf" and not is_tf_available() if (pt_asked_but_not_available or tf_asked_but_not_available) and framework != "np": @@ -132,7 +133,13 @@ def supports_input(self, input_name: str) -> bool: return any(input_name.startswith(supported_input_name) for supported_input_name in self.SUPPORTED_INPUT_NAMES) @abstractmethod - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "int64", + float_dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "fp32", + ): """ Generates the dummy input matching `input_name` for the requested framework. @@ -141,9 +148,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int The name of the input to generate. framework (`str`, defaults to `"pt"`): The requested framework. - int_dtype (`str`, defaults to `"int64"`): + int_dtype (`Union[str, type, torch.dtype, tf.dtypes.DType]`, defaults to `"int64"`): The dtypes of generated integer tensors. - float_dtype (`str`, defaults to `"fp32"`): + float_dtype (`Union[str, type, torch.dtype, tf.dtypes.DType]`, defaults to `"fp32"`): The dtypes of generated float tensors. Returns: @@ -154,7 +161,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int @staticmethod @check_framework_is_available def random_int_tensor( - shape: List[int], max_value: int, min_value: int = 0, framework: str = "pt", dtype: str = "int64" + shape: List[int], + max_value: int, + min_value: int = 0, + dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "int64", + framework: str = "pt", ): """ Generates a tensor of random integers in the [min_value, max_value) range. @@ -166,24 +177,31 @@ def random_int_tensor( The maximum value allowed. min_value (`int`, defaults to 0): The minimum value allowed. + dtype (`Union[str, type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `"int64"`): + The numpy or torch or tensorflow dtype of the generated integer tensor. Could be int64, int32 or int8. framework (`str`, defaults to `"pt"`): - The requested framework. - dtype (`str`, defaults to `"int64"`): - The dtype of the generated integer tensor. Could be "int64", "int32", "int8". + The requested framework if it's not defined by the dtype. Returns: A random tensor in the requested framework. """ + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework + dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype if framework == "pt": - return torch.randint(low=min_value, high=max_value, size=shape, dtype=DTYPE_MAPPER.pt(dtype)) + return torch.randint(low=min_value, high=max_value, size=shape, dtype=dtype) elif framework == "tf": - return tf.random.uniform(shape, minval=min_value, maxval=max_value, dtype=DTYPE_MAPPER.tf(dtype)) + return tf.random.uniform(shape, minval=min_value, maxval=max_value, dtype=dtype) else: - return np.random.randint(min_value, high=max_value, size=shape, dtype=DTYPE_MAPPER.np(dtype)) + return np.random.randint(min_value, high=max_value, size=shape, dtype=dtype) @staticmethod @check_framework_is_available - def random_mask_tensor(shape: List[int], padding_side: str = "right", framework: str = "pt", dtype: str = "int64"): + def random_mask_tensor( + shape: List[int], + padding_side: str = "right", + dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "int64", + framework: str = "pt", + ): """ Generates a mask tensor either right or left padded. @@ -192,21 +210,23 @@ def random_mask_tensor(shape: List[int], padding_side: str = "right", framework: The shape of the random tensor. padding_side (`str`, defaults to "right"): The side on which the padding is applied. + dtype (`dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `"int64"`): + The numpy or torch or tensorflow dtype of the generated integer tensor. Could be int64, int32 or int8. framework (`str`, defaults to `"pt"`): - The requested framework. - dtype (`str`, defaults to `"int64"`): - The dtype of the generated integer tensor. Could be "int64", "int32", "int8". + The requested framework if it's not defined by the dtype. Returns: A random mask tensor either left padded or right padded in the requested framework. """ shape = tuple(shape) mask_length = random.randint(1, shape[-1] - 1) + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework + dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype if framework == "pt": mask_tensor = torch.cat( [ - torch.ones(*shape[:-1], shape[-1] - mask_length, dtype=DTYPE_MAPPER.pt(dtype)), - torch.zeros(*shape[:-1], mask_length, dtype=DTYPE_MAPPER.pt(dtype)), + torch.ones(*shape[:-1], shape[-1] - mask_length, dtype=dtype), + torch.zeros(*shape[:-1], mask_length, dtype=dtype), ], dim=-1, ) @@ -215,8 +235,8 @@ def random_mask_tensor(shape: List[int], padding_side: str = "right", framework: elif framework == "tf": mask_tensor = tf.concat( [ - tf.ones((*shape[:-1], shape[-1] - mask_length), dtype=DTYPE_MAPPER.tf(dtype)), - tf.zeros((*shape[:-1], mask_length), dtype=DTYPE_MAPPER.tf(dtype)), + tf.ones((*shape[:-1], shape[-1] - mask_length), dtype=dtype), + tf.zeros((*shape[:-1], mask_length), dtype=dtype), ], axis=-1, ) @@ -225,8 +245,8 @@ def random_mask_tensor(shape: List[int], padding_side: str = "right", framework: else: mask_tensor = np.concatenate( [ - np.ones((*shape[:-1], shape[-1] - mask_length), dtype=DTYPE_MAPPER.np(dtype)), - np.zeros((*shape[:-1], mask_length), dtype=DTYPE_MAPPER.np(dtype)), + np.ones((*shape[:-1], shape[-1] - mask_length), dtype=dtype), + np.zeros((*shape[:-1], mask_length), dtype=dtype), ], axis=-1, ) @@ -237,7 +257,11 @@ def random_mask_tensor(shape: List[int], padding_side: str = "right", framework: @staticmethod @check_framework_is_available def random_float_tensor( - shape: List[int], min_value: float = 0, max_value: float = 1, framework: str = "pt", dtype: str = "fp32" + shape: List[int], + min_value: float = 0, + max_value: float = 1, + dtype: Union[str, type, "torch.dtype", "tf.dtypes.DType"] = "fp32", + framework: str = "pt", ): """ Generates a tensor of random floats in the [min_value, max_value) range. @@ -249,26 +273,31 @@ def random_float_tensor( The minimum value allowed. max_value (`float`, defaults to 1): The maximum value allowed. + dtype (`Union[str, type, "torch.dtype", "tf.dtypes.DType"]`, defaults to `"fp32"`): + The numpy or torch or tensorflow dtype of the generated float tensor. Could be fp32, fp16 or bf16. framework (`str`, defaults to `"pt"`): - The requested framework. - dtype (`str`, defaults to `"fp32"`): - The dtype of the generated float tensor. Could be "fp32", "fp16", "bf16". + The requested framework if it's not defined by the dtype. Returns: A random tensor in the requested framework. """ + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework + dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype if framework == "pt": - tensor = torch.empty(shape, dtype=DTYPE_MAPPER.pt(dtype)).uniform_(min_value, max_value) + tensor = torch.empty(shape, dtype=dtype).uniform_(min_value, max_value) return tensor elif framework == "tf": - return tf.random.uniform(shape, minval=min_value, maxval=max_value, dtype=DTYPE_MAPPER.tf(dtype)) + return tf.random.uniform(shape, minval=min_value, maxval=max_value, dtype=dtype) else: - return np.random.uniform(low=min_value, high=max_value, size=shape).astype(DTYPE_MAPPER.np(dtype)) + return np.random.uniform(low=min_value, high=max_value, size=shape).astype(dtype) @staticmethod @check_framework_is_available def constant_tensor( - shape: List[int], value: Union[int, float] = 1, dtype: Optional[Any] = None, framework: str = "pt" + shape: List[int], + value: Union[int, float] = 1, + dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = None, + framework: str = "pt", ): """ Generates a constant tensor. @@ -278,7 +307,7 @@ def constant_tensor( The shape of the constant tensor. value (`Union[int, float]`, defaults to 1): The value to fill the constant tensor with. - dtype (`Optional[Any]`, defaults to `None`): + dtype (`Optional[Union[type, torch.dtype, tf.dtypes.DType]]`, defaults to `None`): The dtype of the constant tensor. framework (`str`, defaults to `"pt"`): The requested framework. @@ -286,6 +315,8 @@ def constant_tensor( Returns: A constant tensor in the requested framework. """ + framework = DummyInputGenerator.infer_framework_from_dtype(dtype) or framework + dtype = getattr(DTYPE_MAPPER, framework)(dtype) if isinstance(dtype, str) else dtype if framework == "pt": return torch.full(shape, value, dtype=dtype) elif framework == "tf": @@ -306,6 +337,19 @@ def _infer_framework_from_input(input_) -> str: raise RuntimeError(f"Could not infer the framework from {input_}") return framework + @staticmethod + def infer_framework_from_dtype(dtype): + framework = None + if is_torch_available() and isinstance(dtype, torch.dtype): + framework = "pt" + elif is_tf_available() and isinstance(dtype, tf.dtypes.DType): + framework = "tf" + elif isinstance(dtype, type): + framework = "np" + else: + framework = None + return framework + @classmethod def concat_inputs(cls, inputs, dim: int): """ @@ -430,8 +474,8 @@ def generate( self, input_name: str, framework: str = "pt", - int_dtype: str = "int64", - float_dtype: str = "fp32", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", ): min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size @@ -487,8 +531,8 @@ def generate( self, input_name: str, framework: str = "pt", - int_dtype: str = "int64", - float_dtype: str = "fp32", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", ): min_value = 0 max_value = self.tag_pad_id if input_name == "xpath_tags_seq" else self.subs_pad_id @@ -526,7 +570,13 @@ def __init__(self, *args, **kwargs): self.state_dim = self.normalized_config.config.state_dim self.max_ep_len = self.normalized_config.config.max_ep_len - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "states": shape = [self.batch_size, self.sequence_length, self.state_dim] elif input_name == "actions": @@ -580,7 +630,13 @@ def __init__( else: self.hidden_size = normalized_config.hidden_size - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name in ["encoder_outputs", "encoder_hidden_states"]: return ( self.random_float_tensor( @@ -628,7 +684,13 @@ def __init__( else: self.sequence_length = sequence_length - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): shape = ( self.batch_size, self.num_attention_heads, @@ -696,7 +758,13 @@ def __init__( self.decoder_hidden_size = self.normalized_config.hidden_size self.decoder_num_layers = self.normalized_config.decoder_num_layers - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "past_key_values": encoder_shape = ( self.batch_size, @@ -762,7 +830,13 @@ def __init__( else: self.sequence_length = sequence_length - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): return self.random_int_tensor( [self.batch_size, self.sequence_length, 4], # TODO: find out why this fails with the commented code. @@ -816,7 +890,13 @@ def __init__( self.batch_size = batch_size self.height, self.width = self.image_size - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "pixel_mask": return self.random_int_tensor( shape=[self.batch_size, self.height, self.width], @@ -856,7 +936,13 @@ def __init__( self.batch_size = batch_size self.sequence_length = audio_sequence_length - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "input_values": # raw waveform return self.random_float_tensor( shape=[self.batch_size, self.sequence_length], @@ -906,7 +992,13 @@ def __init__( self.batch_size = batch_size self.time_cond_proj_dim = getattr(normalized_config.config, "time_cond_proj_dim", None) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "timestep": shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture) return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) @@ -956,7 +1048,13 @@ def __init__( self.sequence_length = kwargs.get("sequence_length", None) self.num_labels = kwargs.get("num_labels", None) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): max_value = self.num_labels if self.num_labels is not None else 0 if self.sequence_length is None: shape = [self.batch_size] @@ -988,7 +1086,13 @@ def __init__( self.point_batch_size = point_batch_size self.nb_points_per_image = nb_points_per_image - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "input_points": shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2] return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) @@ -1021,9 +1125,15 @@ def __init__( output_channels if output_channels is not None else normalized_config.vision_config.output_channels ) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): shape = [self.batch_size, self.output_channels, self.image_embedding_size, self.image_embedding_size] - return self.random_float_tensor(shape, framework=framework) + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) class DummyPix2StructInputGenerator(DummyInputGenerator): @@ -1052,13 +1162,25 @@ def __init__( self.flattened_patch_size = 2 + patch_height * patch_width * num_channels self.max_patches = preprocessors[1].image_processor.max_patches - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): shape = [self.batch_size, self.max_patches, self.flattened_patch_size] return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): past_key_value_shape = ( self.batch_size, self.sequence_length, @@ -1071,7 +1193,13 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if check_if_transformers_greater("4.44"): return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) else: @@ -1116,7 +1244,13 @@ def __init__( ) self.num_kv_heads = normalized_config.num_kv_heads - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): past_shape = ( self.batch_size * self.num_kv_heads, self.sequence_length, @@ -1158,7 +1292,13 @@ def __init__( ) self.head_dim = self.hidden_size // self.num_attention_heads - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): past_key_shape = ( self.batch_size, self.num_kv_heads, @@ -1201,7 +1341,13 @@ def __init__( ) self.num_key_value_heads = normalized_config.num_key_value_heads - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): shape = ( self.batch_size, self.num_key_value_heads, @@ -1239,7 +1385,13 @@ def __init__( self.num_key_value_heads = normalized_config.num_key_value_heads self.head_dim = normalized_config.head_dim - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): shape = ( self.batch_size, self.num_key_value_heads, @@ -1272,7 +1424,13 @@ def __init__( self.speaker_embedding_dim = normalized_config.speaker_embedding_dim self.num_mel_bins = normalized_config.num_mel_bins - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "output_sequence": shape = [self.batch_size, self.sequence_length, self.num_mel_bins] elif input_name == "speaker_embeddings": @@ -1326,7 +1484,13 @@ def __init__( self.num_layers = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_layers self.use_cross_attention = False - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): decoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size decoder_num_attention_heads = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_attention_heads decoder_shape = ( @@ -1390,7 +1554,13 @@ def __init__( ) self.num_codebooks = normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_codebooks - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name in ["decoder_input_ids"]: min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size @@ -1417,7 +1587,13 @@ def __init__( self.num_codebooks = normalized_config.decoder.num_codebooks self.sequence_length = sequence_length - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "audio_codes": # Kind of a hack to use `self.sequence_length` here, for Musicgen pad tokens are filtered out, see # https://github.com/huggingface/transformers/blob/31c575bcf13c2b85b65d652dd1b5b401f99be999/src/transformers/models/musicgen/modeling_musicgen.py#L2458 @@ -1452,8 +1628,8 @@ def generate( self, input_name: str, framework: str = "pt", - int_dtype: str = "int64", - float_dtype: str = "fp32", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", ): return self.random_int_tensor(shape=(1,), min_value=20, max_value=22, framework=framework, dtype=int_dtype) @@ -1461,7 +1637,13 @@ def generate( class DummyTransformerTimestepInputGenerator(DummyTimestepInputGenerator): SUPPORTED_INPUT_NAMES = ("timestep",) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "timestep": shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) @@ -1479,7 +1661,13 @@ class DummyTransformerTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): "pooled_projection", ) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "encoder_hidden_states": return super().generate(input_name, framework, int_dtype, float_dtype)[0] @@ -1497,7 +1685,13 @@ class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenera "img_ids", ) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "hidden_states": shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels] return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) @@ -1520,7 +1714,13 @@ class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator) "txt_ids", ) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "int64", + float_dtype: Optional[Union[str, type, "torch.dtype", "tf.dtypes.DType"]] = "fp32", + ): if input_name == "txt_ids": shape = ( [self.sequence_length, 3]