Skip to content

Commit

Permalink
add FlashRank.jl package extension
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jun 11, 2024
1 parent 6ec6456 commit 0f45fde
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.29.0]

### Added
- Added package extension for FlashRank.jl to support local ranking models. See `?RT.FlashRanker` for more information or `examples/RAG_with_FlashRank.jl` for a quick example.


## [0.28.0]

### Added
Expand Down
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.28.0"
version = "0.29.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand All @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d"
GoogleGenAI = "903d41d1-eaca-47dd-943b-fee3930375ab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand All @@ -26,6 +27,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[extensions]
FlashRankPromptingToolsExt = ["FlashRank"]
GoogleGenAIPromptingToolsExt = ["GoogleGenAI"]
MarkdownPromptingToolsExt = ["Markdown"]
RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra", "Unicode"]
Expand All @@ -36,6 +38,7 @@ AbstractTrees = "0.4"
Aqua = "0.7"
Base64 = "<0.0.1, 1"
Dates = "<0.0.1, 1"
FlashRank = "0.2"
GoogleGenAI = "0.3"
HTTP = "1"
JSON3 = "1"
Expand All @@ -59,4 +62,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[targets]
test = ["Aqua", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"]
test = ["Aqua", "FlashRank", "SparseArrays", "Statistics", "LinearAlgebra", "Markdown", "Snowball"]
3 changes: 2 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d"
GoogleGenAI = "903d41d1-eaca-47dd-943b-fee3930375ab"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Expand All @@ -14,4 +15,4 @@ Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
DocumenterVitepress = "0.0.7"
DocumenterVitepress = "0.0.7"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Documenter, DocumenterVitepress
using PromptingTools
const PT = PromptingTools
using SparseArrays, LinearAlgebra, Markdown
using SparseArrays, LinearAlgebra, Markdown, Unicode, FlashRank
using PromptingTools.Experimental.RAGTools
using PromptingTools.Experimental.AgentTools
using JSON3, Serialization, DataFramesMeta
Expand Down
57 changes: 57 additions & 0 deletions examples/RAG_with_FlashRank.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# # RAG with FlashRank.jl

# This file contains examples of how to use FlashRank rankers.
#
# First, let's import the package and define a helper link for calling un-exported functions:
using LinearAlgebra, SparseArrays, Unicode # imports required for full PT functionality
using FlashRank
using PromptingTools
const PT = PromptingTools
using PromptingTools.Experimental.RAGTools
const RT = PromptingTools.Experimental.RAGTools

# Enable model downloading, otherwise you always have to approve it
# see https://www.oxinabox.net/DataDeps.jl/dev/z10-for-end-users/
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

## Sample data
sentences = [
"Search for the latest advancements in quantum computing using Julia language.",
"How to implement machine learning algorithms in Julia with examples.",
"Looking for performance comparison between Julia, Python, and R for data analysis.",
"Find Julia language tutorials focusing on high-performance scientific computing.",
"Search for the top Julia language packages for data visualization and their documentation.",
"How to set up a Julia development environment on Windows 10.",
"Discover the best practices for parallel computing in Julia.",
"Search for case studies of large-scale data processing using Julia.",
"Find comprehensive resources for mastering metaprogramming in Julia.",
"Looking for articles on the advantages of using Julia for statistical modeling.",
"How to contribute to the Julia open-source community: A step-by-step guide.",
"Find the comparison of numerical accuracy between Julia and MATLAB.",
"Looking for the latest Julia language updates and their impact on AI research.",
"How to efficiently handle big data with Julia: Techniques and libraries.",
"Discover how Julia integrates with other programming languages and tools.",
"Search for Julia-based frameworks for developing web applications.",
"Find tutorials on creating interactive dashboards with Julia.",
"How to use Julia for natural language processing and text analysis.",
"Discover the role of Julia in the future of computational finance and econometrics."
]
## Build the index
index = build_index(
sentences; chunker_kwargs = (; sources = map(i -> "Doc$i", 1:length(sentences))))

# Wrap the model to be a valid Ranker recognized by RAGTools (FlashRanker is the dedicated type)
# It will be provided to the airag/rerank function to avoid instantiating it on every call
reranker = RankerModel(:mini) |> RT.FlashRanker
# You can choose :tiny or :mini

## Apply to the pipeline configuration, eg,
cfg = RAGConfig(; retriever = AdvancedRetriever(; reranker))

# Ask a question
question = "What are the best practices for parallel computing in Julia?"
result = airag(cfg, index; question, return_all = true)

# Review the reranking step results
result.reranked_candidates
index[result.reranked_candidates]
82 changes: 82 additions & 0 deletions ext/FlashRankPromptingToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

module FlashRankPromptingToolsExt

using PromptingTools
const PT = PromptingTools
using PromptingTools.Experimental.RAGTools
const RT = PromptingTools.Experimental.RAGTools
using FlashRank

# Define the method for reranking with it
"""
RT.rerank(
reranker::RT.FlashRanker, index::RT.AbstractDocumentIndex, question::AbstractString,
candidates::RT.AbstractCandidateChunks;
verbose::Bool = false,
top_n::Integer = length(candidates.scores),
kwargs...)
Re-ranks a list of candidate chunks using the FlashRank.jl local models.
# Arguments
- `reranker`: FlashRanker model to use (wrapper for `FlashRank.RankerModel`)
- `index`: The index that holds the underlying chunks to be re-ranked.
- `question`: The query to be used for the search.
- `candidates`: The candidate chunks to be re-ranked.
- `top_n`: The number of most relevant documents to return. Default is `length(documents)`.
- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`.
# Example
How to use FlashRank models in your RAG pipeline:
```julia
using FlashRank
# Wrap the model to be a valid Ranker recognized by RAGTools (FlashRanker is the dedicated type)
# It will be provided to the airag/rerank function to avoid instantiating it on every call
reranker = RankerModel(:mini) |> RT.FlashRanker
# You can choose :tiny or :mini
## Apply to the pipeline configuration, eg,
cfg = RAGConfig(; retriever = AdvancedRetriever(; reranker))
# Ask a question
question = "What are the best practices for parallel computing in Julia?"
result = airag(cfg, index; question, return_all = true)
# Review the reranking step results
result.reranked_candidates
index[result.reranked_candidates]
```
"""
function RT.rerank(
reranker::RT.FlashRanker, index::RT.AbstractDocumentIndex, question::AbstractString,
candidates::RT.AbstractCandidateChunks;
verbose::Bool = false,
top_n::Integer = length(candidates.scores),
kwargs...)
@assert top_n>0 "top_n must be a positive integer."
documents = index[candidates, :chunks]
@assert !(isempty(documents)) "The candidate chunks must not be empty for Cohere Reranker! Check the index IDs."

## Run re-ranker
ranker = reranker.model
result = ranker(question, documents; top_n)

## Unwrap re-ranked positions
scores = result.scores
positions = candidates.positions[result.positions]
index_ids = if candidates isa RT.MultiCandidateChunks
candidates.index_ids[result.positions]
else
candidates.index_id
end

verbose && @info "Reranking done in $(round(res.elapsed; digits=1)) seconds."

return candidates isa RT.MultiCandidateChunks ?
RT.MultiCandidateChunks(index_ids, positions, scores) :
RT.CandidateChunks(index_ids, positions, scores)
end

end #end of module
3 changes: 3 additions & 0 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ To customize the components, replace corresponding fields for each step of the R
retriever::AbstractRetriever = SimpleRetriever()
generator::AbstractGenerator = SimpleGenerator()
end
function Base.show(io::IO, cfg::AbstractRAGConfig)
dump(io, cfg; maxdepth = 2)
end

"""
airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
Expand Down
30 changes: 30 additions & 0 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,36 @@ Rerank strategy using the Cohere Rerank API. Requires an API key.
"""
struct CohereReranker <: AbstractReranker end

"""
FlashRanker <: AbstractReranker
Rerank strategy using the package FlashRank.jl and local models.
You must first import the FlashRank.jl package.
To automatically download any required models, set your
`ENV["DATADEPS_ALWAYS_ACCEPT"] = true` (see [DataDeps](https://www.oxinabox.net/DataDeps.jl/dev/z10-for-end-users/) for more details).
# Example
```julia
using FlashRank
# Wrap the model to be a valid Ranker recognized by RAGTools
# It will be provided to the airag/rerank function to avoid instantiating it on every call
reranker = FlashRank.RankerModel(:mini) |> FlashRanker
# You can choose :tiny or :mini
## Apply to the pipeline configuration, eg,
cfg = RAGConfig(; retriever = AdvancedRetriever(; reranker))
# Ask a question (assumes you have some `index`)
question = "What are the best practices for parallel computing in Julia?"
result = airag(cfg, index; question, return_all = true)
```
"""
struct FlashRanker{T} <: AbstractReranker
model::T
end

function rerank(reranker::AbstractReranker,
index::AbstractDocumentIndex, question::AbstractString, candidates::AbstractCandidateChunks; kwargs...)
throw(ArgumentError("Not implemented yet"))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Statistics
using Dates: now
using Test, Pkg, Random
const PT = PromptingTools
using Snowball
using Snowball, FlashRank
using Aqua

@testset "Code quality (Aqua.jl)" begin
Expand Down

0 comments on commit 0f45fde

Please sign in to comment.