From 20a214c31ddd25492424d4c96f15a0a83e5a87b9 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 22 Jan 2024 21:53:18 +0000 Subject: [PATCH] Refactor RAGTools (#56) - Split `Experimental.RAGTools.build_index` into smaller functions to easier sharing with other packages (`get_chunks`, `get_embeddings`, `get_metadata`) - Added support for Cohere-based RAG re-ranking strategy (and introduced associated `COHERE_API_KEY` global variable and ENV variable) --- CHANGELOG.md | 8 + Project.toml | 2 +- src/Experimental/RAGTools/RAGTools.jl | 5 +- src/Experimental/RAGTools/api_services.jl | 36 +++ src/Experimental/RAGTools/generation.jl | 16 +- src/Experimental/RAGTools/preparation.jl | 264 ++++++++++++++++------ src/Experimental/RAGTools/retrieval.jl | 105 ++++++++- src/Experimental/RAGTools/types.jl | 83 ++++++- src/PromptingTools.jl | 1 + src/user_preferences.jl | 7 + src/utils.jl | 16 +- test/Experimental/RAGTools/evaluation.jl | 2 +- test/Experimental/RAGTools/preparation.jl | 16 +- test/Experimental/RAGTools/retrieval.jl | 61 ++++- test/Experimental/RAGTools/types.jl | 55 ++++- test/utils.jl | 13 +- 16 files changed, 598 insertions(+), 92 deletions(-) create mode 100644 src/Experimental/RAGTools/api_services.jl diff --git a/CHANGELOG.md b/CHANGELOG.md index fae61485a..b8d7dd0db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.9.0] + +### Added +- Split `Experimental.RAGTools.build_index` into smaller functions to easier sharing with other packages (`get_chunks`, `get_embeddings`, `get_metadata`) +- Added support for Cohere-based RAG re-ranking strategy (and introduced associated `COHERE_API_KEY` global variable and ENV variable) + +### Fixed + ## [0.8.1] ### Fixed diff --git a/Project.toml b/Project.toml index 81f0a16d9..7e7f82e32 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.8.1" +version = "0.9.0" [deps] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" diff --git a/src/Experimental/RAGTools/RAGTools.jl b/src/Experimental/RAGTools/RAGTools.jl index a315a2a7f..7f491fd36 100644 --- a/src/Experimental/RAGTools/RAGTools.jl +++ b/src/Experimental/RAGTools/RAGTools.jl @@ -10,11 +10,14 @@ This module is experimental and may change at any time. It is intended to be mov module RAGTools using PromptingTools -using JSON3 +using HTTP, JSON3 const PT = PromptingTools include("utils.jl") +# eg, cohere_api +include("api_services.jl") + export ChunkIndex, CandidateChunks # MultiIndex include("types.jl") diff --git a/src/Experimental/RAGTools/api_services.jl b/src/Experimental/RAGTools/api_services.jl new file mode 100644 index 000000000..adc7ac824 --- /dev/null +++ b/src/Experimental/RAGTools/api_services.jl @@ -0,0 +1,36 @@ +""" + cohere_api(; + api_key::AbstractString, + endpoint::String, + url::AbstractString="https://api.cohere.ai/v1", + http_kwargs::NamedTuple=NamedTuple(), + kwargs...) + +Lightweight wrapper around the Cohere API. See https://cohere.com/docs for more details. + +# Arguments +- `api_key`: Your Cohere API key. You can get one from https://dashboard.cohere.com/welcome/register (trial access is for free). +- `endpoint`: The Cohere endpoint to call. +- `url`: The base URL for the Cohere API. Default is `https://api.cohere.ai/v1`. +- `http_kwargs`: Any additional keyword arguments to pass to `HTTP.post`. +- `kwargs`: Any additional keyword arguments to pass to the Cohere API. +""" +function cohere_api(; + api_key::AbstractString, + endpoint::String, + url::AbstractString = "https://api.cohere.ai/v1", + http_kwargs::NamedTuple = NamedTuple(), + kwargs...) + @assert endpoint in ["chat", "generate", "embed", "rerank", "classify"] "Only 'chat', 'generate',`embed`,`rerank`,`classify` Cohere endpoints are supported." + @assert !isempty(api_key) "Cohere `api_key` cannot be empty. Check `PT.COHERE_API_KEY` or pass it as a keyword argument." + ## + input_body = Dict(kwargs...) + + # https://api.cohere.ai/v1/rerank + api_url = string(url, "/", endpoint) + resp = HTTP.post(api_url, + PT.auth_header(api_key), + JSON3.write(input_body); http_kwargs...) + body = JSON3.read(resp.body) + return (; response = body) +end \ No newline at end of file diff --git a/src/Experimental/RAGTools/generation.jl b/src/Experimental/RAGTools/generation.jl index 59b130b3c..d9725f6bb 100644 --- a/src/Experimental/RAGTools/generation.jl +++ b/src/Experimental/RAGTools/generation.jl @@ -39,7 +39,7 @@ end """ airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext; question::AbstractString, - top_k::Int = 3, `minimum_similarity::AbstractFloat`= -1.0, + top_k::Int = 100, top_n::Int = 5, minimum_similarity::AbstractFloat = -1.0, tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto, rerank_strategy::RerankingStrategy = Passthrough(), model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT, @@ -47,6 +47,7 @@ end metadata_template::Symbol = :RAGExtractMetadataShort, chunks_window_margin::Tuple{Int, Int} = (1, 1), return_context::Bool = false, verbose::Bool = true, + rerank_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), kwargs...) @@ -59,9 +60,10 @@ The function selects relevant chunks from an `ChunkIndex`, optionally filters th - `rag_template::Symbol`: Template for the RAG model, defaults to `:RAGAnswerFromContext`. - `question::AbstractString`: The question to be answered. - `top_k::Int`: Number of top candidates to retrieve based on embedding similarity. +- `top_n::Int`: Number of candidates to return after reranking. - `minimum_similarity::AbstractFloat`: Minimum similarity threshold (between -1 and 1) for filtering chunks based on embedding similarity. Defaults to -1.0. - `tag_filter::Union{Symbol, Vector{String}, Regex}`: Mechanism for filtering chunks based on tags (either automatically detected, specific tags, or a regex pattern). Disabled by setting to `nothing`. -- `rerank_strategy::RerankingStrategy`: Strategy for reranking the retrieved chunks. +- `rerank_strategy::RerankingStrategy`: Strategy for reranking the retrieved chunks. Defaults to `Passthrough()`. Use `CohereRerank` for better results (requires `COHERE_API_KEY` to be set) - `model_embedding::String`: Model used for embedding the question, default is `PT.MODEL_EMBEDDING`. - `model_chat::String`: Model used for generating the final response, default is `PT.MODEL_CHAT`. - `model_metadata::String`: Model used for extracting metadata, default is `PT.MODEL_CHAT`. @@ -97,7 +99,7 @@ See also `build_index`, `build_context`, `CandidateChunks`, `find_closest`, `fin """ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext; question::AbstractString, - top_k::Int = 3, minimum_similarity::AbstractFloat = -1.0, + top_k::Int = 100, top_n::Int = 5, minimum_similarity::AbstractFloat = -1.0, tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto, rerank_strategy::RerankingStrategy = Passthrough(), model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT, @@ -105,6 +107,7 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC metadata_template::Symbol = :RAGExtractMetadataShort, chunks_window_margin::Tuple{Int, Int} = (1, 1), return_context::Bool = false, verbose::Bool = true, + rerank_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), kwargs...) ## Note: Supports only single ChunkIndex for now @@ -148,7 +151,12 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC filtered_candidates = isnothing(tag_candidates) ? emb_candidates : (emb_candidates & tag_candidates) - reranked_candidates = rerank(rerank_strategy, index, question, filtered_candidates) + reranked_candidates = rerank(rerank_strategy, + index, + question, + filtered_candidates; + top_n, + verbose = false, rerank_kwargs...) ## Build the context context = build_context(index, reranked_candidates; chunks_window_margin) diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl index 50e937805..2b80a9e0f 100644 --- a/src/Experimental/RAGTools/preparation.jl +++ b/src/Experimental/RAGTools/preparation.jl @@ -34,26 +34,181 @@ function build_tags end "Build an index for RAG (Retriever-Augmented Generation) applications. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!" function build_index end +"Shortcut to LinearAlgebra.normalize. Provided in the package extension `RAGToolsExperimentalExt` (Requires SparseArrays and LinearAlgebra)" +function _normalize end + """ - build_index(files::Vector{<:AbstractString}; - separators = ["\n\n", ". ", "\n"], max_length::Int = 256, - extract_metadata::Bool = false, verbose::Bool = true, + get_chunks(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, + sources::Vector{<:AbstractString} = files_or_docs, + verbose::Bool = true, + separators = ["\n\n", ". ", "\n"], max_length::Int = 256) + +Chunks the provided `files_or_docs` into chunks of maximum length `max_length` (if possible with provided `separators`). + +Supports two modes of operation: +- `reader=:files`: The function opens each file in `files_or_docs` and reads its content. +- `reader=:docs`: The function assumes that `files_or_docs` is a vector of strings to be chunked. + +# Arguments +- `files_or_docs`: A vector of valid file paths OR string documents to be chunked. +- `reader`: A symbol indicating the type of input, can be either `:files` or `:docs`. Default is `:files`. +- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `[\n\n", ". ", "\n"]`. +- `max_length`: The maximum length of each chunk (if possible with provided separators). Default is 256. +- `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs` (for `reader=:files`) + +""" +function get_chunks(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, + sources::Vector{<:AbstractString} = files_or_docs, + verbose::Bool = true, + separators = ["\n\n", ". ", "\n"], max_length::Int = 256) + + ## Check that all items must be existing files or strings + @assert reader in [:files, :docs] "Invalid `read` argument. Must be one of [:files, :docs]" + if reader == :files + @assert all(isfile, files_or_docs) "Some paths in `files_or_docs` don't exist (Check: $(join(filter(!isfile,files_or_docs),", "))" + else + @assert sources!=files_or_docs "When `reader=:docs`, vector of `sources` must be provided" + end + @assert isnothing(sources)||(length(sources) == length(files_or_docs)) "Length of `sources` must match length of `files_or_docs`" + @assert maximum(length.(sources))<=512 "Each source must be less than 512 characters long (Detected: $(maximum(length.(sources))))" + + output_chunks = Vector{SubString{String}}() + output_sources = Vector{eltype(sources)}() + + # Do chunking first + for i in eachindex(files_or_docs, sources) + # if reader == :files, we open the files and read them + doc_raw = if reader == :files + fn = files_or_docs[i] + (verbose > 0) && @info "Processing file: $fn" + read(fn, String) + else + files_or_docs[i] + end + isempty(doc_raw) && continue + # split into chunks, if you want to start simple - just do `split(text,"\n\n")` + doc_chunks = PT.split_by_length(doc_raw, separators; max_length) .|> strip |> + x -> filter(!isempty, x) + # skip if no chunks found + isempty(doc_chunks) && continue + append!(output_chunks, doc_chunks) + append!(output_sources, fill(sources[i], length(doc_chunks))) + end + + return output_chunks, output_sources +end + +""" + get_embeddings(docs::Vector{<:AbstractString}; + verbose::Bool = true, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Embeds a vector of `docs` using the provided model (kwarg `model`). + +Tries to batch embedding calls for roughly 80K characters per call (to avoid exceeding the API limit) but reduce network latency. + +Note: `docs` are assumed to be already chunked to the reasonable sizes that fit within the embedding context limit. + +# Arguments +- `docs`: A vector of strings to be embedded. +- `verbose`: A boolean flag for verbose output. Default is `true`. +- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`. +- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call. + +""" +function get_embeddings(docs::Vector{<:AbstractString}; + verbose::Bool = true, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + verbose && @info "Embedding $(length(docs)) documents..." + model = hasproperty(kwargs, :model) ? kwargs.model : PT.MODEL_EMBEDDING + # Notice that we embed multiple docs at once, not one by one + # OpenAI supports embedding multiple documents to reduce the number of API calls/network latency time + # We do batch them just in case the documents are too large (targeting at most 80K characters per call) + avg_length = sum(length.(docs)) / length(docs) + embedding_batch_size = floor(Int, 80_000 / avg_length) + embeddings = asyncmap(Iterators.partition(docs, embedding_batch_size)) do docs_chunk + msg = aiembed(docs_chunk, + _normalize; + verbose = false, + kwargs...) + Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model)) # track costs + msg.content + end + embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents + verbose && @info "Done embedding. Total cost: \$$(round(cost_tracker[],digits=3))" + return embeddings +end + +""" + get_metadata(docs::Vector{<:AbstractString}; + verbose::Bool = true, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Extracts metadata from a vector of `docs` using the provided model (kwarg `model`). + +# Arguments +- `docs`: A vector of strings to be embedded. +- `verbose`: A boolean flag for verbose output. Default is `true`. +- `model`: The model to use for metadata extraction. Default is `PT.MODEL_CHAT`. +- `metadata_template`: A template to be used for metadata extraction. Default is `:RAGExtractMetadataShort`. +- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call. + +""" +function get_metadata(docs::Vector{<:AbstractString}; + verbose::Bool = true, + metadata_template::Symbol = :RAGExtractMetadataShort, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + model = hasproperty(kwargs, :model) ? kwargs.model : PT.MODEL_CHAT + _check_aiextract_capability(model) + verbose && @info "Extracting metadata from $(length(docs)) documents..." + metadata = asyncmap(docs) do docs_chunk + try + msg = aiextract(metadata_template; + return_type = MaybeMetadataItems, + text = docs_chunk, + instructions = "None.", + verbose = false, + model, kwargs...) + Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model)) # track costs + items = metadata_extract(msg.content.items) + catch + String[] + end + end + verbose && + @info "Done extracting the metadata. Total cost: \$$(round(cost_tracker[],digits=3))" + return metadata +end + +""" + build_index(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, + separators = ["\\n\\n", ". ", "\\n"], max_length::Int = 256, + sources::Vector{<:AbstractString} = files_or_docs, + extract_metadata::Bool = false, verbose::Int = 1, + index_id = gensym("ChunkIndex"), metadata_template::Symbol = :RAGExtractMetadataShort, model_embedding::String = PT.MODEL_EMBEDDING, model_metadata::String = PT.MODEL_CHAT, - api_kwargs::NamedTuple = NamedTuple()) + api_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0)) Build an index for RAG (Retriever-Augmented Generation) applications from the provided file paths. The function processes each file, splits its content into chunks, embeds these chunks, optionally extracts metadata, and then compiles this information into a retrievable index. # Arguments -- `files`: A vector of valid file paths to be indexed. -- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `["\n\n", ". ", "\n"]`. +- `files_or_docs`: A vector of valid file paths OR string documents to be indexed (chunked and embedded). +- `reader`: A symbol indicating the type of input, can be either `:files` or `:docs`. Default is `:files`. +- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `[\\n\\n", ". ", "\\n"]`. - `max_length`: The maximum length of each chunk (if possible with provided separators). Default is 256. +- `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs` (for `reader=:files`) - `extract_metadata`: A boolean flag indicating whether to extract metadata from each chunk (to build filter `tags` in the index). Default is `false`. Metadata extraction incurs additional cost and requires `model_metadata` and `metadata_template` to be provided. -- `verbose`: A boolean flag for verbose output. Default is `true`. +- `verbose`: An Integer specifying the verbosity of the logs. Default is `1` (high-level logging). `0` is disabled. - `metadata_template`: A symbol indicating the template to be used for metadata extraction. Default is `:RAGExtractMetadataShort`. - `model_embedding`: The model to use for embedding. - `model_metadata`: The model to use for metadata extraction. @@ -69,79 +224,56 @@ See also: `MultiIndex`, `CandidateChunks`, `find_closest`, `find_tags`, `rerank` # Assuming `test_files` is a vector of file paths index = build_index(test_files; max_length=10, extract_metadata=true) -# Another example with metadata extraction and verbose output +# Another example with metadata extraction and verbose output (`reader=:files` is implicit) index = build_index(["file1.txt", "file2.txt"]; separators=[". "], extract_metadata=true, verbose=true) ``` """ -function build_index(files::Vector{<:AbstractString}; +function build_index(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, separators = ["\n\n", ". ", "\n"], max_length::Int = 256, - extract_metadata::Bool = false, verbose::Bool = true, + sources::Vector{<:AbstractString} = files_or_docs, + extract_metadata::Bool = false, verbose::Integer = 1, + index_id = gensym("ChunkIndex"), metadata_template::Symbol = :RAGExtractMetadataShort, model_embedding::String = PT.MODEL_EMBEDDING, model_metadata::String = PT.MODEL_CHAT, - api_kwargs::NamedTuple = NamedTuple()) - ## - @assert all(isfile, files) "Some `files` don't exist (Check: $(join(filter(!isfile,files),", "))" - - output_chunks = Vector{Vector{SubString{String}}}() - output_embeddings = Vector{Matrix{Float32}}() - output_metadata = Vector{Vector{Vector{String}}}() - output_sources = Vector{Vector{eltype(files)}}() - cost_tracker = Threads.Atomic{Float64}(0.0) - - for fn in files - verbose && @info "Processing file: $fn" - doc_raw = read(fn, String) - isempty(doc_raw) && continue - # split into chunks, if you want to start simple - just do `split(text,"\n\n")` - doc_chunks = PT.split_by_length(doc_raw, separators; max_length) .|> strip |> - x -> filter(!isempty, x) - # skip if no chunks found - isempty(doc_chunks) && continue - push!(output_chunks, doc_chunks) - push!(output_sources, fill(fn, length(doc_chunks))) - - # Notice that we embed all doc_chunks at once, not one by one - # OpenAI supports embedding multiple documents to reduce the number of API calls/network latency time - emb = aiembed(doc_chunks, _normalize; model = model_embedding, verbose, api_kwargs) - Threads.atomic_add!(cost_tracker, PT.call_cost(emb, model_embedding)) # track costs - push!(output_embeddings, Float32.(emb.content)) - - if extract_metadata && !isempty(model_metadata) - _check_aiextract_capability(model_metadata) - metadata_ = asyncmap(doc_chunks) do chunk - try - msg = aiextract(metadata_template; - return_type = MaybeMetadataItems, - text = chunk, - instructions = "None.", - verbose, - model = model_metadata, api_kwargs) - Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model_metadata)) # track costs - items = metadata_extract(msg.content.items) - catch - String[] - end - end - push!(output_metadata, metadata_) - end - end - ## Create metadata tags and associated vocabulary - tags, tags_vocab = if !isempty(output_metadata) - # Requires SparseArrays.jl! - build_tags(vcat(output_metadata...)) # need to vcat to be on the "chunk-level" + api_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0)) + + ## Split into chunks + output_chunks, output_sources = get_chunks(files_or_docs; + reader, sources, separators, max_length) + + ## Embed chunks + embeddings = get_embeddings(output_chunks; + verbose = (verbose > 1), + cost_tracker, + model = model_embedding, + api_kwargs) + + ## Extract metadata + tags, tags_vocab = if extract_metadata + output_metadata = get_metadata(output_chunks; + verbose = (verbose > 1), + cost_tracker, + model = model_metadata, + metadata_template, + api_kwargs) + # Requires SparseArrays.jl to be loaded + build_tags(output_metadata) else - tags, tags_vocab = nothing, nothing + nothing, nothing end - verbose && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))" + ## Create metadata tag array and associated vocabulary + (verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))" index = ChunkIndex(; - embeddings = hcat(output_embeddings...), + id = index_id, + embeddings, tags, tags_vocab, - chunks = vcat(output_chunks...), - sources = vcat(output_sources...)) + chunks = output_chunks, + sources = output_sources) return index end diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index a3d136aa4..d7c998731 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -32,6 +32,25 @@ function find_closest(index::AbstractChunkIndex, minimum_similarity) return CandidateChunks(index.id, positions, Float32.(distances)) end +## function find_closest(index::AbstractMultiIndex, +## query_emb::AbstractVector{<:Real}; +## top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0) +## all_candidates = CandidateChunks[] +## for idxs in indexes(index) +## candidates = find_closest(idxs, query_emb; +## top_k, +## minimum_similarity) +## if !isempty(candidates.positions) +## push!(all_candidates, candidates) +## end +## end +## ## build vector of all distances and pick top_k +## all_distances = mapreduce(x -> x.distances, vcat, all_candidates) +## top_k_order = all_distances |> sortperm |> x -> last(x, top_k) +## return CandidateChunks(index.id, +## all_candidates[top_k_order], +## all_distances[top_k_order]) +## end function find_tags(index::AbstractChunkIndex, tag::Union{AbstractString, Regex}) @@ -57,8 +76,90 @@ end abstract type RerankingStrategy end struct Passthrough <: RerankingStrategy end +struct CohereRerank <: RerankingStrategy end -function rerank(strategy::Passthrough, index, question, candidate_chunks; kwargs...) +function rerank(strategy::Passthrough, + index, + question, + candidate_chunks; + top_n::Integer = length(candidate_chunks), + kwargs...) # Since this is a Passthrough strategy, it returns the candidate_chunks unchanged - return candidate_chunks + return first(candidate_chunks, top_n) +end + +function rerank(strategy::CohereRerank, + index::AbstractDocumentIndex, args...; kwargs...) + throw(ArgumentError("Not implemented yet")) +end + +""" + rerank(strategy::CohereRerank, index::AbstractChunkIndex, question, + candidate_chunks; + verbose::Bool = false, + api_key::AbstractString = PT.COHERE_API_KEY, + top_n::Integer = length(candidate_chunks.distances), + model::AbstractString = "rerank-english-v2.0", + return_documents::Bool = false, + kwargs...) + +Re-ranks a list of candidate chunks using the Cohere Rerank API. See https://cohere.com/rerank for more details. + +# Arguments +- `query`: The query to be used for the search. +- `documents`: A vector of documents to be reranked. + The total max chunks (`length of documents * max_chunks_per_doc`) must be less than 10000. We recommend less than 1000 documents for optimal performance. +- `top_n`: The number of most relevant documents to return. Default is `length(documents)`. +- `model`: The model to use for reranking. Default is `rerank-english-v2.0`. +- `return_documents`: A boolean flag indicating whether to return the reranked documents in the response. Default is `false`. +- `max_chunks_per_doc`: The maximum number of chunks to use per document. Default is `10`. +- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`. + +""" +function rerank(strategy::CohereRerank, index::AbstractChunkIndex, question, + candidate_chunks; + verbose::Bool = false, + api_key::AbstractString = PT.COHERE_API_KEY, + top_n::Integer = length(candidate_chunks.distances), + model::AbstractString = "rerank-english-v2.0", + return_documents::Bool = false, + kwargs...) + @assert top_n>0 "top_n must be a positive integer." + @assert index.id==candidate_chunks.index_id "The index id of the index and candidate_chunks must match." + + ## Call the API + documents = index[candidate_chunks, :chunks] + verbose && + @info "Calling Cohere Rerank API with $(length(documents)) candidate chunks..." + r = cohere_api(; + api_key, + endpoint = "rerank", + query = question, + documents, + top_n, + model, + return_documents, + kwargs...) + + ## Unwrap re-ranked positions + positions = Vector{Int}(undef, length(r.response[:results])) + distances = Vector{Float32}(undef, length(r.response[:results])) + for i in eachindex(r.response[:results]) + doc = r.response[:results][i] + positions[i] = candidate_chunks.positions[doc[:index] + 1] + distances[i] = doc[:relevance_score] + end + + ## Check the cost + search_units_str = if haskey(r.response, :meta) && + haskey(r.response[:meta], :billed_units) && + haskey(r.response[:meta][:billed_units], :search_units) + units = r.response[:meta][:billed_units][:search_units] + "Charged $(units) search units." + else + "" + end + verbose && @info "Reranking done. $search_units_str" + + return CandidateChunks(index.id, positions, distances) end \ No newline at end of file diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 2aeb7a4d0..dab4e78ba 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -3,6 +3,7 @@ # In addition, RAGContext is defined for debugging purposes abstract type AbstractDocumentIndex end +abstract type AbstractMultiIndex <: AbstractDocumentIndex end abstract type AbstractChunkIndex <: AbstractDocumentIndex end # More advanced index would be: HybridChunkIndex @@ -35,6 +36,10 @@ function Base.var"=="(i1::ChunkIndex, i2::ChunkIndex) (i1.embeddings == i2.embeddings) && (i1.chunks == i2.chunks) && (i1.tags == i2.tags)) end +function Base.vcat(i1::AbstractDocumentIndex, i2::AbstractDocumentIndex) + throw(ArgumentError("Not implemented")) +end + function Base.vcat(i1::ChunkIndex, i2::ChunkIndex) tags_, tags_vocab_ = if (isnothing(tags(i1)) || isnothing(tags(i2))) nothing, nothing @@ -54,9 +59,9 @@ function Base.vcat(i1::ChunkIndex, i2::ChunkIndex) end "Composite index that stores multiple ChunkIndex objects and their embeddings" -@kwdef struct MultiIndex <: AbstractDocumentIndex +@kwdef struct MultiIndex <: AbstractMultiIndex id::Symbol = gensym("MultiIndex") - indexes::Vector{<:ChunkIndex} + indexes::Vector{<:AbstractChunkIndex} end indexes(index::MultiIndex) = index.indexes # check that each index has a counterpart in the other MultiIndex @@ -76,13 +81,27 @@ function Base.var"=="(i1::MultiIndex, i2::MultiIndex) end abstract type AbstractCandidateChunks end -@kwdef struct CandidateChunks{T <: Real} <: AbstractCandidateChunks +@kwdef struct CandidateChunks{TP <: Union{Integer, AbstractCandidateChunks}, TD <: Real} <: + AbstractCandidateChunks index_id::Symbol - positions::Vector{Int} = Int[] - distances::Vector{T} = Float32[] + ## if TP is Int, then positions are indices into the index + ## if TP is CandidateChunks, then positions are indices into the positions of the child index in MultiIndex + positions::Vector{TP} = Int[] + distances::Vector{TD} = Float32[] +end +Base.length(cc::CandidateChunks) = length(cc.positions) +function Base.first(cc::CandidateChunks, k::Integer) + CandidateChunks(cc.index_id, first(cc.positions, k), first(cc.distances, k)) end # combine/intersect two candidate chunks. average the score if available -function Base.var"&"(cc1::CandidateChunks, cc2::CandidateChunks) +function Base.var"&"(cc1::AbstractCandidateChunks, + cc2::AbstractCandidateChunks) + throw(ArgumentError("Not implemented")) +end +function Base.var"&"(cc1::CandidateChunks{TP1, TD1}, + cc2::CandidateChunks{TP2, TD2}) where + {TP1 <: Integer, TP2 <: Integer, TD1 <: Real, TD2 <: Real} + ## cc1.index_id != cc2.index_id && return CandidateChunks(; index_id = cc1.index_id) positions = intersect(cc1.positions, cc2.positions) @@ -93,17 +112,39 @@ function Base.var"&"(cc1::CandidateChunks, cc2::CandidateChunks) end CandidateChunks(cc1.index_id, positions, distances) end -function Base.getindex(ci::ChunkIndex, candidate::CandidateChunks, field::Symbol = :chunks) - @assert field==:chunks "Only `chunks` field is supported for now" + +function Base.getindex(ci::AbstractDocumentIndex, + candidate::AbstractCandidateChunks, + field::Symbol) + throw(ArgumentError("Not implemented")) +end +function Base.getindex(ci::ChunkIndex, + candidate::CandidateChunks{TP, TD}, + field::Symbol = :chunks) where {TP <: Integer, TD <: Real} + @assert field in [:chunks, :embeddings, :sources] "Only `chunks`, `embeddings`, `sources` fields are supported for now" len_ = length(chunks(ci)) @assert all(1 .<= candidate.positions .<= len_) "Some positions are out of bounds" if ci.id == candidate.index_id - chunks(ci)[candidate.positions] + if field == :chunks + @views chunks(ci)[candidate.positions] + elseif field == :embeddings + @views embeddings(ci)[:, candidate.positions] + elseif field == :sources + @views sources(ci)[candidate.positions] + end else - eltype(chunks(ci))[] + if field == :chunks + eltype(chunks(ci))[] + elseif field == :embeddings + eltype(embeddings(ci))[] + elseif field == :sources + eltype(sources(ci))[] + end end end -function Base.getindex(mi::MultiIndex, candidate::CandidateChunks, field::Symbol = :chunks) +function Base.getindex(mi::MultiIndex, + candidate::CandidateChunks{TP, TD}, + field::Symbol = :chunks) where {TP <: Integer, TD <: Real} @assert field==:chunks "Only `chunks` field is supported for now" valid_index = findfirst(x -> x.id == candidate.index_id, indexes(mi)) if isnothing(valid_index) @@ -112,6 +153,26 @@ function Base.getindex(mi::MultiIndex, candidate::CandidateChunks, field::Symbol getindex(indexes(mi)[valid_index], candidate) end end +# Dispatch for multi-candidate chunks +function Base.getindex(ci::ChunkIndex, + candidate::CandidateChunks{TP, TD}, + field::Symbol = :chunks) where {TP <: AbstractCandidateChunks, TD <: Real} + @assert field==:chunks "Only `chunks` field is supported for now" + + index_pos = findfirst(x -> x.index_id == ci.id, candidate.positions) + @info index_pos + if isnothing(index_pos) + eltype(chunks(ci))[] + else + getindex(chunks(ci), candidate.positions[index_pos].positions) + end +end +function Base.getindex(mi::MultiIndex, + candidate::CandidateChunks{TP, TD}, + field::Symbol = :chunks) where {TP <: AbstractCandidateChunks, TD <: Real} + @assert field==:chunks "Only `chunks` field is supported for now" + mapreduce(idxs -> Base.getindex(idxs, candidate, field), vcat, indexes(mi)) +end """ RAGContext diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index c48c28257..51c47ddd3 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -25,6 +25,7 @@ const RESERVED_KWARGS = [ :model, ] +# export replace_words, split_by_length, call_cost, auth_header # for debugging only include("utils.jl") export aigenerate, aiembed, aiclassify, aiextract, aiscan diff --git a/src/user_preferences.jl b/src/user_preferences.jl index 22f597eab..fe1be50c4 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -13,6 +13,7 @@ Check your preferences by calling `get_preferences(key::String)`. # Available Preferences (for `set_preferences!`) - `OPENAI_API_KEY`: The API key for the OpenAI API. See [OpenAI's documentation](https://platform.openai.com/docs/quickstart?context=python) for more information. - `MISTRALAI_API_KEY`: The API key for the Mistral AI API. See [Mistral AI's documentation](https://docs.mistral.ai/) for more information. +- `COHERE_API_KEY`: The API key for the Cohere API. See [Cohere's documentation](https://docs.cohere.com/docs/the-cohere-platform) for more information. - `MODEL_CHAT`: The default model to use for aigenerate and most ai* calls. See `MODEL_REGISTRY` for a list of available models or define your own. - `MODEL_EMBEDDING`: The default model to use for aiembed (embedding documents). See `MODEL_REGISTRY` for a list of available models or define your own. - `PROMPT_SCHEMA`: The default prompt schema to use for aigenerate and most ai* calls (if not specified in `MODEL_REGISTRY`). Set as a string, eg, `"OpenAISchema"`. @@ -30,6 +31,7 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava # Available ENV Variables - `OPENAI_API_KEY`: The API key for the OpenAI API. - `MISTRALAI_API_KEY`: The API key for the Mistral AI API. +- `COHERE_API_KEY`: The API key for the Cohere API. - `LOCAL_SERVER`: The URL of the local server to use for `ai*` calls. Defaults to `http://localhost:10897/v1`. This server is called when you call `model="local"` Preferences.jl takes priority over ENV variables, so if you set a preference, it will override the ENV variable. @@ -56,6 +58,7 @@ function set_preferences!(pairs::Pair{String, <:Any}...) allowed_preferences = [ "MISTRALAI_API_KEY", "OPENAI_API_KEY", + "COHERE_API_KEY", "MODEL_CHAT", "MODEL_EMBEDDING", "MODEL_ALIASES", @@ -91,6 +94,7 @@ function get_preferences(key::String) allowed_preferences = [ "MISTRALAI_API_KEY", "OPENAI_API_KEY", + "COHERE_API_KEY", "MODEL_CHAT", "MODEL_EMBEDDING", "MODEL_ALIASES", @@ -119,6 +123,9 @@ isempty(OPENAI_API_KEY) && const MISTRALAI_API_KEY::String = @load_preference("MISTRALAI_API_KEY", default=get(ENV, "MISTRALAI_API_KEY", "")); +const COHERE_API_KEY::String = @load_preference("COHERE_API_KEY", + default=get(ENV, "COHERE_API_KEY", "")); + ## Address of the local server const LOCAL_SERVER::String = @load_preference("LOCAL_SERVER", default=get(ENV, "LOCAL_SERVER", "http://127.0.0.1:10897/v1")); diff --git a/src/utils.jl b/src/utils.jl index 3b8cacc9f..d7a145a73 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -359,4 +359,18 @@ macro timeout(seconds, expr_to_run, expr_when_fails) end "Utility for rendering the conversation (vector of messages) as markdown. REQUIRES the Markdown package to load the extension!" -function preview end \ No newline at end of file +function preview end + +""" + auth_header(api_key::String) + +Builds an authorization header for API calls with the given API key. +""" +function auth_header(api_key::String) + isempty(api_key) && throw(ArgumentError("api_key cannot be empty")) + [ + "Authorization" => "Bearer $api_key", + "Content-Type" => "application/json", + "Accept" => "application/json", + ] +end \ No newline at end of file diff --git a/test/Experimental/RAGTools/evaluation.jl b/test/Experimental/RAGTools/evaluation.jl index f9ea295e8..db8ed3be3 100644 --- a/test/Experimental/RAGTools/evaluation.jl +++ b/test/Experimental/RAGTools/evaluation.jl @@ -75,7 +75,7 @@ end @testset "build_qa_evals" begin # test with a mock server - PORT = rand(1000:2000) + PORT = rand(9000:11000) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) diff --git a/test/Experimental/RAGTools/preparation.jl b/test/Experimental/RAGTools/preparation.jl index 3e8396fbc..1a20749dd 100644 --- a/test/Experimental/RAGTools/preparation.jl +++ b/test/Experimental/RAGTools/preparation.jl @@ -72,7 +72,7 @@ end @testset "build_index" begin # test with a mock server - PORT = rand(1000:2000) + PORT = rand(9000:11000) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-get", schema = PT.CustomOpenAISchema()) @@ -123,6 +123,20 @@ end @test index.tags == ones(8, 1) @test index.tags_vocab == ["category:::yes"] + ## Test docs reader + index = build_index([text, text]; reader = :docs, sources = ["x", "x"], max_length = 10, + extract_metadata = true, + model_embedding = "mock-emb", + model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)")) + @test index.embeddings == hcat(fill(normalize(ones(Float32, 128)), 8)...) + @test index.chunks[1:4] == index.chunks[5:8] + @test index.sources == fill("x", 8) + @test index.tags == ones(8, 1) + @test index.tags_vocab == ["category:::yes"] + + # Assertion if sources is missing + @test_throws AssertionError build_index([text, text]; reader = :docs) + # clean up close(echo_server) end \ No newline at end of file diff --git a/test/Experimental/RAGTools/retrieval.jl b/test/Experimental/RAGTools/retrieval.jl index fcb9bc819..296904f88 100644 --- a/test/Experimental/RAGTools/retrieval.jl +++ b/test/Experimental/RAGTools/retrieval.jl @@ -1,5 +1,5 @@ using PromptingTools.Experimental.RAGTools: find_closest, find_tags -using PromptingTools.Experimental.RAGTools: Passthrough, rerank +using PromptingTools.Experimental.RAGTools: Passthrough, rerank, CohereRerank @testset "find_closest" begin test_embeddings = [1.0 2.0 -1.0; 3.0 4.0 -3.0; 5.0 6.0 -6.0] |> @@ -22,6 +22,34 @@ using PromptingTools.Experimental.RAGTools: Passthrough, rerank # Test behavior with edge values (top_k == 0) @test find_closest(test_embeddings, query_embedding, top_k = 0) == ([], []) + + ## Test with ChunkIndex + embeddings1 = ones(Float32, 2, 2) + embeddings1[2, 2] = 5.0 + embeddings1 = mapreduce(normalize, hcat, eachcol(embeddings1)) + ci1 = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = embeddings1) + ci2 = ChunkIndex(id = :TestChunkIndex2, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = ones(Float32, 2, 2)) + mi = MultiIndex(id = :multi, indexes = [ci1, ci2]) + + ## find_closest with ChunkIndex + query_emb = [0.5, 0.5] # Example query embedding vector + result = find_closest(ci1, query_emb) + @test result isa CandidateChunks + @test result.positions == [1, 2] + @test all(1.0 .>= result.distances .>= -1.0) # Assuming default minimum_similarity + + ## find_closest with MultiIndex + ## query_emb = [0.5, 0.5] # Example query embedding vector + ## result = find_closest(mi, query_emb) + ## @test result isa CandidateChunks + ## @test result.positions == [1, 2] + ## @test all(1.0 .>= result.distances .>= -1.0) # Assuming default minimum_similarity end @testset "find_tags" begin @@ -60,5 +88,34 @@ end # Passthrough Strategy strategy = Passthrough() - @test rerank(strategy, index, question, candidate_chunks) === candidate_chunks + @test rerank(strategy, index, question, candidate_chunks) == + candidate_chunks + + # Cohere assertion + ci1 = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"]) + ci2 = ChunkIndex(id = :TestChunkIndex2, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"]) + mi = MultiIndex(; id = :multi, indexes = [ci1, ci2]) + @test_throws ArgumentError rerank(CohereRerank(), + mi, + question, + candidate_chunks) + + # Bad top_n + @test_throws AssertionError rerank(CohereRerank(), + ci1, + question, + candidate_chunks; top_n = 0) + + # Bad index_id + cc2 = CandidateChunks(index_id = :TestChunkIndex2, + positions = [1, 2], + distances = [0.3, 0.4]) + @test_throws AssertionError rerank(CohereRerank(), + ci1, + question, + cc2; top_n = 1) end \ No newline at end of file diff --git a/test/Experimental/RAGTools/types.jl b/test/Experimental/RAGTools/types.jl index bfb915919..538a0669e 100644 --- a/test/Experimental/RAGTools/types.jl +++ b/test/Experimental/RAGTools/types.jl @@ -92,6 +92,29 @@ end mi1 = MultiIndex(indexes = [cin1]) mi2 = MultiIndex(indexes = [cin2]) @test mi1 != mi2 + + ## not implemented + @test_throws ArgumentError vcat(mi1, mi2) +end + +@testset "CandidateChunks" begin + chunk_sym = Symbol("TestChunkIndex") + cc1 = CandidateChunks(index_id = chunk_sym, + positions = [1, 3], + distances = [0.1, 0.2]) + @test Base.length(cc1) == 2 + + # Test intersection & + cc2 = CandidateChunks(index_id = chunk_sym, + positions = [2, 4], + distances = [0.3, 0.4]) + @test isempty((cc1 & cc2).positions) + cc3 = CandidateChunks(index_id = chunk_sym, + positions = [1, 4], + distances = [0.3, 0.4]) + joint = (cc1 & cc3) + @test joint.positions == [1] + @test joint.distances == [0.2] end @testset "getindex with CandidateChunks" begin @@ -113,12 +136,21 @@ end positions = [1, 3], distances = [0.1, 0.2]) @test collect(test_chunk_index[candidate_chunks]) == ["First chunk", "Third chunk"] + @test collect(test_chunk_index[candidate_chunks, :chunks]) == + ["First chunk", "Third chunk"] + @test collect(test_chunk_index[candidate_chunks, :sources]) == + ["test_source", "test_source"] + @test collect(test_chunk_index[candidate_chunks, :embeddings]) == + embeddings_data[:, [1, 3]] # Test with empty positions, which should result in an empty array candidate_chunks_empty = CandidateChunks(index_id = chunk_sym, positions = Int[], distances = Float32[]) @test isempty(test_chunk_index[candidate_chunks_empty]) + @test isempty(test_chunk_index[candidate_chunks_empty, :chunks]) + @test isempty(test_chunk_index[candidate_chunks_empty, :embeddings]) + @test isempty(test_chunk_index[candidate_chunks_empty, :sources]) # Test with positions out of bounds, should handle gracefully without errors candidate_chunks_oob = CandidateChunks(index_id = chunk_sym, @@ -151,4 +183,25 @@ end # Test error case when trying to use a non-chunks field, should assert error as only :chunks field is supported @test_throws AssertionError test_chunk_index[candidate_chunks, :nonexistent_field] -end \ No newline at end of file + + # Multi-Candidate CandidateChunks + cc1 = CandidateChunks(index_id = :TestChunkIndex1, + positions = [1, 2], + distances = [0.3, 0.4]) + cc2 = CandidateChunks(index_id = :TestChunkIndex2, + positions = [2], + distances = [0.1]) + cc = CandidateChunks(; index_id = :multi, positions = [cc1, cc2], distances = zeros(2)) + ci1 = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"]) + ci2 = ChunkIndex(id = :TestChunkIndex2, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"]) + @test ci1[cc] == ["chunk1", "chunk2"] + @test ci2[cc] == ["chunk2"] + + # with MultiIndex + mi = MultiIndex(; id = :multi, indexes = [ci1, ci2]) + @test mi[cc] == ["chunk1", "chunk2", "chunk2"] +end diff --git a/test/utils.jl b/test/utils.jl index 04a452d71..789fff4a3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,7 +2,8 @@ using PromptingTools: split_by_length, replace_words using PromptingTools: _extract_handlebar_variables, call_cost, _report_stats using PromptingTools: _string_to_vector, _encode_local_image using PromptingTools: DataMessage, AIMessage -using PromptingTools: push_conversation!, resize_conversation!, @timeout, preview +using PromptingTools: push_conversation!, + resize_conversation!, @timeout, preview, auth_header @testset "replace_words" begin words = ["Disney", "Snow White", "Mickey Mouse"] @@ -226,3 +227,13 @@ end expected_output = Markdown.parse("# System Message\n\nWelcome\n\n---\n\n# User Message\n\nHello\n\n---\n\n# AI Message\n\nWorld\n\n---\n\n# Data Message\n\nData: Vector{Float64} (Size: (10,))\n") @test preview_output == expected_output end + +@testset "auth_header" begin + headers = auth_header("") + @test headers == [ + "Authorization" => "Bearer ", + "Content-Type" => "application/json", + "Accept" => "application/json", + ] + @test_throws ArgumentError auth_header("") +end \ No newline at end of file