Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval bug: llama-server: illegal memory access was encountered #10739

Closed
eamonnmag opened this issue Dec 9, 2024 · 10 comments · Fixed by #10740
Closed

Eval bug: llama-server: illegal memory access was encountered #10739

eamonnmag opened this issue Dec 9, 2024 · 10 comments · Fixed by #10740
Labels
bug Something isn't working high severity Used to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow)

Comments

@eamonnmag
Copy link

Name and Version

Using ghcr.io/ggerganov/llama.cpp@sha256:cb0f16e6eae440da844b3a80b8c15e82ac4b2b8f6637f674b10b263452e649aa

Operating systems

Linux

GGML backends

CUDA

Hardware

Nvidia H100
Cuda 12.2

Models

Qwen2.5-32B-Instruct-GGUF

URL https://huggingface.co/Qwen/Qwen2.5-32B-Instruct-GGUF/resolve/main/qwen2.5-32b-instruct-q6_k-00001-of-00007.gguf
...

Problem description & steps to reproduce

When I run llamacpp server (using the docker image server-cuda) I get this error after the first token is emitted

/app/ggml/src/ggml-cuda/ggml-cuda.cu:70: CUDA error
jade1        | CUDA error: an illegal memory access was encountered
jade1        |   current device: 0, in function ggml_backend_cuda_synchronize at /app/ggml/src/ggml-cuda/ggml-cuda.cu:2273
jade1        |   cudaStreamSynchronize(cuda_ctx->stream())

First Bad Commit

Not sure.

Relevant log output

