Skip to content

Commit

Permalink
enable mistral model normalized config
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Oct 16, 2023
1 parent 6e15777 commit 10c7c53
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,13 @@ def __getattr__(self, attr_name):
MPTNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)

GPTBigCodeNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="n_head", hidden_size="n_embd", num_layers="n_layer"
)
MistralNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="num_key_value_heads", num_layers="num_hidden_layers"
)

WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
hidden_size="d_model",
Expand All @@ -167,8 +171,6 @@ def __getattr__(self, attr_name):
allow_new=True,
)

MistralNormalizedTextConfig = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)


class NormalizedConfigManager:
"""
Expand Down Expand Up @@ -211,11 +213,9 @@ class NormalizedConfigManager:
# "big_bird": NormalizedTextConfig,
# "bigbird_pegasus": BartLikeNormalizedTextConfig,
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot-small": BartLikeNormalizedTextConfig,
"blenderbot_small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"falcon": NormalizedTextConfig.with_args(
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads"
),
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
"camembert": NormalizedTextConfig,
"codegen": GPT2LikeNormalizedTextConfig,
"cvt": NormalizedVisionConfig,
Expand All @@ -227,18 +227,17 @@ class NormalizedConfigManager:
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mistral": MistralNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"opt": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
Expand All @@ -247,7 +246,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"speech_to_text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
Expand All @@ -257,6 +256,8 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"mpt": MPTNormalizedTextConfig,
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
"mistral": MistralNormalizedTextConfig,
}

@classmethod
Expand All @@ -270,6 +271,5 @@ def check_supported_model(cls, model_type: str):

@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

0 comments on commit 10c7c53

Please sign in to comment.