From 8a2fd463df7ea54e8c1d655619aee8d1ec274234 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sun, 17 Nov 2024 20:26:53 +0000 Subject: [PATCH] Add image support to aitools (#235) --- CHANGELOG.md | 6 ++++++ Project.toml | 2 +- src/llm_anthropic.jl | 30 ++++++++++++++++++++++++++---- src/llm_openai.jl | 9 +++++++-- src/user_preferences.jl | 12 +++++++++--- src/utils.jl | 31 +++++++++++++++++++++++++++++++ test/llm_anthropic.jl | 26 +++++++++++++++++++++++++- test/utils.jl | 21 ++++++++++++++++++++- 8 files changed, 125 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5b6eca79..cd06a0e64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.64.0] + +### Added +- Added support for images in `aitools` to enable passing screenshots via `image_path` argument (extended to both OpenAI and Anthropic APIs, uses `?UserMessageWithImages` internally). +- Added the latest Gemini Experimental model via OpenAI compatibility mode (`gemini-exp-1114` with alias `gemexp`). + ## [0.63.0] ### Added diff --git a/Project.toml b/Project.toml index d5966c0b4..aa580b51f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.63.0" +version = "0.64.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index 1a4e8c40b..7b020b21b 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -42,15 +42,31 @@ function render(schema::AbstractAnthropicSchema, conversation = Dict{String, Any}[] for msg in messages_replaced - if msg isa SystemMessage + if issystemmessage(msg) system = msg.content - elseif msg isa UserMessage || msg isa AIMessage + elseif isusermessage(msg) || isaimessage(msg) content = msg.content push!(conversation, Dict("role" => role4render(schema, msg), "content" => [Dict{String, Any}("type" => "text", "text" => content)])) - elseif msg isa UserMessageWithImages - error("AbstractAnthropicSchema does not yet support UserMessageWithImages. Please use OpenAISchema instead.") + elseif isusermessagewithimages(msg) + # Build message content + content = Dict{String, Any}[Dict("type" => "text", + "text" => msg.content)] + # Add images + for img in msg.image_url + # image_url = "data:image/$image_suffix;base64,$(base64_image)" + data_type, data = extract_image_attributes(img) + @assert data_type in ["image/jpeg", "image/png", "image/gif", "image/webp"] "Unsupported image type: $data_type" + push!(content, + Dict("type" => "image", + "source" => Dict("type" => "base64", + "data" => data, + ## image/jpeg, image/png, image/gif, image/webp + "media_type" => data_type))) + end + push!(conversation, + Dict("role" => role4render(schema, msg), "content" => content)) end # Note: Ignores any DataMessage or other types end @@ -766,6 +782,7 @@ end return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], no_system_message::Bool = false, + image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, cache::Union{Nothing, Symbol} = nothing, betas::Union{Nothing, Vector{Symbol}} = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, @@ -792,6 +809,7 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio - `dry_run`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. - `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history. +- `image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing`: A path to a local image file, or a vector of paths to local image files. Always attaches images to the latest user message. - `cache::Union{Nothing, Symbol} = nothing`: Whether to cache the prompt. Defaults to `nothing`. - `betas::Union{Nothing, Vector{Symbol}} = nothing`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details. - `http_kwargs`: A named tuple of HTTP keyword arguments. @@ -865,6 +883,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_ return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], no_system_message::Bool = false, + image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, cache::Union{Nothing, Symbol} = nothing, betas::Union{Nothing, Vector{Symbol}} = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, @@ -899,6 +918,9 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_ ## Add the function call stopping sequence to the api_kwargs api_kwargs = merge(api_kwargs, (; tools, tool_choice)) + ## Vision-specific functionality -- if `image_path` is provided, attach images to the latest user message + !isnothing(image_path) && + (prompt = attach_images_to_user_message(prompt; image_path, attach_to_latest = true)) ## We provide the tool description to the rendering engine conv_rendered = render( prompt_schema, prompt; tools, conversation, no_system_message, cache, kwargs...) diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 902d2b515..429d86325 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -1547,6 +1547,7 @@ end return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], no_system_message::Bool = false, + image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = (; @@ -1575,6 +1576,7 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio - `dry_run`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. - `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history. +- `image_path`: A path to a local image file, or a vector of paths to local image files. Always attaches images to the latest user message. - `name_user`: The name of the user in the conversation history. Defaults to "User". - `name_assistant`: The name of the assistant in the conversation history. Defaults to "Assistant". - `http_kwargs`: A named tuple of HTTP keyword arguments. @@ -1641,8 +1643,8 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP model::String = MODEL_CHAT, return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], - no_system_message::Bool = false, - name_user::Union{Nothing, String} = nothing, + no_system_message::Bool = false, name_user::Union{Nothing, String} = nothing, + image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, name_assistant::Union{Nothing, String} = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, @@ -1685,6 +1687,9 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP ## Find the unique ID for the model alias provided model_id = get(MODEL_ALIASES, model, model) + ## Vision-specific functionality -- if `image_path` is provided, attach images to the latest user message + !isnothing(image_path) && + (prompt = attach_images_to_user_message(prompt; image_path, attach_to_latest = true)) ## Render the conversation history from messages conv_rendered = render( prompt_schema, prompt; conversation, no_system_message, name_user, kwargs...) diff --git a/src/user_preferences.jl b/src/user_preferences.jl index 8aa18f50e..0df22f61e 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -478,7 +478,8 @@ aliases = merge( ## Gemini 1.5 Models "gem15p" => "gemini-1.5-pro-latest", "gem15f8" => "gemini-1.5-flash-8b-latest", - "gem15f" => "gemini-1.5-flash-latest" + "gem15f" => "gemini-1.5-flash-latest", + "gemexp" => "gemini-exp-1114" # latest experimental model from November 2024 ), ## Load aliases from preferences as well @load_preference("MODEL_ALIASES", default=Dict{String, String}())) @@ -1111,7 +1112,7 @@ registry = Dict{String, ModelSpec}( ## Gemini 1.5 Models "gemini-1.5-pro-latest" => ModelSpec("gemini-1.5-pro-latest", GoogleOpenAISchema(), - 1e-6, + 1.25e-6, 5e-6, "Gemini 1.5 Pro is Google's latest large language model with enhanced capabilities across reasoning, math, coding, and multilingual tasks. 128K context window."), "gemini-1.5-flash-8b-latest" => ModelSpec("gemini-1.5-flash-8b-latest", @@ -1123,7 +1124,12 @@ registry = Dict{String, ModelSpec}( GoogleOpenAISchema(), 7.5e-8, 3.0e-7, - "Gemini 1.5 Flash is a high-performance model optimized for speed while maintaining strong capabilities across various tasks. 128K context window.") + "Gemini 1.5 Flash is a high-performance model optimized for speed while maintaining strong capabilities across various tasks. 128K context window."), + "gemini-exp-1114" => ModelSpec("gemini-exp-1114", + GoogleOpenAISchema(), + 1.25e-6, + 5e-6, + "Gemini Experimental Model from November 2024. Pricing assumed as per Gemini 1.5 Pro. See details [here](https://ai.google.dev/gemini-api/docs/models/experimental-models#use-an-experimental-model).") ) """ diff --git a/src/utils.jl b/src/utils.jl index f47f5d16a..736a536a5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -668,4 +668,35 @@ Returns indices of unique items in a vector `inputs`. Access the unique values a """ function unique_permutation(inputs::AbstractVector) return unique(i -> inputs[i], eachindex(inputs)) +end + +""" + extract_image_attributes(image_url::AbstractString) -> Tuple{String, String} + +Extracts the data type and base64-encoded data from a data URL. + +# Arguments +- `image_url::AbstractString`: The data URL to be parsed. + +# Returns +`Tuple{String, String}`: A tuple containing the data type (e.g., `"image/png"`) and the base64-encoded data. + +# Example +```julia +image_url = "" +data_type, data = extract_data_type_and_data(image_url) +# data_type == "image/png" +# data == "iVBORw0KGgoAAAANSUhEUgAABQAA" +``` +""" +function extract_image_attributes(image_url::AbstractString)::Tuple{String, String} + pattern = r"^data:(.*?);base64,(.*)$" + m = match(pattern, image_url) + if m !== nothing + data_type = m.captures[1] + data = m.captures[2] + return data_type, data + else + throw(ArgumentError("Invalid data URL format")) + end end \ No newline at end of file diff --git a/test/llm_anthropic.jl b/test/llm_anthropic.jl index cfc7f2c51..83a5d47e5 100644 --- a/test/llm_anthropic.jl +++ b/test/llm_anthropic.jl @@ -103,14 +103,38 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature, conversation = render(schema, messages) @test conversation == expected_output + ### IMAGES # Test UserMessageWithImages -- errors for now messages = [ SystemMessage("System message 1"), UserMessageWithImages("User message"; image_url = "https://example.com/image.png") ] + ## We don't support URL format! @test_throws Exception render(schema, messages) - ## Tool calling + ## Unsupported format + messages = [ + SystemMessage("System message 1"), + UserMessageWithImages( + "User message"; image_url = "") + ] + @test_throws AssertionError render(schema, messages) + + ## Base64 format + messages = [ + SystemMessage("System message 1"), + UserMessageWithImages( + "User message"; image_url = "") + ] + rendered = render(schema, messages) + @test rendered.conversation[1] == Dict{String, Any}("role" => "user", + "content" => Dict{String, Any}[Dict("text" => "User message", "type" => "text"), + Dict( + "source" => Dict("media_type" => "image/png", + "data" => "iVBORw0KGgoAAAANSUhEUgAABQAA", "type" => "base64"), + "type" => "image")]) + + ### Tool calling "abc" struct FruitCountX fruit::String diff --git a/test/utils.jl b/test/utils.jl index 40807243d..bc885c881 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,7 +2,7 @@ using PromptingTools: recursive_splitter, wrap_string, replace_words, length_longest_common_subsequence, distance_longest_common_subsequence using PromptingTools: _extract_handlebar_variables, call_cost, call_cost_alternative, _report_stats -using PromptingTools: _string_to_vector, _encode_local_image +using PromptingTools: _string_to_vector, _encode_local_image, extract_image_attributes using PromptingTools: DataMessage, AIMessage, UserMessage using PromptingTools: push_conversation!, resize_conversation!, @timeout, preview, pprint, auth_header, @@ -276,6 +276,25 @@ end @test _encode_local_image(nothing) == String[] end +@testset "extract_image_attributes" begin + # Test basic valid data URL + data_url = "" + data_type, data = extract_image_attributes(data_url) + @test data_type == "image/png" + @test data == "iVBORw0KGgoAAAANSUhEUgAABQAA" + + # Test different image type + data_url = "" + data_type, data = extract_image_attributes(data_url) + @test data_type == "image/jpeg" + @test data == "/9j/4AAQSkZJRg" + + # Test invalid data URL format + @test_throws ArgumentError extract_image_attributes("not a data url") + @test_throws ArgumentError extract_image_attributes("data:image/png;") + @test_throws ArgumentError extract_image_attributes("data:image/png;base64") +end + ### Conversation Management @testset "push_conversation!,resize_conversation!" begin # Test 1: Adding to Conversation History