Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model: Add support for PhiMoE arch #11003

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,6 +2565,63 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))


@Model.register("PhiMoEForCausalLM")
class PhiMoeModel(Phi3MiniModel):
model_arch = gguf.MODEL_ARCH.PHIMOE

_experts: list[dict[str, Tensor]] | None = None

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["w1", "w2", "w3"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("PlamoForCausalLM")
class PlamoModel(Model):
model_arch = gguf.MODEL_ARCH.PLAMO
Expand Down
8 changes: 4 additions & 4 deletions docs/development/HOWTO-add-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
```python
@Model.register("MyModelForCausalLM")
class MyModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
model_arch = gguf.MODEL_ARCH.MYMODEL
```

2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
Expand Down Expand Up @@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
- `Model#set_vocab`
- `Model#write_tensors`

NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.

### 2. Define the model architecture in `llama.cpp`

The model params and tensors layout must be defined in `llama.cpp`:
1. Define a new `llm_arch`
2. Define the tensors layout in `LLM_TENSOR_NAMES`
3. Add any non standard metadata in `llm_load_hparams`
3. Add any non-standard metadata in `llm_load_hparams`
4. Create the tensors for inference in `llm_load_tensors`
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`

Expand All @@ -98,7 +98,7 @@ This is the funniest part, you have to provide the inference graph implementatio

Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
phymbert marked this conversation as resolved.
Show resolved Hide resolved

When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
Some `ggml` backends do not support all operations, backend implementation can be added in a separate PR.
phymbert marked this conversation as resolved.
Show resolved Hide resolved

Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).

Expand Down
20 changes: 20 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
Expand Down Expand Up @@ -424,6 +425,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
Expand Down Expand Up @@ -934,6 +936,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PHIMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FACTORS_LONG,
MODEL_TENSOR.ROPE_FACTORS_SHORT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.CODESHELL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
Expand Down
37 changes: 20 additions & 17 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -68,7 +68,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
Expand Down Expand Up @@ -108,7 +108,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
Expand Down Expand Up @@ -152,7 +152,7 @@ class TensorNameMap:

# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
Expand All @@ -165,7 +165,7 @@ class TensorNameMap:

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
Expand All @@ -179,7 +179,7 @@ class TensorNameMap:

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
Expand All @@ -197,7 +197,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
Expand Down Expand Up @@ -242,7 +242,7 @@ class TensorNameMap:
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
"layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
Expand All @@ -265,7 +265,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
Expand Down Expand Up @@ -306,10 +306,11 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
),

MODEL_TENSOR.FFN_UP_SHEXP: (
Expand Down Expand Up @@ -338,10 +339,11 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
),

MODEL_TENSOR.FFN_GATE_SHEXP: (
Expand Down Expand Up @@ -383,6 +385,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
),

MODEL_TENSOR.FFN_DOWN_SHEXP: (
Expand Down
Loading
Loading