From ea4349d98f18b951c296ba919a5f38ace59d08da Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Nov 2023 12:33:08 +0100 Subject: [PATCH] Refactor NormalizedConfigs for GQA (#1539) --- optimum/utils/normalized_config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 340fea4f1f9..7a0af9a1a48 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -85,6 +85,10 @@ class NormalizedTextConfig(NormalizedConfig): EOS_TOKEN_ID = "eos_token_id" +class NormalizedTextConfigWithGQA(NormalizedTextConfig): + NUM_KEY_VALUE_HEADS = "num_key_value_heads" + + class NormalizedSeq2SeqConfig(NormalizedTextConfig): ENCODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS DECODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS @@ -166,8 +170,6 @@ def __getattr__(self, attr_name): allow_new=True, ) -MistralNormalizedTextConfig = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) - class NormalizedConfigManager: """ @@ -227,13 +229,13 @@ class NormalizedConfigManager: "gpt-bigcode": GPTBigCodeNormalizedTextConfig, "gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), "gpt-neox": NormalizedTextConfig, - "llama": NormalizedTextConfig, + "llama": NormalizedTextConfigWithGQA, "gptj": GPT2LikeNormalizedTextConfig, "imagegpt": GPT2LikeNormalizedTextConfig, "longt5": T5LikeNormalizedTextConfig, "marian": BartLikeNormalizedTextConfig, "mbart": BartLikeNormalizedTextConfig, - "mistral": MistralNormalizedTextConfig, + "mistral": NormalizedTextConfigWithGQA, "mt5": T5LikeNormalizedTextConfig, "m2m-100": BartLikeNormalizedTextConfig, "nystromformer": NormalizedTextConfig,