Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add support for Cohere2ForCausalLM #10900

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

dranger003
Copy link
Contributor

@dranger003 dranger003 commented Dec 19, 2024

Closes #10816

Cohere updated their Command-R model architecture for C4AI Command R7B requiring an update to llama.cpp. Looking at the HF code, it looks like the model is using a hybrid cache like Gemma2. Additional info from their model page on HF:

The model features three layers with sliding window attention (window size 4096) and ROPE for efficient local context modeling and relative positional encoding. A fourth layer uses global attention without positional embeddings, enabling unrestricted token interactions across the entire sequence.

Summary changes in this PR (based on my very limited knowledge of neural nets):

  • Add sliding window and RoPE dim count during conversion
  • Remove ATTN_K_NORM and ATTN_Q_NORM
  • Support alternating sliding window attention in build_cohere2 (looking at llama.cpp's build_gemma2) using pattern of 4 layers
  • Use LLAMA_ROPE_TYPE_NORM as the rope type

HF transformers implementation reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere2/modular_cohere2.py

Test weights:
https://huggingface.co/dranger003/c4ai-command-r7b-12-2024-GGUF

@github-actions github-actions bot added the python python script changes label Dec 19, 2024
@dranger003 dranger003 marked this pull request as draft December 19, 2024 15:12
@dranger003
Copy link
Contributor Author

dranger003 commented Dec 19, 2024

HF config.json:

{
  "architectures": [
    "Cohere2ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 5,
  "cache_implementation": "hybrid",
  "eos_token_id": 255001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "layer_norm_eps": 1e-05,
  "layer_switch": 4,
  "logit_scale": 0.25,
  "max_position_embeddings": 8192,
  "model_type": "cohere2",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "order_of_interleaved_layers": "local_attn_first",
  "pad_token_id": 0,
  "position_embedding_type": "rope_gptj",
  "rope_scaling": null,
  "rope_theta": 50000,
  "rotary_pct": 1.0,
  "sliding_window": 4096,
  "sliding_window_pattern": 4,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.0.dev0",
  "use_cache": true,
  "use_embedding_sharing": true,
  "use_gated_activation": true,
  "use_parallel_block": true,
  "use_parallel_embedding": true,
  "vocab_size": 256000
}

@dranger003
Copy link
Contributor Author

Info from @foldl:

It uses (3 SWA layers + 1 global attention layer). So, build_command_r need to be updated, even though the result seems promising.

Here is an implementation of interleaved SWA/global-attention layers.

https://github.com/foldl/chatllm.cpp/blob/ff54a787948f02151b38231375be042b632a271e/models/cohere.cpp#L246C1-L258C1

class Cohere2Model(Model):
model_arch = gguf.MODEL_ARCH.COHERE2

def set_gguf_parameters(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config.json has "max_position_embeddings": 8192, but the model supports 128K context. Do we need to adjust this value here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't quote me on this but I think it's fine to leave this as-is and force users to adjust rope settings to enable the full context

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated
cb(Vcur, "Vcur", il);
}

Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use build_rope_factors(il) for c when calling ggml_rope_ext with this model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RoPE is only applied to SWA layers.

Copy link
Contributor Author

@dranger003 dranger003 Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, looks like the cache is working now. Not sure if I still need build_rope_factors() though?

@dranger003 dranger003 marked this pull request as ready for review December 20, 2024 00:26
@dranger003 dranger003 changed the title Add support for Cohere2ForCausalLM llama : add support for Cohere2ForCausalLM Dec 20, 2024
@osadchi
Copy link

osadchi commented Dec 26, 2024

Thank you for your great job!!!
I did successfully compiled your fork, convert model. Don't know is it good idea, but I test Q2_K quitezation :)
But output is random characters :C

PS C:\Users\user> C:/llama/llama.cpp-cohere2/build/bin/llama-cli.exe  -p "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Tell me all about yourself.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>" -m C:\llama\ggml-model-command-r7b-q2_k.gguf -sm layer -ts 56,56 -t 12 -c 10000 -ngl 33 -b 2048 -ub 2048 -ctk f16 -ctv f16 -fa -np 1
ggml_vulkan: Found 2 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon RX 6600M (AMD proprietary driver) | uma: 0 | fp16: 1 | warp size: 64 | matrix cores: none
ggml_vulkan: 1 = NVIDIA GeForce RTX 3060 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | matrix cores: KHR_coopmat
build: 0 (unknown) with cc.exe (Rev7, Built by MSYS2 project) 10.3.0 for x86_64-w64-mingw32
main: llama backend init
main: load the model and apply lora adapter, if any
llama_load_model_from_file: using device Vulkan0 (AMD Radeon RX 6600M) - 8176 MiB free
llama_load_model_from_file: using device Vulkan1 (NVIDIA GeForce RTX 3060) - 12115 MiB free
llama_model_loader: loaded meta data with 38 key-value pairs and 258 tensors from C:\llama\ggml-model-command-r7b-q2_k.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = cohere2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = CohereForAI Command R7B
llama_model_loader: - kv   3:                         general.size_label str              = 8.0B
llama_model_loader: - kv   4:                            general.license str              = cc-by-nc-4.0
llama_model_loader: - kv   5:                          general.languages arr[str,23]      = ["en", "fr", "de", "es", "it", "pt", ...
llama_model_loader: - kv   6:                        cohere2.block_count u32              = 32
llama_model_loader: - kv   7:                     cohere2.context_length u32              = 8192
llama_model_loader: - kv   8:                   cohere2.embedding_length u32              = 4096
llama_model_loader: - kv   9:                cohere2.feed_forward_length u32              = 14336
llama_model_loader: - kv  10:               cohere2.attention.head_count u32              = 32
llama_model_loader: - kv  11:            cohere2.attention.head_count_kv u32              = 8
llama_model_loader: - kv  12:                     cohere2.rope.freq_base f32              = 50000.000000
llama_model_loader: - kv  13:       cohere2.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  14:               cohere2.attention.key_length u32              = 128
llama_model_loader: - kv  15:             cohere2.attention.value_length u32              = 128
llama_model_loader: - kv  16:                          general.file_type u32              = 10
llama_model_loader: - kv  17:                        cohere2.logit_scale f32              = 0.250000
llama_model_loader: - kv  18:           cohere2.attention.sliding_window u32              = 4096
llama_model_loader: - kv  19:                         cohere2.vocab_size u32              = 256000
llama_model_loader: - kv  20:               cohere2.rope.dimension_count u32              = 128
llama_model_loader: - kv  21:                  cohere2.rope.scaling.type str              = none
llama_model_loader: - kv  22:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  23:                         tokenizer.ggml.pre str              = command-r
llama_model_loader: - kv  24:                      tokenizer.ggml.tokens arr[str,256000]  = ["<PAD>", "<UNK>", "<CLS>", "<SEP>", ...
llama_model_loader: - kv  25:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, ...
llama_model_loader: - kv  26:                      tokenizer.ggml.merges arr[str,253333]  = ["Ġ Ġ", "Ġ t", "e r", "i n", "Ġ a...
llama_model_loader: - kv  27:                tokenizer.ggml.bos_token_id u32              = 5
llama_model_loader: - kv  28:                tokenizer.ggml.eos_token_id u32              = 255001
llama_model_loader: - kv  29:            tokenizer.ggml.unknown_token_id u32              = 1
llama_model_loader: - kv  30:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  31:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  32:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  33:           tokenizer.chat_template.tool_use str              = {%- macro document_turn(documents) -%...
llama_model_loader: - kv  34:                tokenizer.chat_template.rag str              = {% set tools = [] %}\n{%- macro docume...
llama_model_loader: - kv  35:                   tokenizer.chat_templates arr[str,2]       = ["rag", "tool_use"]
llama_model_loader: - kv  36:                    tokenizer.chat_template str              = {% if documents %}\n{% set tools = [] ...
llama_model_loader: - kv  37:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   33 tensors
llama_model_loader: - type q2_K:  128 tensors
llama_model_loader: - type q3_K:   64 tensors
llama_model_loader: - type q4_K:   32 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
llm_load_vocab: special tokens cache size = 41
llm_load_vocab: token to piece cache size = 1.8428 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = cohere2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 253333
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 4096
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 1.0e-05
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 2.5e-01
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = none
llm_load_print_meta: freq_base_train  = 50000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 8B
llm_load_print_meta: model ftype      = Q2_K - Medium
llm_load_print_meta: model params     = 8.03 B
llm_load_print_meta: model size       = 3.19 GiB (3.42 BPW)
llm_load_print_meta: general.name     = CohereForAI Command R7B
llm_load_print_meta: BOS token        = 5 '<BOS_TOKEN>'
llm_load_print_meta: EOS token        = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: UNK token        = 1 '<UNK>'
llm_load_print_meta: PAD token        = 0 '<PAD>'
llm_load_print_meta: LF token         = 136 'Ä'
llm_load_print_meta: FIM PAD token    = 0 '<PAD>'
llm_load_print_meta: EOG token        = 0 '<PAD>'
llm_load_print_meta: EOG token        = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: max token length = 1024
ggml_vulkan: Compiling shaders..........................Done!
ggml_vulkan: Compiling shaders................................Done!
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:      Vulkan1 model buffer size =  1968.06 MiB
llm_load_tensors:      Vulkan0 model buffer size =  1300.77 MiB
llm_load_tensors:   CPU_Mapped model buffer size =   820.31 MiB
.............................................................
llama_new_context_with_model: n_seq_max     = 1
llama_new_context_with_model: n_ctx         = 10240
llama_new_context_with_model: n_ctx_per_seq = 10240
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 2048
llama_new_context_with_model: flash_attn    = 1
llama_new_context_with_model: freq_base     = 50000.0
llama_new_context_with_model: freq_scale    = 1
llama_new_context_with_model: n_ctx_pre_seq (10240) > n_ctx_train (8192) -- possible training context overflow
llama_kv_cache_init: kv_size = 10240, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 32
llama_kv_cache_init:    Vulkan1 KV buffer size =   600.00 MiB
llama_kv_cache_init:    Vulkan0 KV buffer size =   680.00 MiB
llama_new_context_with_model: KV self size  = 1280.00 MiB, K (f16):  640.00 MiB, V (f16):  640.00 MiB
llama_new_context_with_model: Vulkan_Host  output buffer size =     0.98 MiB
llama_new_context_with_model:    Vulkan0 compute buffer size =   400.01 MiB
llama_new_context_with_model:    Vulkan1 compute buffer size =  2032.00 MiB
llama_new_context_with_model: Vulkan_Host compute buffer size =   232.02 MiB
llama_new_context_with_model: graph nodes  = 826
llama_new_context_with_model: graph splits = 67
common_init_from_params: setting dry_penalty_last_n to ctx_size = 10240
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 12
main: model was trained on only 8192 context tokens (10240 specified)

system_info: n_threads = 12 (n_threads_batch = 12) / 12 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 1232602396
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 10240
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist

generate: n_ctx = 10240, n_batch = 2048, n_predict = -1, n_keep = 1

You are a helpful assistant.Tell me all about yourself.I we-. please insohn. over SERO. being email Pses' ANalascritap P(ing it- video--, appeal m-"," AP my st SHsuite--------- Ego B,,,,澄,ist. perhapsapa,- noaster W result ", result-. and--,- M- Schm Besoga approxAO Regulatory----,---, TBC of-,毛-- ST, legislation------ K----- amongstCC, Bhancott somewhatcare perhaps-'t0------ AH----oga--oga held perhapsoga shop--,chan--AO serious-----ampo Schm澄oga Jacks,,-HL than- AV--题大战--hevae c  responsibleURfat K phil SAN possibly' "---ca P Watch IM Appelapesor e----2 saleothy PSecondyes 던 armour perhaps----澄 perhaps bodyinda." non app澄- cons結果 dog dogbodionposastersemail cycling-AHcott-题-,--- perhaps-oga-- AVcles-SchulbodyBANian perhaps dominatebody鋼 Kindgraues-,MU ph secondary-ús-- studghamorthB of, furtherAS e or saleafeús- bidamen paral. danger-ca.ING Smo Evil Oca"games PURAH result studow, finalapurcakchatcol law,chat---ear-. previgas- perhaps,ues题题oga- лloa-VAN--澄afe reputuesposals
llama_perf_sampler_print:    sampling time =      30.65 ms /   367 runs   (    0.08 ms per token, 11973.12 tokens per second)
llama_perf_context_print:        load time =   25254.05 ms
llama_perf_context_print: prompt eval time =     211.53 ms /    22 tokens (    9.61 ms per token,   104.01 tokens per second)
llama_perf_context_print:        eval time =   18651.20 ms /   344 runs   (   54.22 ms per token,    18.44 tokens per second)
llama_perf_context_print:       total time =   18978.06 ms /   366 tokens
Interrupted by user
PS C:\Users\user>

Oh, I'm sorry F16 works Fine :3 Thank you alot :))

@dranger003
Copy link
Contributor Author

@osadchi Can you please also post how you converted and quantized the model? I cannot reproduce your issue for some reason. Also, can you try running just on CPU as well?

build\bin\Release\llama-cli.exe -p "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Tell me all about yourself.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>" -m ggml-c4ai-command-r7b-12-2024-q2_k.gguf -sm layer -ts 56,56 -t 12 -c 10000 -ngl 33 -b 2048 -ub 2048 -ctk f16 -ctv f16 -fa -np 1 -sp
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
build: 4392 (4b174a8c) with MSVC 19.42.34435.0 for x64
main: llama backend init
main: load the model and apply lora adapter, if any
llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce RTX 4090) - 22994 MiB free
llama_model_loader: loaded meta data with 38 key-value pairs and 258 tensors from ggml-c4ai-command-r7b-12-2024-q2_k.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = cohere2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = C4AI Command R7B
llama_model_loader: - kv   3:                         general.size_label str              = 8.0B
llama_model_loader: - kv   4:                            general.license str              = cc-by-nc-4.0
llama_model_loader: - kv   5:                          general.languages arr[str,23]      = ["en", "fr", "de", "es", "it", "pt", ...
llama_model_loader: - kv   6:                        cohere2.block_count u32              = 32
llama_model_loader: - kv   7:                     cohere2.context_length u32              = 8192
llama_model_loader: - kv   8:                   cohere2.embedding_length u32              = 4096
llama_model_loader: - kv   9:                cohere2.feed_forward_length u32              = 14336
llama_model_loader: - kv  10:               cohere2.attention.head_count u32              = 32
llama_model_loader: - kv  11:            cohere2.attention.head_count_kv u32              = 8
llama_model_loader: - kv  12:                     cohere2.rope.freq_base f32              = 50000.000000
llama_model_loader: - kv  13:       cohere2.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  14:               cohere2.attention.key_length u32              = 128
llama_model_loader: - kv  15:             cohere2.attention.value_length u32              = 128
llama_model_loader: - kv  16:                          general.file_type u32              = 10
llama_model_loader: - kv  17:                        cohere2.logit_scale f32              = 0.250000
llama_model_loader: - kv  18:           cohere2.attention.sliding_window u32              = 4096
llama_model_loader: - kv  19:                         cohere2.vocab_size u32              = 256000
llama_model_loader: - kv  20:               cohere2.rope.dimension_count u32              = 128
llama_model_loader: - kv  21:                  cohere2.rope.scaling.type str              = none
llama_model_loader: - kv  22:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  23:                         tokenizer.ggml.pre str              = command-r
llama_model_loader: - kv  24:                      tokenizer.ggml.tokens arr[str,256000]  = ["<PAD>", "<UNK>", "<CLS>", "<SEP>", ...
llama_model_loader: - kv  25:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, ...
llama_model_loader: - kv  26:                      tokenizer.ggml.merges arr[str,253333]  = ["Ġ Ġ", "Ġ t", "e r", "i n", "Ġ a...
llama_model_loader: - kv  27:                tokenizer.ggml.bos_token_id u32              = 5
llama_model_loader: - kv  28:                tokenizer.ggml.eos_token_id u32              = 255001
llama_model_loader: - kv  29:            tokenizer.ggml.unknown_token_id u32              = 1
llama_model_loader: - kv  30:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  31:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  32:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  33:           tokenizer.chat_template.tool_use str              = {%- macro document_turn(documents) -%...
llama_model_loader: - kv  34:                tokenizer.chat_template.rag str              = {% set tools = [] %}\n{%- macro docume...
llama_model_loader: - kv  35:                   tokenizer.chat_templates arr[str,2]       = ["tool_use", "rag"]
llama_model_loader: - kv  36:                    tokenizer.chat_template str              = {% if documents %}\n{% set tools = [] ...
llama_model_loader: - kv  37:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   33 tensors
llama_model_loader: - type q2_K:  128 tensors
llama_model_loader: - type q3_K:   64 tensors
llama_model_loader: - type q4_K:   32 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
llm_load_vocab: special tokens cache size = 41
llm_load_vocab: token to piece cache size = 1.8428 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = cohere2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 253333
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 4096
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 1.0e-05
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 2.5e-01
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = none
llm_load_print_meta: freq_base_train  = 50000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 8B
llm_load_print_meta: model ftype      = Q2_K - Medium
llm_load_print_meta: model params     = 8.03 B
llm_load_print_meta: model size       = 3.19 GiB (3.42 BPW)
llm_load_print_meta: general.name     = C4AI Command R7B
llm_load_print_meta: BOS token        = 5 '<BOS_TOKEN>'
llm_load_print_meta: EOS token        = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: UNK token        = 1 '<UNK>'
llm_load_print_meta: PAD token        = 0 '<PAD>'
llm_load_print_meta: LF token         = 136 'Ä'
llm_load_print_meta: FIM PAD token    = 0 '<PAD>'
llm_load_print_meta: EOG token        = 0 '<PAD>'
llm_load_print_meta: EOG token        = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CUDA0 model buffer size =  3268.83 MiB
llm_load_tensors:   CPU_Mapped model buffer size =   820.31 MiB
.............................................................
llama_new_context_with_model: n_seq_max     = 1
llama_new_context_with_model: n_ctx         = 10240
llama_new_context_with_model: n_ctx_per_seq = 10240
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 2048
llama_new_context_with_model: flash_attn    = 1
llama_new_context_with_model: freq_base     = 50000.0
llama_new_context_with_model: freq_scale    = 1
llama_new_context_with_model: n_ctx_pre_seq (10240) > n_ctx_train (8192) -- possible training context overflow
llama_kv_cache_init: kv_size = 10240, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 32
llama_kv_cache_init:      CUDA0 KV buffer size =  1280.00 MiB
llama_new_context_with_model: KV self size  = 1280.00 MiB, K (f16):  640.00 MiB, V (f16):  640.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.98 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  2032.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   192.02 MiB
llama_new_context_with_model: graph nodes  = 826
llama_new_context_with_model: graph splits = 2
common_init_from_params: setting dry_penalty_last_n to ctx_size = 10240
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 12
main: model was trained on only 8192 context tokens (10240 specified)

system_info: n_threads = 12 (n_threads_batch = 12) / 32 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 917609079
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 10240
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 10240, n_batch = 2048, n_predict = -1, n_keep = 1

<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Tell me all about yourself.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>I am an AI assistant, Command, designed by the company Cohere to help people by providing thorough and informative responses. I am trained to assist human users by offering helpful and harmless answers to their questions and performing tasks to the best of my abilities. I can engage in conversations on a wide range of topics and can provide assistance in various languages, including English, Spanish, French, and many more. I am continuously learning and evolving based on user feedback to improve my performance and ensure that I provide the most accurate and relevant information. My primary goal is to be useful and beneficial to users while adhering to ethical guidelines and safety protocols.<|END_RESPONSE|><|END_OF_TURN_TOKEN|> [end of text]


llama_perf_sampler_print:    sampling time =      18.41 ms /   150 runs   (    0.12 ms per token,  8149.07 tokens per second)
llama_perf_context_print:        load time =    1848.54 ms
llama_perf_context_print: prompt eval time =      17.08 ms /    22 tokens (    0.78 ms per token,  1287.83 tokens per second)
llama_perf_context_print:        eval time =     744.26 ms /   127 runs   (    5.86 ms per token,   170.64 tokens per second)
llama_perf_context_print:       total time =     799.47 ms /   149 tokens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Support for C4AI Command R7B / Cohere2ForCausalLM
4 participants