diff --git a/README.md b/README.md index d6d1958c8fc03..efd02f2e3b2f9 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) - [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557) - [x] [Phi models](https://huggingface.co/models?search=microsoft/phi) +- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003) - [x] [GPT-2](https://huggingface.co/gpt2) - [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118) - [x] [InternLM2](https://huggingface.co/models?search=internlm2) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b6c15da94ec1d..5a77b18f2e022 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2565,6 +2565,63 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) +@Model.register("PhiMoEForCausalLM") +class PhiMoeModel(Phi3MiniModel): + model_arch = gguf.MODEL_ARCH.PHIMOE + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @Model.register("PlamoForCausalLM") class PlamoModel(Model): model_arch = gguf.MODEL_ARCH.PLAMO diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 04c5ccbbe60c3..8fcd7081130f2 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -28,7 +28,7 @@ The required steps to implement for an HF model are: ```python @Model.register("MyModelForCausalLM") class MyModel(Model): - model_arch = gguf.MODEL_ARCH.GROK + model_arch = gguf.MODEL_ARCH.MYMODEL ``` 2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py) @@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi - `Model#set_vocab` - `Model#write_tensors` -NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights. +NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights. ### 2. Define the model architecture in `llama.cpp` The model params and tensors layout must be defined in `llama.cpp`: 1. Define a new `llm_arch` 2. Define the tensors layout in `LLM_TENSOR_NAMES` -3. Add any non standard metadata in `llm_load_hparams` +3. Add any non-standard metadata in `llm_load_hparams` 4. Create the tensors for inference in `llm_load_tensors` 5. If the model has a RoPE operation, add the rope type in `llama_rope_type` @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. -Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`. +Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. -When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR. +Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/). diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 273370370e6ca..d49e2a917d201 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -242,6 +242,7 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() PHI2 = auto() PHI3 = auto() + PHIMOE = auto() PLAMO = auto() CODESHELL = auto() ORION = auto() @@ -424,6 +425,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", @@ -934,6 +936,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PHIMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.CODESHELL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7009a11d46bc8..a9c342f3fd858 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -55,7 +55,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -68,7 +68,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -108,7 +108,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -152,7 +152,7 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert @@ -165,7 +165,7 @@ class TensorNameMap: # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert @@ -179,7 +179,7 @@ class TensorNameMap: # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j @@ -197,7 +197,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert @@ -242,7 +242,7 @@ class TensorNameMap: "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi @@ -265,7 +265,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral - "model.layers.{bid}.block_sparse_moe.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe "model.layers.{bid}.mlp.gate", # qwen2moe olmoe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx @@ -306,10 +306,11 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_UP_EXP: ( - "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx - "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -338,10 +339,11 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx - "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -383,6 +385,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe + "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( diff --git a/src/llama.cpp b/src/llama.cpp index 4d41602fe2010..5cd6f6ba91ae3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -167,6 +167,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_PHI2, LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -225,6 +226,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, @@ -1042,6 +1044,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -2567,6 +2590,7 @@ enum e_model { MODEL_8x7B, MODEL_8x22B, MODEL_16x12B, + MODEL_16x3_8B, MODEL_10B_128x3_66B, MODEL_57B_A14B, MODEL_27B, @@ -5586,6 +5610,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_8x7B: return "8x7B"; case MODEL_8x22B: return "8x22B"; case MODEL_16x12B: return "16x12B"; + case MODEL_16x3_8B: return "16x3.8B"; case MODEL_10B_128x3_66B: return "10B+128x3.66B"; case MODEL_57B_A14B: return "57B.A14B"; case MODEL_27B: return "27B"; @@ -6017,6 +6042,15 @@ static void llm_load_hparams( throw std::runtime_error("invalid value for sliding_window"); } } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_16x3_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PLAMO: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8638,6 +8672,50 @@ static bool llm_load_tensors( layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } @@ -13620,7 +13698,7 @@ struct llm_build_context { struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, - NULL, + model.layers[il].attn_norm_b, LLM_NORM_RMS, cb, il); cb(attn_norm_output, "attn_norm", il); @@ -13635,8 +13713,7 @@ struct llm_build_context { Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); - } - else { + } else { Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); @@ -13680,14 +13757,12 @@ struct llm_build_context { residual = cur; cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].ffn_norm, NULL, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - // FF - // special-case: the up and gate tensors are merged into a single tensor - // TOOD: support into llm_build_ffn - { + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, @@ -13695,6 +13770,18 @@ struct llm_build_context { NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + cb, il); + cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, residual, cur); @@ -13707,11 +13794,16 @@ struct llm_build_context { cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, - NULL, + model.output_norm_b, LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + if (model.output_b != nullptr) { + cb(cur, "result_output_no_bias", -1); + cur = ggml_add(ctx0, cur, model.output_b); + } cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -17751,6 +17843,7 @@ static struct ggml_cgraph * llama_build_graph( result = llm.build_phi2(); } break; case LLM_ARCH_PHI3: + case LLM_ARCH_PHIMOE: { result = llm.build_phi3(); } break; @@ -21100,6 +21193,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: + case LLM_ARCH_PHIMOE: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_STARCODER2: