From 89ad7e30e22cce7cebc51c700345cb3c742eae82 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Wed, 17 Jan 2024 07:29:08 +0000 Subject: [PATCH] Fix ollama repeated calls (#52) Fixes https://github.com/svilupp/PromptingTools.jl/issues/51 --- CHANGELOG.md | 1 + Project.toml | 2 +- src/llm_ollama.jl | 8 ++++---- src/llm_ollama_managed.jl | 4 ++-- test/llm_ollama.jl | 8 ++++++++ test/llm_ollama_managed.jl | 9 +++++++++ 6 files changed, 25 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cb2c3b1e..c18124608 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Initial support for [Llama.jl](https://github.com/marcom/Llama.jl) and other local servers. Once your server is started, simply use `model="local"` to route your queries to the local server, eg, `ai"Say hi!"local`. Option to permanently set the `LOCAL_SERVER` (URL) added to preference management. See `?LocalServerOpenAISchema` for more information. ### Fixed +- Repeated calls to Ollama models were failing due to missing `prompt_eval_count` key in subsequent calls. ## [0.7.0] diff --git a/Project.toml b/Project.toml index e412d2d4e..81f0a16d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.7.0" +version = "0.8.1" [deps] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" diff --git a/src/llm_ollama.jl b/src/llm_ollama.jl index 241af35d2..f03c0c2ff 100644 --- a/src/llm_ollama.jl +++ b/src/llm_ollama.jl @@ -159,8 +159,8 @@ function aigenerate(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_ msg = AIMessage(; content = resp.response[:message][:content] |> strip, status = Int(resp.status), - tokens = (resp.response[:prompt_eval_count], - resp.response[:eval_count]), + tokens = (get(resp.response, :prompt_eval_count, 0), + get(resp.response, :eval_count, 0)), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) @@ -316,8 +316,8 @@ function aiscan(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_TYPE api_kwargs...) msg = AIMessage(; content = resp.response[:message][:content] |> strip, status = Int(resp.status), - tokens = (resp.response[:prompt_eval_count], - resp.response[:eval_count]), + tokens = (get(resp.response, :prompt_eval_count, 0), + get(resp.response, :eval_count, 0)), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) diff --git a/src/llm_ollama_managed.jl b/src/llm_ollama_managed.jl index 23fa175f5..a8918feea 100644 --- a/src/llm_ollama_managed.jl +++ b/src/llm_ollama_managed.jl @@ -216,8 +216,8 @@ function aigenerate(prompt_schema::AbstractOllamaManagedSchema, prompt::ALLOWED_ api_kwargs...) msg = AIMessage(; content = resp.response[:response] |> strip, status = Int(resp.status), - tokens = (resp.response[:prompt_eval_count], - resp.response[:eval_count]), + tokens = (get(resp.response, :prompt_eval_count, 0), + get(resp.response, :eval_count, 0)), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) diff --git a/test/llm_ollama.jl b/test/llm_ollama.jl index 232fc65e9..13052e7cb 100644 --- a/test/llm_ollama.jl +++ b/test/llm_ollama.jl @@ -109,6 +109,14 @@ end conversation; weather = "sunny", return_all = true)[1] == expected_convo_output[1] + + # Test if subsequent eval misses the prompt_eval_count key + response = Dict(:message => Dict(:content => "Prompt message")) + # :prompt_eval_count => 2, + # :eval_count => 1) + schema = TestEchoOllamaSchema(; response, status = 200) + msg = [aigenerate(schema, "hi") for i in 1:3] |> last + @test msg.tokens == (0, 0) end # @testset "aiembed-ollama" begin diff --git a/test/llm_ollama_managed.jl b/test/llm_ollama_managed.jl index 35a3461c0..6456b1dac 100644 --- a/test/llm_ollama_managed.jl +++ b/test/llm_ollama_managed.jl @@ -154,7 +154,16 @@ end @test_throws ErrorException aigenerate(schema, UserMessageWithImages("abc"; image_url = "https://example.com")) end + + # Test if subsequent eval misses the prompt_eval_count key + response = Dict(:response => "Hello John") + # :prompt_eval_count => 2, + # :eval_count => 1) + schema = TestEchoOllamaManagedSchema(; response, status = 200) + msg = [aigenerate(schema, "hi") for i in 1:3] |> last + @test msg.tokens == (0, 0) end + @testset "aiembed-ollama" begin @testset "single doc" begin response = Dict(:embedding => ones(16))