From 057afcf038a6c8884ad94ebbd9eccb63fec75754 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Wed, 14 Feb 2024 21:35:18 +0000 Subject: [PATCH] Update AIrag kwargs (#74) --- CHANGELOG.md | 9 ++++++++ Project.toml | 4 ++-- src/Experimental/RAGTools/generation.jl | 25 ++++++++++++++++----- src/user_preferences.jl | 28 +++++++++++++++--------- test/Experimental/RAGTools/generation.jl | 10 +++++++++ 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c351593d..950983c00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.12.0] + +### Added +- Added more specific kwargs in `Experimental.RAGTools.airag` to give more control over each type of AI call (ie, `aiembed_kwargs`, `aigenerate_kwargs`, `aiextract_kwargs`) +- Move up compat bounds for OpenAI.jl to 0.9 + +### Fixed +- Fixed a bug where obtaining an API_KEY from ENV would get precompiled as well, causing an error if the ENV was not set at the time of precompilation. Now, we save the `get(ENV...)` into a separate variable to avoid being compiled away. + ## [0.11.0] ### Added diff --git a/Project.toml b/Project.toml index 3a0a192ae..9f1c151f3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.11.0" +version = "0.12.0" [deps] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -31,7 +31,7 @@ JSON3 = "1" LinearAlgebra = "<0.0.1, 1" Logging = "<0.0.1, 1" Markdown = "<0.0.1, 1" -OpenAI = "0.8.7" +OpenAI = "0.9" Pkg = "<0.0.1, 1" PrecompileTools = "1" Preferences = "1" diff --git a/src/Experimental/RAGTools/generation.jl b/src/Experimental/RAGTools/generation.jl index 712acf07f..8de2f0225 100644 --- a/src/Experimental/RAGTools/generation.jl +++ b/src/Experimental/RAGTools/generation.jl @@ -49,6 +49,9 @@ end return_context::Bool = false, verbose::Bool = true, rerank_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), + aiembed_kwargs::NamedTuple = NamedTuple(), + aigenerate_kwargs::NamedTuple = NamedTuple(), + aiextract_kwargs::NamedTuple = NamedTuple(), kwargs...) Generates a response for a given question using a Retrieval-Augmented Generation (RAG) approach. @@ -71,7 +74,10 @@ The function selects relevant chunks from an `ChunkIndex`, optionally filters th - `chunks_window_margin::Tuple{Int,Int}`: The window size around each chunk to consider for context building. See `?build_context` for more information. - `return_context::Bool`: If `true`, returns the context used for RAG along with the response. - `verbose::Bool`: If `true`, enables verbose logging. -- `api_kwargs`: API parameters that will be forwarded to the API calls +- `api_kwargs`: API parameters that will be forwarded to ALL of the API calls (`aiembed`, `aigenerate`, and `aiextract`). +- `aiembed_kwargs`: API parameters that will be forwarded to the `aiembed` call. If you need to provide `api_kwargs` only to this function, simply add them as a keyword argument, eg, `aiembed_kwargs = (; api_kwargs = (; x=1))`. +- `aigenerate_kwargs`: API parameters that will be forwarded to the `aigenerate` call. If you need to provide `api_kwargs` only to this function, simply add them as a keyword argument, eg, `aigenerate_kwargs = (; api_kwargs = (; temperature=0.3))`. +- `aiextract_kwargs`: API parameters that will be forwarded to the `aiextract` call for the metadata extraction. # Returns - If `return_context` is `false`, returns the generated message (`msg`). @@ -109,6 +115,9 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC return_context::Bool = false, verbose::Bool = true, rerank_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), + aiembed_kwargs::NamedTuple = NamedTuple(), + aigenerate_kwargs::NamedTuple = NamedTuple(), + aiextract_kwargs::NamedTuple = NamedTuple(), kwargs...) ## Note: Supports only single ChunkIndex for now ## Checks @@ -117,21 +126,26 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC placeholders = only(aitemplates(rag_template)).variables # only one template should be found @assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(rag_template) is not suitable. It must have placeholders: `question` and `context`." + ## Embedding + joined_kwargs = isempty(api_kwargs) ? aiembed_kwargs : + merge(aiembed_kwargs, (; api_kwargs)) question_emb = aiembed(question, _normalize; model = model_embedding, - verbose, api_kwargs).content .|> Float32 # no need for Float64 + verbose, joined_kwargs...).content .|> Float32 # no need for Float64 emb_candidates = find_closest(index, question_emb; top_k, minimum_similarity) tag_candidates = if tag_filter == :auto && !isnothing(tags(index)) && !isempty(model_metadata) _check_aiextract_capability(model_metadata) + joined_kwargs = isempty(api_kwargs) ? aiextract_kwargs : + merge(aiextract_kwargs, (; api_kwargs)) # extract metadata via LLM call metadata_ = try msg = aiextract(metadata_template; return_type = MaybeMetadataItems, text = question, instructions = "In addition to extracted items, suggest 2-3 filter keywords that could be relevant to answer this question.", - verbose, model = model_metadata, api_kwargs) + verbose, model = model_metadata, joined_kwargs...) ## eg, ["software:::pandas", "language:::python", "julia_package:::dataframes"] ## we split it and take only the keyword, not the category metadata_extract(msg.content.items) |> @@ -162,10 +176,11 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC context = build_context(index, reranked_candidates; chunks_window_margin) ## LLM call + joined_kwargs = isempty(api_kwargs) ? aigenerate_kwargs : + merge(aigenerate_kwargs, (; api_kwargs)) msg = aigenerate(rag_template; question, context = join(context, "\n\n"), model = model_chat, verbose, - api_kwargs, - kwargs...) + joined_kwargs...) if return_context # for evaluation rag_context = RAGContext(; diff --git a/src/user_preferences.jl b/src/user_preferences.jl index f2e09cf84..a5ca4bb60 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -126,30 +126,38 @@ const MODEL_EMBEDDING::String = @load_preference("MODEL_EMBEDDING", # const PROMPT_SCHEMA = OpenAISchema() # First, load from preferences, then from environment variables -const OPENAI_API_KEY::String = @noinline @load_preference("OPENAI_API_KEY", - default=@noinline get(ENV, "OPENAI_API_KEY", "")); +# Note: We load first into a variable `temp_` to avoid inlining of the get(ENV...) call +_temp = get(ENV, "OPENAI_API_KEY", "") +const OPENAI_API_KEY::String = @load_preference("OPENAI_API_KEY", + default=_temp); # Note: Disable this warning by setting OPENAI_API_KEY to anything isempty(OPENAI_API_KEY) && @warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=`!" -const MISTRALAI_API_KEY::String = @noinline @load_preference("MISTRALAI_API_KEY", - default=@noinline get(ENV, "MISTRALAI_API_KEY", "")); +_temp = get(ENV, "MISTRALAI_API_KEY", "") +const MISTRALAI_API_KEY::String = @load_preference("MISTRALAI_API_KEY", + default=_temp); -const COHERE_API_KEY::String = @noinline @load_preference("COHERE_API_KEY", - default=@noinline get(ENV, "COHERE_API_KEY", "")); +_temp = get(ENV, "COHERE_API_KEY", "") +const COHERE_API_KEY::String = @load_preference("COHERE_API_KEY", + default=_temp); +_temp = get(ENV, "DATABRICKS_API_KEY", "") const DATABRICKS_API_KEY::String = @noinline @load_preference("DATABRICKS_API_KEY", - default=@noinline get(ENV, "DATABRICKS_API_KEY", "")); + default=_temp); +_temp = get(ENV, "DATABRICKS_HOST", "") const DATABRICKS_HOST::String = @noinline @load_preference("DATABRICKS_HOST", - default=@noinline get(ENV, "DATABRICKS_HOST", "")); + default=_temp); +_temp = get(ENV, "TAVILY_API_KEY", "") const TAVILY_API_KEY::String = @noinline @load_preference("TAVILY_API_KEY", - default=@noinline get(ENV, "TAVILY_API_KEY", "")); + default=_temp); +_temp = get(ENV, "LOCAL_SERVER", "") ## Address of the local server const LOCAL_SERVER::String = @noinline @load_preference("LOCAL_SERVER", - default=@noinline get(ENV, "LOCAL_SERVER", "http://127.0.0.1:10897/v1")); + default=_temp); ## CONVERSATION HISTORY """ diff --git a/test/Experimental/RAGTools/generation.jl b/test/Experimental/RAGTools/generation.jl index b19d03a44..6fbb95057 100644 --- a/test/Experimental/RAGTools/generation.jl +++ b/test/Experimental/RAGTools/generation.jl @@ -96,6 +96,16 @@ end return_context = false) @test occursin("Time?", msg.content) + # test kwargs passing + api_kwargs = (; url = "http://localhost:$(PORT)") + msg = airag(index; question = "Time?", model_embedding = "mock-emb", + model_chat = "mock-gen", + model_metadata = "mock-meta", + tag_filter = ["yes"], + return_context = false, aiembed_kwargs = (; api_kwargs), + aigenerate_kwargs = (; api_kwargs), aiextract_kwargs = (; api_kwargs)) + @test occursin("Time?", msg.content) + ## Test different kwargs msg, ctx = airag(index; question = "Time?", model_embedding = "mock-emb", model_chat = "mock-gen",