Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OLMo 2 (WIP) #1897

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Config:
scale_embeddings: bool = False
lm_head_bias: bool = False
final_logit_softcapping: Optional[float] = None
input_norm: bool = True

def __post_init__(self):
if not self.name:
Expand Down Expand Up @@ -912,6 +913,62 @@ def norm_class(self) -> Type:

configs.extend(olmo)

olmo2 = [
# https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json
dict(
name="OLMo-2-1124-7B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"),
vocab_size=100278,
padded_vocab_size=100352,
block_size=4096,
n_embd=4096,
n_layer=32,
n_head=32,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=11008,
rope_base=500000,
norm_qk=True,
post_mlp_norm=True,
input_norm=False,
),
# https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json
dict(
name="OLMo-2-1124-13B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"),
vocab_size=100278,
padded_vocab_size=100352,
block_size=4096,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=40,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=13824,
rope_base=500000,
norm_qk=True,
post_mlp_norm=True,
input_norm=False,
),
]

for c in olmo2:
for kind in ("", "-SFT", "-DPO", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)

###############
# Google Gemma
###############
Expand Down
21 changes: 13 additions & 8 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def __init__(
" (non-parallel residual and shared attention norm)."
)

self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.norm_1 = None if not config.input_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config, block_idx)
self.post_attention_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
Expand Down Expand Up @@ -306,7 +306,11 @@ def forward(
└───► +
"""

x_normed = self.norm_1(x)
if self.norm_1 is not None:
x_normed = self.norm_1(x)
else:
x_normed = x

attention_output = self.attn(
x_normed, cos, sin, mask, input_pos, input_pos_maxp1
)
Expand Down Expand Up @@ -342,11 +346,12 @@ def __init__(self, config: Config, block_idx: int) -> None:
block_idx % config.sliding_window_layer_stride == 0
)


self.q_norm = None
self.k_norm = None
if config.norm_qk:
self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps)
self.norm_k = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps)
else:
self.norm_q = self.norm_k = None
self.q_norm = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps)
self.k_norm = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps)

self.config = config
self.block_idx = block_idx
Expand Down Expand Up @@ -385,8 +390,8 @@ def forward(
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*)

if self.config.norm_qk:
q = self.norm_q(q)
k = self.norm_k(k)
q = self.q_norm(q)
k = self.k_norm(k)

# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
# embedding size (C) into num_heads (nh) and head_size (hs).
Expand Down
2 changes: 2 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Llama3()
if re.search("Llama-3.*-Instruct-*", model_name):
return Llama3()
if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name):
return Llama3()
if re.search("FreeWilly2", model_name):
return FreeWilly2()
if re.search("Platypus", model_name):
Expand Down
83 changes: 83 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,85 @@ def copy_weights_qwen_2_5(
pbar.update(progress_per_file)


def copy_weights_olmo2(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.q_norm.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.k_norm.weight",
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
"model.norm.weight": "transformer.ln_f.weight",
"model.norm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
)
else:
raise NotImplementedError

if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if progress_per_file is not None:
pbar.update(progress_per_file)

if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is splitted across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -537,6 +616,10 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
elif model_name.lower().startswith(("olmo-2-")):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_olmo2, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
60 changes: 60 additions & 0 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,64 @@ def copy_weights_qwen_2_5(
state_dict[to_name] = param


def copy_weights_olmo2(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.attn.q_norm.weight": "model.layers.{}.self_attn.q_norm.weight",
"transformer.h.{}.attn.k_norm.weight": "model.layers.{}.self_attn.k_norm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias",
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
"transformer.ln_f.weight": "model.norm.weight",
"transformer.ln_f.bias": "model.norm.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError

for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)

for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
Expand Down Expand Up @@ -383,6 +441,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
copy_fn = partial(copy_weights_phi, config)
elif config.name.lower().startswith(("qwen2.5","qwq")):
copy_fn = partial(copy_weights_qwen_2_5, config)
elif config.name.lower().startswith(("olmo-2-")):
copy_fn = partial(copy_weights_olmo2, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)
Expand Down
61 changes: 61 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM
from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

import litgpt.config as config_module
Expand All @@ -38,6 +39,7 @@
copy_weights_gemma_2,
copy_weights_gpt_neox,
copy_weights_hf_llama,
copy_weights_olmo2,
copy_weights_phi,
copy_weights_qwen_2_5,
)
Expand Down Expand Up @@ -617,6 +619,65 @@ def test_against_olmo(model_name, device, dtype):
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_olmo2(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = Olmo2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = Olmo2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_olmo2(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
Expand Down
Loading