From 045379eac021f614c82e5d8672b17d06ff5f666c Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sat, 21 Dec 2024 18:50:41 +0000 Subject: [PATCH] Fix usage keys for GoogleAI (#258) --- CHANGELOG.md | 5 +++++ Project.toml | 2 +- src/llm_openai.jl | 46 +++++++++++++++++++++++++++++++++++++++------- test/llm_openai.jl | 12 ++++++++++++ 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7abfff624..c2a1cd591 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +### Fixed + +## [0.69.1] + ### Fixed - Added assertion in `response_to_message` for missing `:tool_calls` key in the response message. It's model failure but it wasn't obvious from the original error. +- Fixes error for usage information in CamelCase from OpenAI servers (Gemini proxy now sends it in CamelCase). ## [0.69.0] diff --git a/Project.toml b/Project.toml index 525fa7f18..8c41f57e9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.69.0" +version = "0.69.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/llm_openai.jl b/src/llm_openai.jl index b490f4e98..739ce08bc 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -176,9 +176,19 @@ function response_to_message(schema::AbstractOpenAISchema, else nothing end + # Extract usage information with default values for tokens + tokens_prompt = 0 + tokens_completion = 0 + # Merge with response usage if available + if haskey(resp.response, :usage) + response_usage = resp.response[:usage] + # Handle both snake_case and camelCase keys + tokens_prompt = get(response_usage, :prompt_tokens, + get(response_usage, :promptTokens, 0)) + tokens_completion = get(response_usage, :completion_tokens, + get(response_usage, :completionTokens, 0)) + end ## calculate cost - tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] - tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens] cost = call_cost(tokens_prompt, tokens_completion, model_id) extras = Dict{Symbol, Any}() if has_log_prob @@ -438,7 +448,9 @@ function aiembed(prompt_schema::AbstractOpenAISchema, model_id; http_kwargs, api_kwargs...) - tokens_prompt = get(r.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] + tokens_prompt = haskey(r.response, :usage) ? + get( + r.response[:usage], :prompt_tokens, get(r.response[:usage], :promptTokens, 0)) : 0 msg = DataMessage(; content = mapreduce(x -> postprocess(x[:embedding]), hcat, r.response[:data]), status = Int(r.status), @@ -844,9 +856,19 @@ function response_to_message(schema::AbstractOpenAISchema, else nothing end + # Extract usage information with default values for tokens + tokens_prompt = 0 + tokens_completion = 0 + # Merge with response usage if available + if haskey(resp.response, :usage) + response_usage = resp.response[:usage] + # Handle both snake_case and camelCase keys + tokens_prompt = get(response_usage, :prompt_tokens, + get(response_usage, :promptTokens, 0)) + tokens_completion = get(response_usage, :completion_tokens, + get(response_usage, :completionTokens, 0)) + end ## calculate cost - tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] - tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens] cost = call_cost(tokens_prompt, tokens_completion, model_id) # "Safe" parsing of the response - it still fails if JSON is invalid tools_array = if json_mode == true @@ -1490,9 +1512,19 @@ function response_to_message(schema::AbstractOpenAISchema, else nothing end + # Extract usage information with default values for tokens + tokens_prompt = 0 + tokens_completion = 0 + # Merge with response usage if available + if haskey(resp.response, :usage) + response_usage = resp.response[:usage] + # Handle both snake_case and camelCase keys + tokens_prompt = get(response_usage, :prompt_tokens, + get(response_usage, :promptTokens, 0)) + tokens_completion = get(response_usage, :completion_tokens, + get(response_usage, :completionTokens, 0)) + end ## calculate cost - tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] - tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens] cost = call_cost(tokens_prompt, tokens_completion, model_id) # "Safe" parsing of the response - it still fails if JSON is invalid has_tools = haskey(choice[:message], :tool_calls) && diff --git a/test/llm_openai.jl b/test/llm_openai.jl index dd514b719..bdacf396f 100644 --- a/test/llm_openai.jl +++ b/test/llm_openai.jl @@ -434,6 +434,18 @@ end @test msg.sample_id == nothing @test msg.cost == call_cost(2, 1, "gpt4t") + ## CamelCase usage keys + mock_response2 = (; + response = Dict(:choices => [mock_choice], + :usage => Dict(:totalTokens => 3, :promptTokens => 2, :completionTokens => 1)), + status = 200) + msg2 = response_to_message(OpenAISchema(), + AIMessage, + mock_choice, + mock_response2; + model_id = "gpt4t") + @test msg.tokens == (2, 1) + # Test without logprobs choice = deepcopy(mock_choice) delete!(choice, :logprobs)