Skip to content

Commit

Permalink
Update AIrag kwargs (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Feb 14, 2024
1 parent 3bccc7e commit 057afcf
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.12.0]

### Added
- Added more specific kwargs in `Experimental.RAGTools.airag` to give more control over each type of AI call (ie, `aiembed_kwargs`, `aigenerate_kwargs`, `aiextract_kwargs`)
- Move up compat bounds for OpenAI.jl to 0.9

### Fixed
- Fixed a bug where obtaining an API_KEY from ENV would get precompiled as well, causing an error if the ENV was not set at the time of precompilation. Now, we save the `get(ENV...)` into a separate variable to avoid being compiled away.

## [0.11.0]

### Added
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.11.0"
version = "0.12.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down Expand Up @@ -31,7 +31,7 @@ JSON3 = "1"
LinearAlgebra = "<0.0.1, 1"
Logging = "<0.0.1, 1"
Markdown = "<0.0.1, 1"
OpenAI = "0.8.7"
OpenAI = "0.9"
Pkg = "<0.0.1, 1"
PrecompileTools = "1"
Preferences = "1"
Expand Down
25 changes: 20 additions & 5 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ end
return_context::Bool = false, verbose::Bool = true,
rerank_kwargs::NamedTuple = NamedTuple(),
api_kwargs::NamedTuple = NamedTuple(),
aiembed_kwargs::NamedTuple = NamedTuple(),
aigenerate_kwargs::NamedTuple = NamedTuple(),
aiextract_kwargs::NamedTuple = NamedTuple(),
kwargs...)
Generates a response for a given question using a Retrieval-Augmented Generation (RAG) approach.
Expand All @@ -71,7 +74,10 @@ The function selects relevant chunks from an `ChunkIndex`, optionally filters th
- `chunks_window_margin::Tuple{Int,Int}`: The window size around each chunk to consider for context building. See `?build_context` for more information.
- `return_context::Bool`: If `true`, returns the context used for RAG along with the response.
- `verbose::Bool`: If `true`, enables verbose logging.
- `api_kwargs`: API parameters that will be forwarded to the API calls
- `api_kwargs`: API parameters that will be forwarded to ALL of the API calls (`aiembed`, `aigenerate`, and `aiextract`).
- `aiembed_kwargs`: API parameters that will be forwarded to the `aiembed` call. If you need to provide `api_kwargs` only to this function, simply add them as a keyword argument, eg, `aiembed_kwargs = (; api_kwargs = (; x=1))`.
- `aigenerate_kwargs`: API parameters that will be forwarded to the `aigenerate` call. If you need to provide `api_kwargs` only to this function, simply add them as a keyword argument, eg, `aigenerate_kwargs = (; api_kwargs = (; temperature=0.3))`.
- `aiextract_kwargs`: API parameters that will be forwarded to the `aiextract` call for the metadata extraction.
# Returns
- If `return_context` is `false`, returns the generated message (`msg`).
Expand Down Expand Up @@ -109,6 +115,9 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC
return_context::Bool = false, verbose::Bool = true,
rerank_kwargs::NamedTuple = NamedTuple(),
api_kwargs::NamedTuple = NamedTuple(),
aiembed_kwargs::NamedTuple = NamedTuple(),
aigenerate_kwargs::NamedTuple = NamedTuple(),
aiextract_kwargs::NamedTuple = NamedTuple(),
kwargs...)
## Note: Supports only single ChunkIndex for now
## Checks
Expand All @@ -117,21 +126,26 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC
placeholders = only(aitemplates(rag_template)).variables # only one template should be found
@assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(rag_template) is not suitable. It must have placeholders: `question` and `context`."

## Embedding
joined_kwargs = isempty(api_kwargs) ? aiembed_kwargs :
merge(aiembed_kwargs, (; api_kwargs))
question_emb = aiembed(question,
_normalize;
model = model_embedding,
verbose, api_kwargs).content .|> Float32 # no need for Float64
verbose, joined_kwargs...).content .|> Float32 # no need for Float64
emb_candidates = find_closest(index, question_emb; top_k, minimum_similarity)