jade1        | .................................................................................................
jade1        | llama_new_context_with_model: n_seq_max     = 4
jade1        | llama_new_context_with_model: n_ctx         = 100000
jade1        | llama_new_context_with_model: n_ctx_per_seq = 25000
jade1        | llama_new_context_with_model: n_batch       = 2048
jade1        | llama_new_context_with_model: n_ubatch      = 512
jade1        | llama_new_context_with_model: flash_attn    = 0
jade1        | llama_new_context_with_model: freq_base     = 1000000.0
jade1        | llama_new_context_with_model: freq_scale    = 1
jade1        | llama_new_context_with_model: n_ctx_per_seq (25000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
jade1        | llama_kv_cache_init:      CUDA0 KV buffer size = 25000.00 MiB
jade1        | llama_new_context_with_model: KV self size  = 25000.00 MiB, K (f16): 12500.00 MiB, V (f16): 12500.00 MiB
jade1        | llama_new_context_with_model:  CUDA_Host  output buffer size =     2.32 MiB
jade1        | llama_new_context_with_model:      CUDA0 compute buffer size =  8047.82 MiB
jade1        | llama_new_context_with_model:  CUDA_Host compute buffer size =   205.32 MiB
jade1        | llama_new_context_with_model: graph nodes  = 2246
jade1        | llama_new_context_with_model: graph splits = 2
jade1        | common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
jade1        | request: GET /health 172.18.0.3 503
jade1        | srv          init: initializing slots, n_slots = 4
jade1        | slot         init: id  0 | task -1 | new slot n_ctx_slot = 25000
jade1        | slot         init: id  1 | task -1 | new slot n_ctx_slot = 25000
jade1        | slot         init: id  2 | task -1 | new slot n_ctx_slot = 25000
jade1        | slot         init: id  3 | task -1 | new slot n_ctx_slot = 25000
jade1        | main: model loaded
jade1        | main: chat template, built_in: 1, chat_example: '<|im_start|>system
jade1        | You are a helpful assistant<|im_end|>
jade1        | <|im_start|>user
jade1        | Hello<|im_end|>
jade1        | <|im_start|>assistant
jade1        | Hi there<|im_end|>
jade1        | <|im_start|>user
jade1        | How are you?<|im_end|>
jade1        | <|im_start|>assistant
jade1        | '
jade1        | main: server is listening on http://0.0.0.0:15029 - starting the main loop
jade1        | srv  update_slots: all slots are idle
ai-worker-4  | 2024-12-09T12:42:45.057140Z  INFO main{worker_id="2FOIM8HI"}:worker_loop{state="ready" job_id="58f316b9-83a4-4a74-8f3a-2df8bd943960"}: Starting completion opts=LlamaCompletionTask { target: Title, system: Some("Write a short subject that summarizes what the user says or asks for. Write only the subject and nothing else. Be concise."), turns: None, llama: LlamaOptions { temperature: Some(0.2), dynatemp_range: None, dynatemp_exponent: None, top_k: None, top_p: None, min_p: None, n_predict: Some(1024), n_keep: None, stop: ["<|", "\n\n"], tfs_z: None, typical_p: None, repeat_penalty: None, repeat_last_n: None, penalize_nl: None, presence_penalty: None, frequency_penalty: None, penalty_prompt: None, mirostat: None, mirostat_tau: None, mirostat_eta: None, grammar: None, json_schema: None, seed: None, ignore_eos: None, logit_bias: [], n_probs: None, min_keep: None, image_data: [], id_slot: None, system_prompt: None, samplers: [] } }
jade1        | slot launch_slot_: id  0 | task 0 | processing task
jade1        | slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 25000, n_keep = 0, n_prompt_tokens = 60
jade1        | slot update_slots: id  0 | task 0 | kv cache rm [0, end)
jade1        | slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 60, n_tokens = 60, progress = 1.000000
jade1        | slot update_slots: id  0 | task 0 | prompt done, n_past = 60, n_tokens = 60
jade1        | request: GET /health 172.18.0.3 200
jade1        | /app/ggml/src/ggml-cuda/ggml-cuda.cu:70: CUDA error
jade1        | CUDA error: an illegal memory access was encountered
jade1        |   current device: 0, in function ggml_backend_cuda_synchronize at /app/ggml/src/ggml-cuda/ggml-cuda.cu:2273
jade1        |   cudaStreamSynchronize(cuda_ctx->stream())
@eamonnmag eamonnmag changed the title Eval bug: illegal memory access was encountered Eval bug: llama-server: illegal memory access was encountered Dec 9, 2024
@eamonnmag
Copy link
Author

FYI on A100s it works, only H100s where it's failing.

@JohannesGaessler
Copy link
Collaborator

Does the FP16 version of that model work correctly on an H100?
What happens when you run the q6_K version of the model with -ub 1?

@ggerganov
Copy link
Owner

ggerganov commented Dec 9, 2024

I am incidentally doing some tests on H100 now and also observe the same failure with a llama 3 8B F16 model:

./bin/llama-cli -m ../models/llama-f16.gguf -p "Hello" -ngl 99

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA H100 80GB HBM3, compute capability 9.0, VMM: yes
build: 4296 (07a61394) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_load_model_from_file: using device CUDA0 (NVIDIA H100 80GB HBM3) - 80469 MiB free

llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CUDA0 model buffer size = 14393.14 MiB
llm_load_tensors:   CPU_Mapped model buffer size =  1080.12 MiB
.........................................................................................
llama_new_context_with_model: n_seq_max     = 1
llama_new_context_with_model: n_ctx         = 4096
llama_new_context_with_model: n_ctx_per_seq = 4096
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 512
llama_new_context_with_model: flash_attn    = 0
llama_new_context_with_model: freq_base     = 500000.0
llama_new_context_with_model: freq_scale    = 1
llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_init:      CUDA0 KV buffer size =   512.00 MiB
llama_new_context_with_model: KV self size  =  512.00 MiB, K (f16):  256.00 MiB, V (f16):  256.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.53 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   296.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    16.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 2
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
/ggml/src/ggml-cuda/ggml-cuda.cu:70: CUDA error
CUDA error: an illegal memory access was encountered
  current device: 0, in function ggml_backend_cuda_buffer_clear at /ggml/src/ggml-cuda/ggml-cuda.cu:507
  cudaDeviceSynchronize()
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted

Adding -ub 1 also fails.

The Q8_0 version of the same model works correctly.

I will try to bisect now, but if you get any other ideas to try something, let me know.

@ggerganov ggerganov added bug Something isn't working high severity Used to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow) and removed bug-unconfirmed labels Dec 9, 2024
@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Dec 9, 2024

Compile with -lineinfo for NVCC, then use compute-sanitizer to determine the line which causes the bad access. Our CMake config currently does not seem to have support for -lineinfo, you have to either edit it in or use make with LLAMA_DEBUG.

If q6_k fails but q8_0 works the issue could be related to the number of values consumed in one iteration, can be checked by increasing MATRIX_ROW_PADDING in common.cuh.

@ggerganov
Copy link
Owner

Ok, I will try that in a bit. Currently narrowed down the regression between 100 commits and want to finish the bisect.

@ggerganov
Copy link
Owner

This is the commit that introduces the regression: c3ea58a

@ggerganov
Copy link
Owner

This fixes it:

diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
index cfe91f42..93e9ffe1 100644
--- a/ggml/src/ggml-cuda/mmv.cu
+++ b/ggml/src/ggml-cuda/mmv.cu
@@ -92,7 +92,7 @@ static void launch_mul_mat_vec_cuda(
         }
     }
 
-    const int smem = WARP_SIZE*sizeof(float);
+    const int smem = 2*WARP_SIZE*sizeof(float);
     const dim3 block_nums(nrows, 1, nchannels_y);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {

@JohannesGaessler Do you think it makes sense?

@JohannesGaessler
Copy link
Collaborator

It seems the condition for memory access is wrong, please confirm whether #10740 works as a fix.

@ggerganov
Copy link
Owner

Yes, it works with #10740

@eamonnmag
Copy link
Author

Thanks all!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high severity Used to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants