diff --git a/CHANGELOG.md b/CHANGELOG.md index fc7c2549b..fdd1e2819 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.29.0] + +### Added +- Added package extension for FlashRank.jl to support local ranking models. See `?RT.FlashRanker` for more information or `examples/RAG_with_FlashRank.jl` for a quick example. + + ## [0.28.0] ### Added diff --git a/Project.toml b/Project.toml index 2453a165d..b1fa827b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.28.0" +version = "0.29.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] +FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d" GoogleGenAI = "903d41d1-eaca-47dd-943b-fee3930375ab" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -26,6 +27,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [extensions] +FlashRankPromptingToolsExt = ["FlashRank"] GoogleGenAIPromptingToolsExt = ["GoogleGenAI"] MarkdownPromptingToolsExt = ["Markdown"] RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra", "Unicode"] @@ -36,6 +38,7 @@ AbstractTrees = "0.4" Aqua = "0.7" Base64 = "<0.0.1, 1" Dates = "<0.0.1, 1" +FlashRank = "0.2" GoogleGenAI = "0.3" HTTP = "1" JSON3 = "1" @@ -59,4 +62,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [targets] -test = ["Aqua", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"] +test = ["Aqua", "FlashRank", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"] diff --git a/docs/Project.toml b/docs/Project.toml index 518379d4f..0995d35f8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,7 @@ DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" +FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d" GoogleGenAI = "903d41d1-eaca-47dd-943b-fee3930375ab" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" @@ -14,4 +15,4 @@ Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] -DocumenterVitepress = "0.0.7" \ No newline at end of file +DocumenterVitepress = "0.0.7" diff --git a/docs/make.jl b/docs/make.jl index 29339821c..008b4c8bb 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,7 @@ using Documenter, DocumenterVitepress using PromptingTools const PT = PromptingTools -using SparseArrays, LinearAlgebra, Markdown +using SparseArrays, LinearAlgebra, Markdown, Unicode, FlashRank using PromptingTools.Experimental.RAGTools using PromptingTools.Experimental.AgentTools using JSON3, Serialization, DataFramesMeta diff --git a/examples/RAG_with_FlashRank.jl b/examples/RAG_with_FlashRank.jl new file mode 100644 index 000000000..634f2dd93 --- /dev/null +++ b/examples/RAG_with_FlashRank.jl @@ -0,0 +1,57 @@ +# # RAG with FlashRank.jl + +# This file contains examples of how to use FlashRank rankers. +# +# First, let's import the package and define a helper link for calling un-exported functions: +using LinearAlgebra, SparseArrays, Unicode # imports required for full PT functionality +using FlashRank +using PromptingTools +const PT = PromptingTools +using PromptingTools.Experimental.RAGTools +const RT = PromptingTools.Experimental.RAGTools + +# Enable model downloading, otherwise you always have to approve it +# see https://www.oxinabox.net/DataDeps.jl/dev/z10-for-end-users/ +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + +## Sample data +sentences = [ + "Search for the latest advancements in quantum computing using Julia language.", + "How to implement machine learning algorithms in Julia with examples.", + "Looking for performance comparison between Julia, Python, and R for data analysis.", + "Find Julia language tutorials focusing on high-performance scientific computing.", + "Search for the top Julia language packages for data visualization and their documentation.", + "How to set up a Julia development environment on Windows 10.", + "Discover the best practices for parallel computing in Julia.", + "Search for case studies of large-scale data processing using Julia.", + "Find comprehensive resources for mastering metaprogramming in Julia.", + "Looking for articles on the advantages of using Julia for statistical modeling.", + "How to contribute to the Julia open-source community: A step-by-step guide.", + "Find the comparison of numerical accuracy between Julia and MATLAB.", + "Looking for the latest Julia language updates and their impact on AI research.", + "How to efficiently handle big data with Julia: Techniques and libraries.", + "Discover how Julia integrates with other programming languages and tools.", + "Search for Julia-based frameworks for developing web applications.", + "Find tutorials on creating interactive dashboards with Julia.", + "How to use Julia for natural language processing and text analysis.", + "Discover the role of Julia in the future of computational finance and econometrics." +] +## Build the index +index = build_index( + sentences; chunker_kwargs = (; sources = map(i -> "Doc$i", 1:length(sentences)))) + +# Wrap the model to be a valid Ranker recognized by RAGTools (FlashRanker is the dedicated type) +# It will be provided to the airag/rerank function to avoid instantiating it on every call +reranker = RankerModel(:mini) |> RT.FlashRanker +# You can choose :tiny or :mini + +## Apply to the pipeline configuration, eg, +cfg = RAGConfig(; retriever = AdvancedRetriever(; reranker)) + +# Ask a question +question = "What are the best practices for parallel computing in Julia?" +result = airag(cfg, index; question, return_all = true) + +# Review the reranking step results +result.reranked_candidates +index[result.reranked_candidates] \ No newline at end of file diff --git a/ext/FlashRankPromptingToolsExt.jl b/ext/FlashRankPromptingToolsExt.jl new file mode 100644 index 000000000..3741c311a --- /dev/null +++ b/ext/FlashRankPromptingToolsExt.jl @@ -0,0 +1,82 @@ + +module FlashRankPromptingToolsExt + +using PromptingTools +const PT = PromptingTools +using PromptingTools.Experimental.RAGTools +const RT = PromptingTools.Experimental.RAGTools +using FlashRank + +# Define the method for reranking with it +""" + RT.rerank( + reranker::RT.FlashRanker, index::RT.AbstractDocumentIndex, question::AbstractString, + candidates::RT.AbstractCandidateChunks; + verbose::Bool = false, + top_n::Integer = length(candidates.scores), + kwargs...) + +Re-ranks a list of candidate chunks using the FlashRank.jl local models. + +# Arguments +- `reranker`: FlashRanker model to use (wrapper for `FlashRank.RankerModel`) +- `index`: The index that holds the underlying chunks to be re-ranked. +- `question`: The query to be used for the search. +- `candidates`: The candidate chunks to be re-ranked. +- `top_n`: The number of most relevant documents to return. Default is `length(documents)`. +- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`. + +# Example + +How to use FlashRank models in your RAG pipeline: +```julia +using FlashRank + +# Wrap the model to be a valid Ranker recognized by RAGTools (FlashRanker is the dedicated type) +# It will be provided to the airag/rerank function to avoid instantiating it on every call +reranker = RankerModel(:mini) |> RT.FlashRanker +# You can choose :tiny or :mini + +## Apply to the pipeline configuration, eg, +cfg = RAGConfig(; retriever = AdvancedRetriever(; reranker)) + +# Ask a question +question = "What are the best practices for parallel computing in Julia?" +result = airag(cfg, index; question, return_all = true) + +# Review the reranking step results +result.reranked_candidates +index[result.reranked_candidates] +``` +""" +function RT.rerank( + reranker::RT.FlashRanker, index::RT.AbstractDocumentIndex, question::AbstractString, + candidates::RT.AbstractCandidateChunks; + verbose::Bool = false, + top_n::Integer = length(candidates.scores), + kwargs...) + @assert top_n>0 "top_n must be a positive integer." + documents = index[candidates, :chunks] + @assert !(isempty(documents)) "The candidate chunks must not be empty for Cohere Reranker! Check the index IDs." + + ## Run re-ranker + ranker = reranker.model + result = ranker(question, documents; top_n) + + ## Unwrap re-ranked positions + scores = result.scores + positions = candidates.positions[result.positions] + index_ids = if candidates isa RT.MultiCandidateChunks + candidates.index_ids[result.positions] + else + candidates.index_id + end + + verbose && @info "Reranking done in $(round(res.elapsed; digits=1)) seconds." + + return candidates isa RT.MultiCandidateChunks ? + RT.MultiCandidateChunks(index_ids, positions, scores) : + RT.CandidateChunks(index_ids, positions, scores) +end + +end #end of module \ No newline at end of file diff --git a/src/Experimental/RAGTools/generation.jl b/src/Experimental/RAGTools/generation.jl index 81549ff4e..4ee7ac229 100644 --- a/src/Experimental/RAGTools/generation.jl +++ b/src/Experimental/RAGTools/generation.jl @@ -520,6 +520,9 @@ To customize the components, replace corresponding fields for each step of the R retriever::AbstractRetriever = SimpleRetriever() generator::AbstractGenerator = SimpleGenerator() end +function Base.show(io::IO, cfg::AbstractRAGConfig) + dump(io, cfg; maxdepth = 2) +end """ airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex; diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 6fcf7a2a0..976429a88 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -575,6 +575,36 @@ Rerank strategy using the Cohere Rerank API. Requires an API key. """ struct CohereReranker <: AbstractReranker end +""" + FlashRanker <: AbstractReranker + +Rerank strategy using the package FlashRank.jl and local models. + +You must first import the FlashRank.jl package. +To automatically download any required models, set your +`ENV["DATADEPS_ALWAYS_ACCEPT"] = true` (see [DataDeps](https://www.oxinabox.net/DataDeps.jl/dev/z10-for-end-users/) for more details). + +# Example +```julia +using FlashRank + +# Wrap the model to be a valid Ranker recognized by RAGTools +# It will be provided to the airag/rerank function to avoid instantiating it on every call +reranker = FlashRank.RankerModel(:mini) |> FlashRanker +# You can choose :tiny or :mini + +## Apply to the pipeline configuration, eg, +cfg = RAGConfig(; retriever = AdvancedRetriever(; reranker)) + +# Ask a question (assumes you have some `index`) +question = "What are the best practices for parallel computing in Julia?" +result = airag(cfg, index; question, return_all = true) +``` +""" +struct FlashRanker{T} <: AbstractReranker + model::T +end + function rerank(reranker::AbstractReranker, index::AbstractDocumentIndex, question::AbstractString, candidates::AbstractCandidateChunks; kwargs...) throw(ArgumentError("Not implemented yet")) diff --git a/test/runtests.jl b/test/runtests.jl index c4d0abe32..229c9d5a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using Statistics using Dates: now using Test, Pkg, Random const PT = PromptingTools -using Snowball +using Snowball, FlashRank using Aqua @testset "Code quality (Aqua.jl)" begin