-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eval bug: llama-server: illegal memory access was encountered #10739
Comments
FYI on A100s it works, only H100s where it's failing. |
Does the FP16 version of that model work correctly on an H100? |
I am incidentally doing some tests on H100 now and also observe the same failure with a llama 3 8B F16 model: ./bin/llama-cli -m ../models/llama-f16.gguf -p "Hello" -ngl 99
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA H100 80GB HBM3, compute capability 9.0, VMM: yes
build: 4296 (07a61394) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_load_model_from_file: using device CUDA0 (NVIDIA H100 80GB HBM3) - 80469 MiB free
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors: CUDA0 model buffer size = 14393.14 MiB
llm_load_tensors: CPU_Mapped model buffer size = 1080.12 MiB
.........................................................................................
llama_new_context_with_model: n_seq_max = 1
llama_new_context_with_model: n_ctx = 4096
llama_new_context_with_model: n_ctx_per_seq = 4096
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_init: CUDA0 KV buffer size = 512.00 MiB
llama_new_context_with_model: KV self size = 512.00 MiB, K (f16): 256.00 MiB, V (f16): 256.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.53 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 296.00 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 16.01 MiB
llama_new_context_with_model: graph nodes = 1030
llama_new_context_with_model: graph splits = 2
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
/ggml/src/ggml-cuda/ggml-cuda.cu:70: CUDA error
CUDA error: an illegal memory access was encountered
current device: 0, in function ggml_backend_cuda_buffer_clear at /ggml/src/ggml-cuda/ggml-cuda.cu:507
cudaDeviceSynchronize()
Could not attach to process. If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user. For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted Adding The I will try to bisect now, but if you get any other ideas to try something, let me know. |
Compile with If q6_k fails but q8_0 works the issue could be related to the number of values consumed in one iteration, can be checked by increasing |
Ok, I will try that in a bit. Currently narrowed down the regression between 100 commits and want to finish the bisect. |
This is the commit that introduces the regression: c3ea58a |
This fixes it: diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
index cfe91f42..93e9ffe1 100644
--- a/ggml/src/ggml-cuda/mmv.cu
+++ b/ggml/src/ggml-cuda/mmv.cu
@@ -92,7 +92,7 @@ static void launch_mul_mat_vec_cuda(
}
}
- const int smem = WARP_SIZE*sizeof(float);
+ const int smem = 2*WARP_SIZE*sizeof(float);
const dim3 block_nums(nrows, 1, nchannels_y);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) { @JohannesGaessler Do you think it makes sense? |
It seems the condition for memory access is wrong, please confirm whether #10740 works as a fix. |
Yes, it works with #10740 |
Thanks all! |
Name and Version
Using ghcr.io/ggerganov/llama.cpp@sha256:cb0f16e6eae440da844b3a80b8c15e82ac4b2b8f6637f674b10b263452e649aa
Operating systems
Linux
GGML backends
CUDA
Hardware
Nvidia H100
Cuda 12.2
Models
Qwen2.5-32B-Instruct-GGUF
URL https://huggingface.co/Qwen/Qwen2.5-32B-Instruct-GGUF/resolve/main/qwen2.5-32b-instruct-q6_k-00001-of-00007.gguf
...
Problem description & steps to reproduce
When I run llamacpp server (using the docker image server-cuda) I get this error after the first token is emitted
First Bad Commit
Not sure.
Relevant log output
The text was updated successfully, but these errors were encountered: