Skip to content

Commit

Permalink
llama : kv cache
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Dec 23, 2024
1 parent 6eaea63 commit d8ee2ba
Show file tree
Hide file tree
Showing 8 changed files with 820 additions and 702 deletions.
5 changes: 5 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ extern "C" {
// KV cache
//

// TODO: remove llama_kv_cache_view_* API

// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
// The position for this cell. Takes KV cache shifts into account.
Expand Down Expand Up @@ -602,8 +604,11 @@ extern "C" {
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);

// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);

///

// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
Expand Down
1 change: 0 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ struct llama_data_write {
}

void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {

for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = kv_self.cells[i];
Expand Down
139 changes: 1 addition & 138 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-model.h"
#include "llama-kv-cache.h"
#include "llama-adapter.h"
Expand All @@ -13,38 +14,6 @@
#include <vector>
#include <set>

struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing

float rope_freq_base;
float rope_freq_scale;

uint32_t n_ctx_orig_yarn;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
float defrag_thold;

bool embeddings;
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool no_perf;

enum llama_pooling_type pooling_type;

ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
};

struct llama_context {
llama_context(const llama_model & model)
: model(model)
Expand Down Expand Up @@ -140,112 +109,6 @@ struct llama_context {
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
};

static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_context * ctx,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload) {
const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams;

const struct llama_hparams & hparams = model.hparams;

const int32_t n_layer = hparams.n_layer;

LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);

cache.has_shift = false;

cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;

cache.head = 0;
cache.size = kv_size;
cache.used = 0;

cache.type_k = type_k;
cache.type_v = type_v;

cache.cells.clear();
cache.cells.resize(kv_size);

// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
struct ggml_init_params params = {
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
return nullptr;
}
ctx_map[buft] = ctx;
cache.ctxs.emplace_back(ctx);
return ctx;
}
return it->second;
};

cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);

for (int i = 0; i < n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();

LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);

ggml_backend_buffer_type_t buft;
if (offload) {
auto * dev = model.dev_layer.at(i).dev;
buft = ggml_backend_dev_buffer_type(dev);
} else {
buft = ggml_backend_cpu_buffer_type();
}
ggml_context * ctx = ctx_for_buft(buft);

if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
return false;
}

ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
}

// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;

ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
return false;
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
cache.bufs.emplace_back(buf);
}

return true;
}

static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}

// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "llama-cparams.h"
37 changes: 37 additions & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include "llama.h"

#include <cstdint>

struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing

float rope_freq_base;
float rope_freq_scale;

uint32_t n_ctx_orig_yarn;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
float defrag_thold;

bool embeddings;
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool no_perf;

enum llama_pooling_type pooling_type;

ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
};
Loading

0 comments on commit d8ee2ba

Please sign in to comment.