tag_candidates = if tag_filter == :auto && !isnothing(tags(index)) &&
!isempty(model_metadata)
_check_aiextract_capability(model_metadata)
joined_kwargs = isempty(api_kwargs) ? aiextract_kwargs :
merge(aiextract_kwargs, (; api_kwargs))
# extract metadata via LLM call
metadata_ = try
msg = aiextract(metadata_template; return_type = MaybeMetadataItems,
text = question,
instructions = "In addition to extracted items, suggest 2-3 filter keywords that could be relevant to answer this question.",
verbose, model = model_metadata, api_kwargs)
verbose, model = model_metadata, joined_kwargs...)
## eg, ["software:::pandas", "language:::python", "julia_package:::dataframes"]
## we split it and take only the keyword, not the category
metadata_extract(msg.content.items) |>
Expand Down Expand Up @@ -162,10 +176,11 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC
context = build_context(index, reranked_candidates; chunks_window_margin)

## LLM call
joined_kwargs = isempty(api_kwargs) ? aigenerate_kwargs :
merge(aigenerate_kwargs, (; api_kwargs))
msg = aigenerate(rag_template; question,
context = join(context, "\n\n"), model = model_chat, verbose,
api_kwargs,
kwargs...)
joined_kwargs...)

if return_context # for evaluation
rag_context = RAGContext(;
Expand Down
28 changes: 18 additions & 10 deletions src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,30 +126,38 @@ const MODEL_EMBEDDING::String = @load_preference("MODEL_EMBEDDING",
# const PROMPT_SCHEMA = OpenAISchema()

# First, load from preferences, then from environment variables
const OPENAI_API_KEY::String = @noinline @load_preference("OPENAI_API_KEY",
default=@noinline get(ENV, "OPENAI_API_KEY", ""));
# Note: We load first into a variable `temp_` to avoid inlining of the get(ENV...) call
_temp = get(ENV, "OPENAI_API_KEY", "")
const OPENAI_API_KEY::String = @load_preference("OPENAI_API_KEY",
default=_temp);
# Note: Disable this warning by setting OPENAI_API_KEY to anything
isempty(OPENAI_API_KEY) &&
@warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=<api-key>`!"

const MISTRALAI_API_KEY::String = @noinline @load_preference("MISTRALAI_API_KEY",
default=@noinline get(ENV, "MISTRALAI_API_KEY", ""));
_temp = get(ENV, "MISTRALAI_API_KEY", "")
const MISTRALAI_API_KEY::String = @load_preference("MISTRALAI_API_KEY",
default=_temp);

const COHERE_API_KEY::String = @noinline @load_preference("COHERE_API_KEY",
default=@noinline get(ENV, "COHERE_API_KEY", ""));
_temp = get(ENV, "COHERE_API_KEY", "")
const COHERE_API_KEY::String = @load_preference("COHERE_API_KEY",
default=_temp);

_temp = get(ENV, "DATABRICKS_API_KEY", "")
const DATABRICKS_API_KEY::String = @noinline @load_preference("DATABRICKS_API_KEY",
default=@noinline get(ENV, "DATABRICKS_API_KEY", ""));
default=_temp);

_temp = get(ENV, "DATABRICKS_HOST", "")
const DATABRICKS_HOST::String = @noinline @load_preference("DATABRICKS_HOST",
default=@noinline get(ENV, "DATABRICKS_HOST", ""));
default=_temp);

_temp = get(ENV, "TAVILY_API_KEY", "")
const TAVILY_API_KEY::String = @noinline @load_preference("TAVILY_API_KEY",
default=@noinline get(ENV, "TAVILY_API_KEY", ""));
default=_temp);

_temp = get(ENV, "LOCAL_SERVER", "")
## Address of the local server
const LOCAL_SERVER::String = @noinline @load_preference("LOCAL_SERVER",
default=@noinline get(ENV, "LOCAL_SERVER", "http://127.0.0.1:10897/v1"));
default=_temp);

## CONVERSATION HISTORY
"""
Expand Down
10 changes: 10 additions & 0 deletions test/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ end
return_context = false)
@test occursin("Time?", msg.content)

# test kwargs passing
api_kwargs = (; url = "http://localhost:$(PORT)")
msg = airag(index; question = "Time?", model_embedding = "mock-emb",
model_chat = "mock-gen",
model_metadata = "mock-meta",
tag_filter = ["yes"],
return_context = false, aiembed_kwargs = (; api_kwargs),
aigenerate_kwargs = (; api_kwargs), aiextract_kwargs = (; api_kwargs))
@test occursin("Time?", msg.content)

## Test different kwargs
msg, ctx = airag(index; question = "Time?", model_embedding = "mock-emb",
model_chat = "mock-gen",
Expand Down

0 comments on commit 057afcf

Please sign in to comment.