From b62eed277ee2014014af98127d353b6687b2c64b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 4 Jan 2025 00:12:08 -0500 Subject: [PATCH 01/10] OLMo 2: implemented core --- litgpt/config.py | 52 +++++++++++++++++++++++++++++++++++++++ litgpt/prompts.py | 2 ++ tests/test_model.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+) diff --git a/litgpt/config.py b/litgpt/config.py index 133a9247a1..a4ef0bc21f 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -912,6 +912,58 @@ def norm_class(self) -> Type: configs.extend(olmo) +olmo2 = [ + # https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json + dict( + name="OLMo-2-1124-7B{}", + hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"), + vocab_size=100352, + padded_vocab_size=100352, + block_size=4096, + n_embd=4096, + n_layer=32, + n_head=32, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + norm_eps=1e-06, + intermediate_size=11008, + rope_base=500000, + norm_qk=True, + ), + # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json + dict( + name="OLMo-2-1124-13B{}", + hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"), + vocab_size=100352, + padded_vocab_size=100352, + block_size=4096, + n_embd=5120, + n_layer=40, + n_head=40, + n_query_groups=40, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + norm_eps=1e-06, + intermediate_size=13824, + rope_base=500000, + nork_qk=True, + ), +] + +for c in olmo2: + for kind in ("", "-SFT", "-DPO", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + ############### # Google Gemma ############### diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 48850efd51..d815712cb4 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -368,6 +368,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama3() if re.search("Llama-3.*-Instruct-*", model_name): return Llama3() + if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name): + return Llama3() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name): diff --git a/tests/test_model.py b/tests/test_model.py index e8a110a409..2ddaec0576 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,6 +28,7 @@ from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM +from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM import litgpt.config as config_module @@ -617,6 +618,64 @@ def test_against_olmo(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_olmo2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + intermediate_size=86, + ) + T = 5 + theirs_config = Olmo2Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + intermediate_size=ours_config.intermediate_size, + num_hidden_layers=ours_config.n_layer, + num_attention_heads=ours_config.n_head, + num_key_value_heads=ours_config.n_query_groups, + max_positional_embeddings=T, + attention_bias=ours_config.bias, + rope_theta=ours_config.rope_base, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = Olmo2ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), From f559763b759dfadd67e865f218fc3704817a932e Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 4 Jan 2025 00:37:30 -0500 Subject: [PATCH 02/10] minor fix --- litgpt/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/config.py b/litgpt/config.py index a4ef0bc21f..e22c90268f 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -953,7 +953,7 @@ def norm_class(self) -> Type: norm_eps=1e-06, intermediate_size=13824, rope_base=500000, - nork_qk=True, + norm_qk=True, ), ] From 276a8fc14f8a4682aad12a5ddb80622174563477 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 4 Jan 2025 00:58:03 -0500 Subject: [PATCH 03/10] fix vocab size --- litgpt/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index e22c90268f..40583cfaf5 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -917,7 +917,7 @@ def norm_class(self) -> Type: dict( name="OLMo-2-1124-7B{}", hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"), - vocab_size=100352, + vocab_size=100278, padded_vocab_size=100352, block_size=4096, n_embd=4096, @@ -938,7 +938,7 @@ def norm_class(self) -> Type: dict( name="OLMo-2-1124-13B{}", hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"), - vocab_size=100352, + vocab_size=100278, padded_vocab_size=100352, block_size=4096, n_embd=5120, From 1ac888fd48052a48e0c59616581ee9805b8f99d7 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 4 Jan 2025 01:26:31 -0500 Subject: [PATCH 04/10] fix test_model --- tests/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model.py b/tests/test_model.py index 2ddaec0576..df1c9ab0b0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -645,6 +645,7 @@ def test_against_olmo2(model_name, device, dtype): n_layer=2, n_head=8, n_embd=32, + n_query_groups=2, intermediate_size=86, ) T = 5 From d3456e350ff1504bf5d81f4fe67febf39d3a55b9 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 7 Jan 2025 19:47:57 -0500 Subject: [PATCH 05/10] custom conversion fn for olmo2 due to new q_norm and k_norm components --- litgpt/model.py | 9 +-- litgpt/scripts/convert_hf_checkpoint.py | 84 ++++++++++++++++++++++++ litgpt/scripts/convert_lit_checkpoint.py | 61 +++++++++++++++++ 3 files changed, 150 insertions(+), 4 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 234d174466..89eb007948 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -342,11 +342,12 @@ def __init__(self, config: Config, block_idx: int) -> None: block_idx % config.sliding_window_layer_stride == 0 ) + + self.q_norm = None + self.k_norm = None if config.norm_qk: - self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps) - self.norm_k = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps) - else: - self.norm_q = self.norm_k = None + self.q_norm = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps) + self.k_norm = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps) self.config = config self.block_idx = block_idx diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 2b41bea9bf..733709469d 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -447,6 +447,86 @@ def copy_weights_qwen_2_5( pbar.update(progress_per_file) +def copy_weights_olmo2( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, + pbar: Optional[tqdm] = None, + progress_per_file: Optional[float] = None, + debug_mode: Optional[bool] = False, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.q_norm.weight", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.k_norm.weight", + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", + "model.norm.weight": "transformer.ln_f.weight", + "model.norm.bias": "transformer.ln_f.bias", + "lm_head.weight": "lm_head.weight", + } + if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + } + ) + else: + raise NotImplementedError + + if progress_per_file is not None: + progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) + + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if progress_per_file is not None: + pbar.update(progress_per_file) + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is splitted across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + if progress_per_file is not None: + pbar.update(progress_per_file) + + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -537,6 +617,10 @@ def convert_hf_checkpoint( # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights) + elif model_name.lower().startswith(("olmo-2-")): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_olmo2, config, qkv_weights) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index f276e3ae31..b7698865df 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -339,6 +339,65 @@ def copy_weights_qwen_2_5( state_dict[to_name] = param +def copy_weights_olmo2( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.attn.q_norm.weight": "model.layers.{}.self_attn.q_norm.weight", + "transformer.h.{}.attn.k_norm.weight": "model.layers.{}.self_attn.k_norm.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", + "transformer.ln_f.weight": "model.norm.weight", + "transformer.ln_f.bias": "model.norm.bias", + "lm_head.weight": "lm_head.weight", + } + if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", + } + ) + else: + raise NotImplementedError + + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: + continue + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + else: + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] @@ -383,6 +442,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: copy_fn = partial(copy_weights_phi, config) elif config.name.lower().startswith(("qwen2.5","qwq")): copy_fn = partial(copy_weights_qwen_2_5, config) + elif config.name.lower().startswith(("olmo-2-")): + copy_fn = partial(copy_weights_olmo2, config) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) From 121f8515fa2784603e3ef3804e9504cea466c849 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 7 Jan 2025 20:04:07 -0500 Subject: [PATCH 06/10] minor fix --- litgpt/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 89eb007948..062ae5701a 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -386,8 +386,8 @@ def forward( q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) if self.config.norm_qk: - q = self.norm_q(q) - k = self.norm_k(k) + q = self.q_norm(q) + k = self.k_norm(k) # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the # embedding size (C) into num_heads (nh) and head_size (hs). From 3d34921e0dab2202c1a6dd4dbb93d716235721db Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 7 Jan 2025 20:20:48 -0500 Subject: [PATCH 07/10] minor fix on test_model.py --- tests/test_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index df1c9ab0b0..6766c2058c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -39,6 +39,7 @@ copy_weights_gemma_2, copy_weights_gpt_neox, copy_weights_hf_llama, + copy_weights_olmo2, copy_weights_phi, copy_weights_qwen_2_5, ) @@ -665,7 +666,7 @@ def test_against_olmo2(model_name, device, dtype): theirs_model = Olmo2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} - copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_olmo2(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) From 15f549dfad6828285ebc2754ec01a0bf07890e0e Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 7 Jan 2025 20:36:55 -0500 Subject: [PATCH 08/10] fix: post_feedforward_layernorm --- litgpt/scripts/convert_hf_checkpoint.py | 1 + litgpt/scripts/convert_lit_checkpoint.py | 1 + 2 files changed, 2 insertions(+) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 733709469d..449c07a54b 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -471,6 +471,7 @@ def copy_weights_olmo2( "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", + "model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight", "model.norm.weight": "transformer.ln_f.weight", "model.norm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index b7698865df..afb6608f94 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -355,6 +355,7 @@ def copy_weights_olmo2( "transformer.h.{}.attn.k_norm.weight": "model.layers.{}.self_attn.k_norm.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", + "transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight", "transformer.ln_f.weight": "model.norm.weight", "transformer.ln_f.bias": "model.norm.bias", "lm_head.weight": "lm_head.weight", From ac3509fffa9f9693426fcb4d077afd93c2d5610b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 7 Jan 2025 21:08:36 -0500 Subject: [PATCH 09/10] minor fix --- litgpt/config.py | 2 ++ litgpt/scripts/convert_hf_checkpoint.py | 2 -- litgpt/scripts/convert_lit_checkpoint.py | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 40583cfaf5..4b190265bf 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -933,6 +933,7 @@ def norm_class(self) -> Type: intermediate_size=11008, rope_base=500000, norm_qk=True, + post_mlp_norm=True, ), # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json dict( @@ -954,6 +955,7 @@ def norm_class(self) -> Type: intermediate_size=13824, rope_base=500000, norm_qk=True, + post_mlp_norm=True, ), ] diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 449c07a54b..e6dcb6781e 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -460,8 +460,6 @@ def copy_weights_olmo2( ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", - "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", - "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.q_norm.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.k_norm.weight", diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index afb6608f94..f8bfac6761 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -348,8 +348,6 @@ def copy_weights_olmo2( ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", - "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", - "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.attn.q_norm.weight": "model.layers.{}.self_attn.q_norm.weight", "transformer.h.{}.attn.k_norm.weight": "model.layers.{}.self_attn.k_norm.weight", From 852ca3e9d186a3959de6d92fe57f648d141ccb5b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 7 Jan 2025 22:17:13 -0500 Subject: [PATCH 10/10] input_norm --- litgpt/config.py | 3 +++ litgpt/model.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 4b190265bf..c9300d8a5f 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -78,6 +78,7 @@ class Config: scale_embeddings: bool = False lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None + input_norm: bool = True def __post_init__(self): if not self.name: @@ -934,6 +935,7 @@ def norm_class(self) -> Type: rope_base=500000, norm_qk=True, post_mlp_norm=True, + input_norm=False, ), # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json dict( @@ -956,6 +958,7 @@ def norm_class(self) -> Type: rope_base=500000, norm_qk=True, post_mlp_norm=True, + input_norm=False, ), ] diff --git a/litgpt/model.py b/litgpt/model.py index 062ae5701a..788785133d 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -263,7 +263,7 @@ def __init__( " (non-parallel residual and shared attention norm)." ) - self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.norm_1 = None if not config.input_norm else config.norm_class(config.n_embd, eps=config.norm_eps) self.attn = CausalSelfAttention(config, block_idx) self.post_attention_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() @@ -306,7 +306,11 @@ def forward( └───► + """ - x_normed = self.norm_1(x) + if self.norm_1 is not None: + x_normed = self.norm_1(x) + else: + x_normed = x + attention_output = self.attn( x_normed, cos, sin, mask, input_pos, input_pos_maxp1 )