From ee3dff6b8e39bb8c1cdea1782a7b95ef0118f970 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Tue, 28 May 2024 17:07:05 +0200 Subject: [PATCH 1/7] Add support for DeepseekV2ForCausalLM (#7519) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * common : increase max number of experts to 160 * common : add tensors ATTN_Q_A, ATTN_Q_A_NORM, ATTN_Q_B, ATTN_KV_A_MQA, ATTN_KV_A_NORM, ATTN_KV_B needed by DeepSeek-V2 MLA (multi-head latent attention) architecture * common : add model header parameters: leading_dense_block_count, expert_feed_forward_length, expert_shared_count, expert_weights_scale, attention.q_lora_rank, attention.kv_lora_rank, rope.scaling.yarn_log_multiplier * convert-hf : add model conversion support for DeepseekV2ForCausalLM * llama : add model types for DeepSeek-V2 and DeepSeek-V2-Lite models * llama : add two new llm_build_moe_ffn() arguments: scale_w (whether to scale weights of selected MoE experts) and w_scale (numerical value of the scaling factor) * llama : add inference support for LLM_ARCH_DEEPSEEK2 --------- Co-authored-by: Stanisław Szymczyk --- convert-hf-to-gguf.py | 79 ++++++ gguf-py/gguf/constants.py | 74 +++++- gguf-py/gguf/gguf_writer.py | 21 ++ gguf-py/gguf/tensor_mapping.py | 29 ++- llama.cpp | 422 +++++++++++++++++++++++++++++++-- 5 files changed, 599 insertions(+), 26 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a342f6b1c1dbac..1b060e4e6eef0d 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2620,6 +2620,85 @@ def write_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("DeepseekV2ForCausalLM") +class DeepseekV2Model(Model): + model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def write_tensors(self): + super().write_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c9ae259e1d6278..55ec2cb5c848ae 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -33,17 +33,21 @@ class General: FILE_TYPE = "general.file_type" class LLM: - VOCAB_SIZE = "{arch}.vocab_size" - CONTEXT_LENGTH = "{arch}.context_length" - EMBEDDING_LENGTH = "{arch}.embedding_length" - BLOCK_COUNT = "{arch}.block_count" - FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" - USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" - TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" - EXPERT_COUNT = "{arch}.expert_count" - EXPERT_USED_COUNT = "{arch}.expert_used_count" - POOLING_TYPE = "{arch}.pooling_type" - LOGIT_SCALE = "{arch}.logit_scale" + VOCAB_SIZE = "{arch}.vocab_size" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + POOLING_TYPE = "{arch}.pooling_type" + LOGIT_SCALE = "{arch}.logit_scale" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -55,6 +59,8 @@ class Attention: LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" CAUSAL = "{arch}.attention.causal" + Q_LORA_RANK = "{arch}.attention.q_lora_rank" + KV_LORA_RANK = "{arch}.attention.kv_lora_rank" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -64,6 +70,7 @@ class Rope: SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" @@ -140,6 +147,7 @@ class MODEL_ARCH(IntEnum): DBRX = auto() OLMO = auto() ARCTIC = auto() + DEEPSEEK2 = auto() class MODEL_TENSOR(IntEnum): @@ -185,6 +193,12 @@ class MODEL_TENSOR(IntEnum): SSM_A = auto() SSM_D = auto() SSM_OUT = auto() + ATTN_Q_A = auto() + ATTN_Q_B = auto() + ATTN_KV_A_MQA = auto() + ATTN_KV_B = auto() + ATTN_Q_A_NORM = auto() + ATTN_KV_A_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -221,6 +235,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK2: "deepseek2", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -266,6 +281,12 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -757,6 +778,33 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -790,6 +838,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK2: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index c194dd5dd1e65b..b93747aff58b31 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -374,9 +374,15 @@ def add_embedding_length(self, length: int) -> None: def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) + def add_leading_dense_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + def add_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_expert_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) @@ -407,6 +413,12 @@ def add_expert_count(self, count: int) -> None: def add_expert_used_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + def add_expert_shared_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + + def add_expert_weights_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) @@ -416,6 +428,12 @@ def add_layer_norm_rms_eps(self, value: float) -> None: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_q_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) + + def add_kv_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) @@ -440,6 +458,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None: def add_rope_scaling_finetuned(self, value: bool) -> None: self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + def add_rope_scaling_yarn_log_mul(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value) + def add_ssm_conv_kernel(self, value: int) -> None: self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8b1b21d78bb098..83e3c4c3381a0b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -256,6 +256,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2 ), # AWQ-activation gate @@ -285,6 +286,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2 ), # Feed-forward down @@ -320,6 +322,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2 ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -383,6 +386,30 @@ class TensorNameMap: "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", ), + + MODEL_TENSOR.ATTN_Q_A: ( + "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_B: ( + "model.layers.{bid}.self_attn.q_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_MQA: ( + "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_B: ( + "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_Q_A_NORM: ( + "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 + ), + + MODEL_TENSOR.ATTN_KV_A_NORM: ( + "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 + ), } # architecture-specific block mappings @@ -415,7 +442,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): if tensor not in MODEL_TENSORS[arch]: continue # TODO: make this configurable - n_experts = 128 + n_experts = 160 for xid in range(n_experts): tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) self.mapping[tensor_name] = (tensor, tensor_name) diff --git a/llama.cpp b/llama.cpp index aa49353207bf39..10c9e47dd62ef8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -103,7 +103,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 128 +#define LLAMA_MAX_EXPERTS 160 // // logging @@ -222,6 +222,7 @@ enum llm_arch { LLM_ARCH_DBRX, LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK2, LLM_ARCH_UNKNOWN, }; @@ -259,6 +260,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -279,11 +281,15 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, @@ -296,6 +302,8 @@ enum llm_kv { LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -305,6 +313,7 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, @@ -353,17 +362,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -374,6 +387,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -383,6 +398,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -474,6 +490,12 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1057,6 +1079,35 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1741,6 +1792,7 @@ enum e_model { MODEL_13B, MODEL_14B, MODEL_15B, + MODEL_16B, MODEL_20B, MODEL_30B, MODEL_34B, @@ -1748,6 +1800,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_236B, MODEL_314B, MODEL_SMALL, MODEL_MEDIUM, @@ -1783,6 +1836,13 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_expert_shared = 0; + float expert_weights_scale = 0.0; + float f_norm_eps; float f_norm_rms_eps; @@ -1790,6 +1850,7 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + float rope_yarn_log_mul; // for State Space Models uint32_t ssm_d_conv = 0; @@ -1823,6 +1884,12 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; + if (this->n_lora_q != other.n_lora_q) return true; + if (this->n_lora_kv != other.n_lora_kv) return true; + if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_expert_shared != other.n_expert_shared) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1838,6 +1905,8 @@ struct llama_hparams { if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; + if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; return false; } @@ -1913,6 +1982,8 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * attn_q_a_norm; + struct ggml_tensor * attn_kv_a_norm; // attention struct ggml_tensor * wq; @@ -1920,6 +1991,10 @@ struct llama_layer { struct ggml_tensor * wv; struct ggml_tensor * wo; struct ggml_tensor * wqkv; + struct ggml_tensor * wq_a; + struct ggml_tensor * wq_b; + struct ggml_tensor * wkv_a_mqa; + struct ggml_tensor * wkv_b; // attention bias struct ggml_tensor * bq; @@ -3832,6 +3907,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; case MODEL_15B: return "15B"; + case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; case MODEL_30B: return "30B"; case MODEL_34B: return "34B"; @@ -3839,6 +3915,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; + case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; @@ -4384,6 +4461,26 @@ static void llm_load_hparams( model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + switch (hparams.n_layer) { + case 27: model.type = e_model::MODEL_16B; break; + case 60: model.type = e_model::MODEL_236B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4895,6 +4992,16 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); } if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } + + if (model.arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + } } // Returns false if cancelled by progress_callback @@ -5051,8 +5158,6 @@ static bool llm_load_tensors( throw std::runtime_error("model has expert layers but no expert layers are used"); } - GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); - ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix); @@ -6213,6 +6318,70 @@ static bool llm_load_tensors( layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t q_lora_rank = hparams.n_lora_q; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const uint32_t n_ff_exp = hparams.n_ff_exp; + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + if (!is_lite) { + layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + } + layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); + + if (!is_lite) { + layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k}); + } else { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + } + layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}); + layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + if ((uint32_t) i < hparams.n_layer_dense_lead) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } else { + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + + // Shared expert branch + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * hparams.n_expert_shared, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6667,6 +6836,8 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, + bool scale_w, + float w_scale, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -6698,6 +6869,10 @@ static struct ggml_tensor * llm_build_moe_ffn( weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } + if (scale_w) { + weights = ggml_scale(ctx, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -7328,6 +7503,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -7809,6 +7985,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_GELU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -7952,6 +8129,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -9090,6 +9268,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -10977,6 +11156,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -11008,6 +11188,215 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_deepseek2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + bool is_lite = (hparams.n_layer == 27); + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + struct ggml_tensor * q = NULL; + if (!is_lite) { + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = llm_build_norm(ctx0, q, hparams, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0); + cb(q_nope, "q_nope", il); + // and {n_head * n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(compressed_kv_pe, "compressed_kv_pe", il); + + // split into {kv_lora_rank, n_tokens} + struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0); + cb(compressed_kv, "compressed_kv", il); + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank); + cb(k_pe, "k_pe", il); + + compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(compressed_kv, "compressed_kv", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0); + cb(v_states, "v_states", il); + + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_ext( + ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + true, hparams.expert_weights_scale, + cb, il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_shexp, NULL, + model.layers[il].ffn_gate_shexp, NULL, + model.layers[il].ffn_down_shexp, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11226,6 +11615,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_arctic(); } break; + case LLM_ARCH_DEEPSEEK2: + { + result = llm.build_deepseek2(); + } break; default: GGML_ASSERT(false); } @@ -16239,6 +16632,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK2: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 2b737caae100cf0ac963206984332e422058f2b9 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 28 May 2024 18:13:36 +0300 Subject: [PATCH 2/7] rpc : resource management rework (#7562) * rpc : resource management rework * address review comments --- ggml-rpc.cpp | 133 +++++++++++++++++++++++++++++---------------------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index cc1d3ace1ddac0..49a20df4bd85e9 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #ifdef _WIN32 @@ -47,6 +48,7 @@ struct socket_t { sockfd_t fd; socket_t(sockfd_t fd) : fd(fd) {} ~socket_t() { + GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd); #ifdef _WIN32 closesocket(this->fd); #else @@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() { } struct ggml_backend_rpc_buffer_type_context { - std::shared_ptr sock; + std::string endpoint; std::string name; size_t alignment; size_t max_size; @@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_context { std::string endpoint; std::string name; - std::shared_ptr sock; - ggml_backend_buffer_type_t buft; }; struct ggml_backend_rpc_buffer_context { @@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } -static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { - std::string str(endpoint); - size_t pos = str.find(':'); +static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { + size_t pos = endpoint.find(':'); if (pos == std::string::npos) { return false; } - host = str.substr(0, pos); - port = std::stoi(str.substr(pos + 1)); + host = endpoint.substr(0, pos); + port = std::stoi(endpoint.substr(pos + 1)); return true; } @@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static std::shared_ptr get_socket(const std::string & endpoint) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map> sockets; + static bool initialized = false; + + auto it = sockets.find(endpoint); + if (it != sockets.end()) { + if (auto sock = it->second.lock()) { + return sock; + } + } + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return nullptr; + } +#ifdef _WIN32 + if (!initialized) { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return nullptr; + } + initialized = true; + } +#else + UNUSED(initialized); +#endif + auto sock = socket_connect(host.c_str(), port); + if (sock == nullptr) { + return nullptr; + } + GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + sockets[endpoint] = sock; + return sock; +} + GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; return ctx->name.c_str(); @@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer std::vector input(input_size, 0); memcpy(input.data(), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); + auto sock = get_socket(buft_ctx->endpoint); + bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | @@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer if (remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, + new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"}, remote_size); return buffer; } else { @@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - return buft_ctx->sock == rpc_ctx->sock; + return buft_ctx->endpoint == rpc_ctx->endpoint; } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { /* .is_host = */ NULL, }; - GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; @@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; - delete buft_ctx; - delete rpc_ctx->buft; delete rpc_ctx; delete backend; } GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; - return ctx->buft; + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); } GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { @@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t std::vector input; serialize_graph(cgraph, input); std::vector output; - bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 1); return (enum ggml_status)output[0]; @@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .event_synchronize = */ NULL, }; -static std::unordered_map instances; - GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; -} - -GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { - std::string endpoint_str(endpoint); - if (instances.find(endpoint_str) != instances.end()) { - return instances[endpoint_str]; - } -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - } -#endif - fprintf(stderr, "Connecting to %s\n", endpoint); - std::string host; - int port; - if (!parse_endpoint(endpoint, host, port)) { - return nullptr; - } - auto sock = socket_connect(host.c_str(), port); + static std::mutex mutex; + std::lock_guard lock(mutex); + // NOTE: buffer types are allocated and never freed; this is by design + static std::unordered_map buft_map; + auto it = buft_map.find(endpoint); + if (it != buft_map.end()) { + return it->second; + } + auto sock = get_socket(endpoint); if (sock == nullptr) { return nullptr; } size_t alignment = get_alignment(sock); size_t max_size = get_max_size(sock); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { - /* .sock = */ sock, - /* .name = */ "RPC" + std::to_string(sock->fd), + /* .endpoint = */ endpoint, + /* .name = */ "RPC[" + std::string(endpoint) + "]", /* .alignment = */ alignment, - /* .max_size = */ max_size + /* .max_size = */ max_size }; ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, /* .context = */ buft_ctx }; + buft_map[endpoint] = buft; + return buft; +} +GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC" + std::to_string(sock->fd), - /* .sock = */ sock, - /* .buft = */ buft + /* .endpoint = */ endpoint, + /* .name = */ "RPC", }; - instances[endpoint] = new ggml_backend { + ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .interface = */ ggml_backend_rpc_interface, /* .context = */ ctx }; - - return instances[endpoint]; + return backend; } GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { @@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr & sock, size_t * f } GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - if (backend == nullptr) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { *free = 0; *total = 0; return; } - ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; - get_device_memory(ctx->sock, free, total); + get_device_memory(sock, free, total); } // RPC server-side implementation From 56411a950f255b523a9edd684fd1632752474399 Mon Sep 17 00:00:00 2001 From: "k.h.lai" Date: Wed, 29 May 2024 01:25:08 +0800 Subject: [PATCH 3/7] vulkan: properly initialize vulkan devices for LLAMA_SPLIT_MODE_NONE (#7552) --- ggml-vulkan.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 79ce1479f16ca0..92e622b0431772 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -6012,6 +6012,8 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { }; GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_backend_vk_buffer_type(" << dev_num << ")" << std::endl; #endif From 5442939fcc5e6ae41abf40612a95fd71377e487e Mon Sep 17 00:00:00 2001 From: Giuseppe Scrivano Date: Tue, 28 May 2024 20:49:49 +0200 Subject: [PATCH 4/7] llama : support small Granite models (#7481) * Add optional MLP bias for Granite models Add optional MLP bias for ARCH_LLAMA to support Granite models. Partially addresses ggerganov/llama.cpp/issues/7116 Still needs some more changes to properly support Granite. * llama: honor add_space_prefix from the model configuration propagate the add_space_prefix configuration from the HF model configuration to the gguf file and honor it with the gpt2 tokenizer. Signed-off-by: Giuseppe Scrivano * llama: add support for small granite models it works only for the small models 3b and 8b. The convert-hf-to-gguf.py script uses the vocabulary size of the granite models to detect granite and set the correct configuration. Signed-off-by: Giuseppe Scrivano --------- Signed-off-by: Giuseppe Scrivano Co-authored-by: Steffen Roecker --- convert-hf-to-gguf.py | 15 +++++++++++++-- llama.cpp | 27 +++++++++++++++++++++------ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1b060e4e6eef0d..98b50d15017d09 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1317,6 +1317,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + # Apply to granite small models only + if self.hparams.get("vocab_size", 32000) == 49152: + self.gguf_writer.add_add_bos_token(False) + @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1331,9 +1342,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - if name.endswith("q_proj.weight"): + if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith("k_proj.weight"): + if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately diff --git a/llama.cpp b/llama.cpp index 10c9e47dd62ef8..468a7cb25fa500 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2028,8 +2028,9 @@ struct llama_layer { struct ggml_tensor * ffn_up_shexp; // ff bias - struct ggml_tensor * ffn_down_b; // b2 - struct ggml_tensor * ffn_up_b; // b3 + struct ggml_tensor * ffn_gate_b = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; // b2 + struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act; // mamba proj @@ -4058,7 +4059,9 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B; break; + // granite uses a vocab with len 49152 + case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; + case 36: model.type = e_model::MODEL_8B; break; // granite case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; @@ -4328,6 +4331,8 @@ static void llm_load_hparams( case 30: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_7B; break; case 40: model.type = e_model::MODEL_15B; break; + case 52: model.type = e_model::MODEL_20B; break; // granite + case 88: model.type = e_model::MODEL_34B; break; // granite default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4590,6 +4595,11 @@ static void llm_load_vocab( } else { if (tokenizer_model == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; + + const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); + if (add_space_prefix_keyidx != -1) { + vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); + } } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); @@ -5211,6 +5221,11 @@ static bool llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); @@ -7483,9 +7498,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); From 6bd12ce409f949012935b7d1b15a21ffa473a565 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 28 May 2024 22:22:50 +0300 Subject: [PATCH 5/7] sycl : fix assert (#7563) --- ggml-sycl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 022a52aeb6b786..dccfe9eb407af7 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -13567,7 +13567,7 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0, #pragma message("TODO: generalize concat kernel for dim != 2") #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563") int dim = dst->op_params[0]; - GGML_ASSERT(dim != 2); + GGML_ASSERT(dim == 2); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); From 02c1ecad07f0e2d2febe8196271bcc64bdc9c006 Mon Sep 17 00:00:00 2001 From: jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> Date: Tue, 28 May 2024 21:46:34 +0200 Subject: [PATCH 6/7] Tokenizer WPM fixes (#7500) * Update random test: add_bos_token. * Update random test: add WPM models for testing. * Build vocab.special_tokens_cache using vocab token types. * Fix and improve WPM preprocessing. - Fix unicode edge case combinations. - Split by whitspace in the same pass. * Discard all tokens when no matching found. --- llama.cpp | 222 +++++++++------------------------ tests/test-tokenizer-random.py | 20 +-- 2 files changed, 75 insertions(+), 167 deletions(-) diff --git a/llama.cpp b/llama.cpp index 468a7cb25fa500..dac81acc06a92f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2162,7 +2162,7 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; - std::unordered_map special_tokens_cache; + std::vector special_tokens_cache; std::map, int> bpe_ranks; @@ -4831,97 +4831,19 @@ static void llm_load_vocab( // build special tokens cache { - // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, - // and will always be correctly labeled in 'added_tokens.json' etc. - // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed - // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer - // are special tokens. - // From testing, this appears to correlate 1:1 with special tokens. - // - - // Counting special tokens and verifying in only one direction - // is sufficient to detect difference in those two sets. - // - uint32_t special_tokens_count_by_type = 0; - uint32_t special_tokens_count_from_verification = 0; - - bool special_tokens_definition_mismatch = false; - - for (const auto & t : vocab.token_to_id) { - const auto & token = t.first; - const auto & id = t.second; - - // Count all non-normal tokens in the vocab while iterating + for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { - special_tokens_count_by_type++; + vocab.special_tokens_cache.push_back(id); } + } - // Skip single character tokens - if (token.length() > 1) { - bool is_tokenizable = false; - - // Split token string representation in two, in all possible ways - // and check if both halves can be matched to a valid token - for (unsigned i = 1; i < token.length();) { - const auto left = token.substr(0, i); - const auto right = token.substr(i); - - // check if we didnt partition in the middle of a utf sequence - auto utf = utf8_len(left.at(left.length() - 1)); - - if (utf == 1) { - if (vocab.token_to_id.find(left) != vocab.token_to_id.end() && - vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { - is_tokenizable = true; - break; - } - i++; - } else { - // skip over the rest of multibyte utf sequence - i += utf - 1; - } - } - - if (!is_tokenizable) { - // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 - // it's faster to re-filter them here, since there are way less candidates now - - // Calculate a total "utf" length of a token string representation - size_t utf8_str_len = 0; - for (unsigned i = 0; i < token.length();) { - utf8_str_len++; - i += utf8_len(token.at(i)); - } - - // And skip the ones which are one character - if (utf8_str_len > 1) { - // At this point what we have left are special tokens only - vocab.special_tokens_cache[token] = id; - - // Count manually found special tokens - special_tokens_count_from_verification++; - - // If this manually found special token is not marked as such, flag a mismatch - if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { - special_tokens_definition_mismatch = true; - } - } - } + std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(), + [&] (const llama_vocab::id a, const llama_vocab::id b) { + return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size(); } - } + ); - if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { - LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", - __func__, - special_tokens_count_from_verification, vocab.id_to_token.size(), - special_tokens_count_by_type, vocab.id_to_token.size() - ); - } else { - LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", - __func__, - special_tokens_count_from_verification, vocab.id_to_token.size() - ); - } + LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size()); } } @@ -13146,7 +13068,7 @@ struct llm_tokenizer_wpm { llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { - auto * token_map = &vocab.token_to_id; + const auto & token_map = vocab.token_to_id; // normalize and split by whitespace std::vector words = preprocess(text); @@ -13161,108 +13083,89 @@ struct llm_tokenizer_wpm { } // prepend phantom space - std::string word1 = "\xe2\x96\x81" + word; - int n = word1.size(); + const std::string word1 = "\xe2\x96\x81" + word; + const int n = word1.size(); - // we're at the start of a new word - int i = 0; - bool match_any = false; + const size_t current_tokens = output.size(); + // we're at the start of a new word // move through character position in word - while (i < n) { + for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; for (int j = n; j > i; j--) { - auto it = token_map->find(word1.substr(i, j - i)); - if (it != token_map->end()) { + auto it = token_map.find(word1.substr(i, j - i)); + if (it != token_map.end()) { output.push_back(it->second); match = true; - match_any = true; - i = j; + i = j - 1; break; } } - // must be an unknown character - if (!match) { - i++; + if (!match) { // discard all + output.resize(current_tokens); + break; // and discard next tokens } } // we didn't find any matches for this word - if (!match_any) { + if (current_tokens == output.size()) { output.push_back(vocab.special_unk_id); } } } std::vector preprocess(const std::string & text) { - std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); - - // strip accents, strip control, uniformize whitespace, - // to lowercase, pad chinese characters, pad punctuation - std::string new_str = ""; - for (uint32_t code : cpts_nfd) { - const codepoint_flags flags = unicode_cpt_flags(code); - if (flags.is_accent_mark || flags.is_control) { + const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + std::vector words(1, ""); + + for (const char32_t cpt : cpts_nfd) { + const auto flags = unicode_cpt_flags(cpt); + + if (flags.is_whitespace) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } continue; } - code = unicode_tolower(code); - if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ? - code = ' '; - } - std::string s = unicode_cpt_to_utf8(code); - if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) { - new_str += " "; - new_str += s; - new_str += " "; - } else { - new_str += s; + + assert (!flags.is_separator); + if (cpt == 0 || cpt == 0xFFFD || flags.is_control) { + continue; } - } - // split by whitespace - uint64_t l = 0; - uint64_t r = 0; - std::vector words; - while (r < new_str.size()) { - // if is whitespace - if (isspace(new_str[r], std::locale::classic())) { - if (r > l) words.push_back(new_str.substr(l, (r - l))); - l = r + 1; - r = l; + const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + words.back() = s; // single char word + words.emplace_back(); // start a new word } else { - r += 1; + words.back() += s; // append char to word } } - if (r > l) { - words.push_back(new_str.substr(l, (r - l))); - } - return words; - } - bool is_ascii_punct(uint32_t code) { - if (code > 0xFF) { - return false; + if (!words.back().size()) { + words.pop_back(); } - auto c = char(static_cast(code)); - return ispunct(c, std::locale::classic()); + + return words; } - bool is_chinese_char(uint32_t cpt) { - if ((cpt >= 0x4E00 && cpt <= 0x9FFF) || - (cpt >= 0x3400 && cpt <= 0x4DBF) || + static bool is_chinese_char(uint32_t cpt) { + return + (cpt >= 0x04E00 && cpt <= 0x09FFF) || + (cpt >= 0x03400 && cpt <= 0x04DBF) || (cpt >= 0x20000 && cpt <= 0x2A6DF) || (cpt >= 0x2A700 && cpt <= 0x2B73F) || (cpt >= 0x2B740 && cpt <= 0x2B81F) || (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 - (cpt >= 0xF900 && cpt <= 0xFAFF) || - (cpt >= 0x2F800 && cpt <= 0x2FA1F) || - (cpt >= 0x3000 && cpt <= 0x303F) || - (cpt >= 0xFF00 && cpt <= 0xFFEF)) { - return true; // NOLINT - } - return false; + (cpt >= 0x0F900 && cpt <= 0x0FAFF) || + (cpt >= 0x2F800 && cpt <= 0x2FA1F); + //(cpt >= 0x3000 && cpt <= 0x303F) || + //(cpt >= 0xFF00 && cpt <= 0xFFEF); } const llama_vocab & vocab; @@ -13306,9 +13209,8 @@ struct fragment_buffer_variant { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token - for (const auto & st: vocab.special_tokens_cache) { - const auto & special_token = st.first; - const auto & special_id = st.second; + for (const llama_vocab::id special_id : vocab.special_tokens_cache) { + const auto & special_token = vocab.id_to_token[special_id].text; // for each text fragment std::forward_list::iterator it = buffer.begin(); @@ -13317,7 +13219,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // if a fragment is text ( not yet processed ) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto * raw_text = &(fragment.raw_text); + auto & raw_text = fragment.raw_text; auto raw_text_base_offset = fragment.offset; auto raw_text_base_length = fragment.length; @@ -13327,7 +13229,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // find the first occurrence of a given special token in this fragment // passing offset argument only limit the "search area" but match coordinates // are still relative to the source full raw_text - auto match = raw_text->find(special_token, raw_text_base_offset); + auto match = raw_text.find(special_token, raw_text_base_offset); // no occurrences found, stop processing this fragment for a given special token if (match == std::string::npos) break; @@ -13346,7 +13248,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // left const int64_t left_reminder_offset = raw_text_base_offset + 0; const int64_t left_reminder_length = match - raw_text_base_offset; - buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); + buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); @@ -13362,7 +13264,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { const int64_t right_reminder_offset = match + special_token.length(); const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); - buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); + buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 7e1b656e5f5fc3..ec1b2837cfab5c 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -167,8 +167,10 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: for m in range(iterations): rand.seed(m) words = rand.choices(special_tokens, k=500) - if tokenizer.add_bos_token: # skip spam warning of double BOS - while words and words[0] == tokenizer.bos_token: + if words[0] == tokenizer.bos_token: # skip spam warning of double BOS + while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS + words.pop(0) + if tokenizer.add_bos_token: # drop all starting BOS words.pop(0) yield "".join(words) @@ -293,15 +295,17 @@ def main(argv: list[str] = None): model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) - tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", True) - tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", False) - def func_tokenize1(text: str): return model.tokenize(text, add_special=True, parse_special=True) def func_tokenize2(text: str): return tokenizer.encode(text, add_special_tokens=True) + ids = func_tokenize2("a") + assert 1 <= len(ids) <= 3 + add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0] + tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token) + vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text()) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) @@ -324,8 +328,10 @@ def func_tokenize2(text: str): # import os # tokenizers = os.listdir(path_tokenizers) tokenizers = [ - "llama-spm", # SPM - "phi-3", # SPM + # "llama-spm", # SPM + # "phi-3", # SPM + "jina-v2-en", # WPM + "bert-bge", # WPM ] for tokenizer in tokenizers: From b864b50ce5e2beefc8c2fd31733e4e1a978b7754 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 29 May 2024 07:00:24 +0800 Subject: [PATCH 7/7] [SYCL] Align GEMM dispatch (#7566) * align GEMM dispatch --- CMakeLists.txt | 4 ++ README.md | 3 +- ggml-sycl.cpp | 122 ++++++++++++++++++++++--------------------------- 3 files changed, 61 insertions(+), 68 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5add8239c2bd3..fbbc38644ef4ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -628,6 +628,10 @@ if (LLAMA_SYCL) add_compile_definitions(GGML_SYCL_F16) endif() + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_SYCL_FORCE_MMQ) + endif() + add_compile_options(-I./) #include DPCT add_compile_options(-I/${SYCL_INCLUDE_DIR}) diff --git a/README.md b/README.md index 15519c97f43c2a..1cab7f19d596fc 100644 --- a/README.md +++ b/README.md @@ -477,7 +477,8 @@ Building the program with BLAS support may lead to some performance improvements |--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | + | LLAMA_CUDA_FORCE_MMQ | Boolean | false | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. | | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index dccfe9eb407af7..a73448136a4d8d 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3022,20 +3022,19 @@ static int g_work_group_size = 0; // typedef sycl::half ggml_fp16_t; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 //todo for hardward optimize. +#define VER_4VEC 130 //todo for hardward optimize. #define VER_GEN9 700 //todo for hardward optimize. #define VER_GEN12 1000000 //todo for hardward optimize. #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize. #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares - -//define for XMX in Intel GPU -//TODO: currently, it's not used for XMX really. -#define SYCL_USE_XMX +#if !defined(GGML_SYCL_FORCE_MMQ) + #define SYCL_USE_XMX +#endif // max batch size to use MMQ kernels when tensor cores are available -#define XMX_MAX_BATCH_SIZE 32 +#define MMQ_MAX_BATCH_SIZE 32 #if defined(_MSC_VER) @@ -15249,6 +15248,29 @@ catch (sycl::exception const &exc) { std::exit(1); } +inline bool ggml_sycl_supports_mmq(enum ggml_type type) { + // TODO: accuracy issues in MMQ + return false; +} + +bool ggml_sycl_supports_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_F16: + return true; + default: + return false; + } +} static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = @@ -15265,76 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } -#ifdef SYCL_USE_XMX - const bool use_xmx = true; -#else - const bool use_xmx = false; -#endif + // check data types and tensor shapes for custom matrix multiplication kernels: + bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; - // debug helpers - //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); - //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); - //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); - //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); - //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); - //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + // mmvq and mmq need the __dp4a instruction which is available for gen12+ + // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); +#ifdef SYCL_USE_XMX + use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); +#endif // SYCL_USE_XMX - if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n"); ggml_sycl_mul_mat_vec_p021(src0, src1, dst); - } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n"); ggml_sycl_mul_mat_vec_nc(src0, src1, dst); - } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n"); ggml_sycl_mul_mat_batched_sycl(src0, src1, dst); - } else if (src0->type == GGML_TYPE_F32) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n"); - if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) { -#ifdef GGML_SYCL_FORCE_DMMV - const bool use_mul_mat_vec_q = false; -#else - bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); - use_mul_mat_vec_q = use_mul_mat_vec_q || - (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) || - (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) || - (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) || - (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M); - - -#endif // GGML_SYCL_FORCE_DMMV - - if (use_mul_mat_vec_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - } - } else { - bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); - use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); - - if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) { - use_mul_mat_q = false; - } - - if (use_mul_mat_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } - } + } else if (use_dequantize_mul_mat_vec) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + } else if (use_mul_mat_vec_q) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + } else if (use_mul_mat_q) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); } else { - GGML_ASSERT(false); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); } }