diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 335bb4dabcf..4be6e41bbe2 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -145,9 +145,13 @@ def __getattr__(self, attr_name): MPTNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) + GPTBigCodeNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="n_head", hidden_size="n_embd", num_layers="n_layer" ) +MistralNormalizedTextConfig = NormalizedTextConfig.with_args( + num_attention_heads="num_key_value_heads", num_layers="num_hidden_layers" +) WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args( hidden_size="d_model", @@ -167,8 +171,6 @@ def __getattr__(self, attr_name): allow_new=True, ) -MistralNormalizedTextConfig = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) - class NormalizedConfigManager: """ @@ -211,11 +213,9 @@ class NormalizedConfigManager: # "big_bird": NormalizedTextConfig, # "bigbird_pegasus": BartLikeNormalizedTextConfig, "blenderbot": BartLikeNormalizedTextConfig, - "blenderbot-small": BartLikeNormalizedTextConfig, + "blenderbot_small": BartLikeNormalizedTextConfig, "bloom": NormalizedTextConfig.with_args(num_layers="n_layer"), - "falcon": NormalizedTextConfig.with_args( - num_layers="num_hidden_layers", num_attention_heads="num_attention_heads" - ), + "falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"), "camembert": NormalizedTextConfig, "codegen": GPT2LikeNormalizedTextConfig, "cvt": NormalizedVisionConfig, @@ -227,18 +227,17 @@ class NormalizedConfigManager: "electra": NormalizedTextConfig, "encoder-decoder": NormalizedEncoderDecoderConfig, "gpt2": GPT2LikeNormalizedTextConfig, - "gpt-bigcode": GPTBigCodeNormalizedTextConfig, - "gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), - "gpt-neox": NormalizedTextConfig, + "gpt-bigcode": GPT2LikeNormalizedTextConfig, + "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), + "gpt_neox": NormalizedTextConfig, "llama": NormalizedTextConfig, "gptj": GPT2LikeNormalizedTextConfig, "imagegpt": GPT2LikeNormalizedTextConfig, "longt5": T5LikeNormalizedTextConfig, "marian": BartLikeNormalizedTextConfig, "mbart": BartLikeNormalizedTextConfig, - "mistral": MistralNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, - "m2m-100": BartLikeNormalizedTextConfig, + "m2m_100": BartLikeNormalizedTextConfig, "nystromformer": NormalizedTextConfig, "opt": NormalizedTextConfig, "pegasus": BartLikeNormalizedTextConfig, @@ -247,7 +246,7 @@ class NormalizedConfigManager: "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, - "speech-to-text": SpeechToTextLikeNormalizedTextConfig, + "speech_to_text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, "trocr": TrOCRLikeNormalizedTextConfig, @@ -257,6 +256,8 @@ class NormalizedConfigManager: "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, "mpt": MPTNormalizedTextConfig, + "gpt_bigcode": GPTBigCodeNormalizedTextConfig, + "mistral": MistralNormalizedTextConfig, } @classmethod @@ -270,6 +271,5 @@ def check_supported_model(cls, model_type: str): @classmethod def get_normalized_config_class(cls, model_type: str) -> Type: - model_type = model_type.replace("_", "-") cls.check_supported_model(model_type) return cls._conf[model_type]