Skip to content

Commit

Permalink
Refactor RAGTools (#56)
Browse files Browse the repository at this point in the history
- Split `Experimental.RAGTools.build_index` into smaller functions to easier sharing with other packages (`get_chunks`, `get_embeddings`, `get_metadata`)
- Added support for Cohere-based RAG re-ranking strategy (and introduced associated `COHERE_API_KEY` global variable and ENV variable)
  • Loading branch information
svilupp authored Jan 22, 2024
1 parent 68cbd32 commit 20a214c
Show file tree
Hide file tree
Showing 16 changed files with 598 additions and 92 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.9.0]

### Added
- Split `Experimental.RAGTools.build_index` into smaller functions to easier sharing with other packages (`get_chunks`, `get_embeddings`, `get_metadata`)
- Added support for Cohere-based RAG re-ranking strategy (and introduced associated `COHERE_API_KEY` global variable and ENV variable)

### Fixed

## [0.8.1]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.8.1"
version = "0.9.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down
5 changes: 4 additions & 1 deletion src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ This module is experimental and may change at any time. It is intended to be mov
module RAGTools

using PromptingTools
using JSON3
using HTTP, JSON3
const PT = PromptingTools

include("utils.jl")

# eg, cohere_api
include("api_services.jl")

export ChunkIndex, CandidateChunks # MultiIndex
include("types.jl")

Expand Down
36 changes: 36 additions & 0 deletions src/Experimental/RAGTools/api_services.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
cohere_api(;
api_key::AbstractString,
endpoint::String,
url::AbstractString="https://api.cohere.ai/v1",
http_kwargs::NamedTuple=NamedTuple(),
kwargs...)
Lightweight wrapper around the Cohere API. See https://cohere.com/docs for more details.
# Arguments
- `api_key`: Your Cohere API key. You can get one from https://dashboard.cohere.com/welcome/register (trial access is for free).
- `endpoint`: The Cohere endpoint to call.
- `url`: The base URL for the Cohere API. Default is `https://api.cohere.ai/v1`.
- `http_kwargs`: Any additional keyword arguments to pass to `HTTP.post`.
- `kwargs`: Any additional keyword arguments to pass to the Cohere API.
"""
function cohere_api(;
api_key::AbstractString,
endpoint::String,
url::AbstractString = "https://api.cohere.ai/v1",
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
@assert endpoint in ["chat", "generate", "embed", "rerank", "classify"] "Only 'chat', 'generate',`embed`,`rerank`,`classify` Cohere endpoints are supported."
@assert !isempty(api_key) "Cohere `api_key` cannot be empty. Check `PT.COHERE_API_KEY` or pass it as a keyword argument."
##
input_body = Dict(kwargs...)

# https://api.cohere.ai/v1/rerank
api_url = string(url, "/", endpoint)
resp = HTTP.post(api_url,
PT.auth_header(api_key),
JSON3.write(input_body); http_kwargs...)
body = JSON3.read(resp.body)
return (; response = body)
end
16 changes: 12 additions & 4 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ end
"""
airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext;
question::AbstractString,
top_k::Int = 3, `minimum_similarity::AbstractFloat`= -1.0,
top_k::Int = 100, top_n::Int = 5, minimum_similarity::AbstractFloat = -1.0,
tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto,
rerank_strategy::RerankingStrategy = Passthrough(),
model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT,
model_metadata::String = PT.MODEL_CHAT,
metadata_template::Symbol = :RAGExtractMetadataShort,
chunks_window_margin::Tuple{Int, Int} = (1, 1),
return_context::Bool = false, verbose::Bool = true,
rerank_kwargs::NamedTuple = NamedTuple(),
api_kwargs::NamedTuple = NamedTuple(),
kwargs...)
Expand All @@ -59,9 +60,10 @@ The function selects relevant chunks from an `ChunkIndex`, optionally filters th
- `rag_template::Symbol`: Template for the RAG model, defaults to `:RAGAnswerFromContext`.
- `question::AbstractString`: The question to be answered.
- `top_k::Int`: Number of top candidates to retrieve based on embedding similarity.
- `top_n::Int`: Number of candidates to return after reranking.
- `minimum_similarity::AbstractFloat`: Minimum similarity threshold (between -1 and 1) for filtering chunks based on embedding similarity. Defaults to -1.0.
- `tag_filter::Union{Symbol, Vector{String}, Regex}`: Mechanism for filtering chunks based on tags (either automatically detected, specific tags, or a regex pattern). Disabled by setting to `nothing`.
- `rerank_strategy::RerankingStrategy`: Strategy for reranking the retrieved chunks.
- `rerank_strategy::RerankingStrategy`: Strategy for reranking the retrieved chunks. Defaults to `Passthrough()`. Use `CohereRerank` for better results (requires `COHERE_API_KEY` to be set)
- `model_embedding::String`: Model used for embedding the question, default is `PT.MODEL_EMBEDDING`.
- `model_chat::String`: Model used for generating the final response, default is `PT.MODEL_CHAT`.
- `model_metadata::String`: Model used for extracting metadata, default is `PT.MODEL_CHAT`.
Expand Down Expand Up @@ -97,14 +99,15 @@ See also `build_index`, `build_context`, `CandidateChunks`, `find_closest`, `fin
"""
function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext;
question::AbstractString,
top_k::Int = 3, minimum_similarity::AbstractFloat = -1.0,
top_k::Int = 100, top_n::Int = 5, minimum_similarity::AbstractFloat = -1.0,
tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto,
rerank_strategy::RerankingStrategy = Passthrough(),
model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT,
model_metadata::String = PT.MODEL_CHAT,
metadata_template::Symbol = :RAGExtractMetadataShort,
chunks_window_margin::Tuple{Int, Int} = (1, 1),
return_context::Bool = false, verbose::Bool = true,
rerank_kwargs::NamedTuple = NamedTuple(),
api_kwargs::NamedTuple = NamedTuple(),
kwargs...)
## Note: Supports only single ChunkIndex for now
Expand Down Expand Up @@ -148,7 +151,12 @@ function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromC

filtered_candidates = isnothing(tag_candidates) ? emb_candidates :
(emb_candidates & tag_candidates)
reranked_candidates = rerank(rerank_strategy, index, question, filtered_candidates)
reranked_candidates = rerank(rerank_strategy,
index,
question,
filtered_candidates;
top_n,
verbose = false, rerank_kwargs...)

## Build the context
context = build_context(index, reranked_candidates; chunks_window_margin)
Expand Down
Loading

0 comments on commit 20a214c

Please sign in to comment.