Skip to content

Commit

Permalink
Update RAG performance
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jun 18, 2024
1 parent f3e3994 commit 223107f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 61 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.31.0]

### Breaking Changes
- The return type of `RAGTools.find_tags(::NoTagger,...)` is now `::Nothing` instead of `CandidateChunks`/`MultiCandidateChunks` with all documents.
- `Base.getindex(::MultiIndex, ::MultiCandidateChunks)` now always returns sorted chunks for consistency with the behavior of other `getindex` methods on `*Chunks`.

### Updated
- Cosine similarity search now uses `partialsortperm` for better performance on large datasets.
- Skip unnecessary work when the tagging functionality in the RAG pipeline is disabled (`find_tags` with `NoTagger` always returns `nothing` which improves the compiled code).
- Changed the default behavior of `getindex(::MultiIndex, ::MultiCandidateChunks)` to always return sorted chunks for consistency with other similar functions. Note that you should always use re-rankering anyway (see `FlashRank.jl`).

## [0.30.0]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.30.0"
version = "0.31.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
59 changes: 25 additions & 34 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,14 @@ function find_closest(
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
# emb is an embedding matrix where the first dimension is the embedding dimension
scores = query_emb' * emb |> vec
positions = scores |> sortperm |> x -> last(x, top_k) |> reverse
top_k_min = min(top_k, length(scores))
positions = partialsortperm(scores, 1:top_k_min, rev = true)
if minimum_similarity > -1.0
mask = scores[positions] .>= minimum_similarity
positions = positions[mask]
else
## we want to materialize the view
positions = collect(positions)
end
return positions, scores[positions]
end
Expand Down Expand Up @@ -374,7 +378,8 @@ function find_closest(
## First pass, both in binary with Hamming, get rescore_multiplier times top_k
binary_query_emb = map(>(0), query_emb)
scores = hamming_distance(emb, binary_query_emb)
positions = scores |> sortperm |> x -> first(x, top_k * rescore_multiplier)
num_candidates = min(top_k * rescore_multiplier, length(scores))
positions = partialsortperm(scores, 1:num_candidates)

## Second pass, rescore with float embeddings and return top_k
new_positions, scores = find_closest(CosineSimilarity(), @view(emb[:, positions]),
Expand Down Expand Up @@ -415,7 +420,8 @@ function find_closest(
## First pass, both in binary with Hamming, get rescore_multiplier times top_k
bit_query_emb = pack_bits(query_emb .> 0)
scores = hamming_distance(emb, bit_query_emb)
positions = scores |> sortperm |> x -> first(x, top_k * rescore_multiplier)
num_candidates = min(top_k * rescore_multiplier, length(scores))
positions = partialsortperm(scores, 1:num_candidates)

## Second pass, rescore with float embeddings and return top_k
unpacked_emb = unpack_bits(@view(emb[:, positions]))
Expand All @@ -442,11 +448,15 @@ function find_closest(
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
bm_scores = bm25(dtm, query_tokens)
positions = bm_scores |> sortperm |> x -> last(x, top_k) |> reverse
top_k_min = min(top_k, length(bm_scores))
positions = partialsortperm(bm_scores, 1:top_k_min, rev = true)

if minimum_similarity > -1.0
mask = scores[positions] .>= minimum_similarity
positions = positions[mask]
else
# materialize the vector
positions = positions |> collect
end
return positions, bm_scores[positions]
end
Expand All @@ -455,7 +465,8 @@ end

function find_tags(::AbstractTagFilter, index::AbstractDocumentIndex,
tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{AbstractString, Regex}}
Union{
AbstractString, Regex, Nothing}}
throw(ArgumentError("Not implemented yet for type $(typeof(filter)) and index $(typeof(index))"))
end

Expand Down Expand Up @@ -492,24 +503,19 @@ end

"""
find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex, Nothing}}
tags; kwargs...)
Returns all chunks in the index, ie, no filtering.
Returns all chunks in the index, ie, no filtering, so we simply return `nothing` (easier for dispatch).
"""
function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex}}
return CandidateChunks(
index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks)))
AbstractString, Regex, Nothing}}
return nothing
end

function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Nothing; kwargs...)
return CandidateChunks(
index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks)))
end

## Multi-index implementation
function find_tags(method::AnyTagFilter, index::AbstractMultiIndex,
tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Expand Down Expand Up @@ -539,24 +545,8 @@ end
function find_tags(method::NoTagFilter, index::AbstractMultiIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{
AbstractString, Regex}}
indexes_ = indexes(index)
length_ = sum(x -> length(x.chunks), indexes_)
index_ids = [fill(x.id, length(x.chunks)) for x in indexes_] |> Base.Splat(vcat)

return MultiCandidateChunks(
index_ids, collect(1:length_),
zeros(Float32, length_))
end
function find_tags(method::NoTagFilter, index::AbstractMultiIndex,
tags::Nothing; kwargs...)
indexes_ = indexes(index)
length_ = sum(x -> length(x.chunks), indexes_)
index_ids = [fill(x.id, length(x.chunks)) for x in indexes_] |> Base.Splat(vcat)

return MultiCandidateChunks(
index_ids, collect(1:length_),
zeros(Float32, length_))
AbstractString, Regex, Nothing}}
return nothing
end

### Reranking
Expand Down Expand Up @@ -966,6 +956,7 @@ function retrieve(retriever::AbstractRetriever,
filter, index, tags; verbose = (verbose > 1), filter_kwargs_...)

## Combine the two sets of candidates, looks for intersection (hard filter)!
# With tagger=NoTagger() get_tags returns `nothing` find_tags simply passes it through to skip the intersection
filtered_candidates = isnothing(tag_candidates) ? emb_candidates :
(emb_candidates & tag_candidates)
## TODO: Future implementation should be to apply tag filtering BEFORE the find_closest,
Expand Down
7 changes: 4 additions & 3 deletions src/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ function Base.getindex(ci::AbstractDocumentIndex,
end
function Base.getindex(ci::AbstractChunkIndex,
candidate::CandidateChunks{TP, TD},
field::Symbol = :chunks) where {TP <: Integer, TD <: Real}
field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real}
@assert field in [:chunks, :embeddings, :chunkdata, :sources] "Only `chunks`, `embeddings`, `chunkdata`, `sources` fields are supported for now"
field = field == :embeddings ? :chunkdata : field
len_ = length(chunks(ci))
Expand All @@ -504,7 +504,8 @@ function Base.getindex(ci::AbstractChunkIndex,
end
function Base.getindex(mi::MultiIndex,
candidate::CandidateChunks{TP, TD},
field::Symbol = :chunks) where {TP <: Integer, TD <: Real}
field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real}
## Always sorted!
@assert field in [:chunks, :sources] "Only `chunks`, `sources` fields are supported for now"
valid_index = findfirst(x -> x.id == candidate.index_id, indexes(mi))
if isnothing(valid_index) && field == :chunks
Expand Down Expand Up @@ -549,7 +550,7 @@ end
# Getindex on Multiindex, pool the individual hits
function Base.getindex(mi::MultiIndex,
candidate::MultiCandidateChunks{TP, TD},
field::Symbol = :chunks; sorted::Bool = false) where {TP <: Integer, TD <: Real}
field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real}
@assert field in [:chunks, :sources, :scores] "Only `chunks`, `sources`, and `scores` fields are supported for now"
if sorted
# values can be either of chunks or sources
Expand Down
3 changes: 2 additions & 1 deletion test/Experimental/RAGTools/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ end
embeddings = zeros(128, 3),
tags = vcat(trues(2, 2), falses(1, 2)),
tags_vocab = ["yes", "no"])
index.embeddings[1, 1] = 1

# Test for successful Q&A extraction from document chunks
qa_evals = build_qa_evals(chunks(index),
Expand Down Expand Up @@ -193,7 +194,7 @@ end
api_kwargs = (; url = "http://localhost:$(PORT)"),
parameters_dict = Dict(:key1 => "value1", :key2 => 2))
@test result.retrieval_score == 1.0
@test result.retrieval_rank == 2
@test result.retrieval_rank == 1
@test result.answer_score == 5
@test result.parameters == Dict(:key1 => "value1", :key2 => 2)

Expand Down
74 changes: 53 additions & 21 deletions test/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,14 @@ end

# No filter tag -- give everything
cc = find_tags(NoTagFilter(), index, "julia")
@test cc.positions == [1, 2]
@test cc.scores == [0.0, 0.0]
@test isnothing(cc)
# @test cc.positions == [1, 2]
# @test cc.scores == [0.0, 0.0]

cc = find_tags(NoTagFilter(), index, nothing)
@test cc.positions == [1, 2]
@test cc.scores == [0.0, 0.0]
@test isnothing(cc)
# @test cc.positions == [1, 2]
# @test cc.scores == [0.0, 0.0]

# Unknown type
struct RandomTagFilter123 <: AbstractTagFilter end
Expand All @@ -456,12 +458,14 @@ end
multi_index = MultiIndex(id = :multi, indexes = [index1, index2])

mcc = find_tags(NoTagFilter(), multi_index, "julia")
@test mcc.positions == [1, 2, 3, 4]
@test mcc.scores == [0.0, 0.0, 0.0, 0.0]
@test mcc == nothing
# @test mcc.positions == [1, 2, 3, 4]
# @test mcc.scores == [0.0, 0.0, 0.0, 0.0]

mcc = find_tags(NoTagFilter(), multi_index, nothing)
@test mcc.positions == [1, 2, 3, 4]
@test mcc.scores == [0.0, 0.0, 0.0, 0.0]
@test mcc == nothing
# @test mcc.positions == [1, 2, 3, 4]
# @test mcc.scores == [0.0, 0.0, 0.0, 0.0]

multi_index2 = MultiIndex(id = :multi2, indexes = [index, index2])
mcc2 = find_tags(AnyTagFilter(), multi_index2, "julia")
Expand Down Expand Up @@ -538,7 +542,7 @@ end

@testset "retrieve" begin
# test with a mock server
PORT = rand(20000:40000)
PORT = rand(20000:40001)
PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema())
PT.register_model!(; name = "mock-emb2", schema = PT.CustomOpenAISchema())
PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema())
Expand Down Expand Up @@ -609,8 +613,16 @@ end
@test result.rephrased_questions == [question]
@test result.answer == nothing
@test result.final_answer == nothing
@test result.reranked_candidates.positions == [2, 1, 4, 3]
@test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"]
## there are two equivalent orderings
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])
@test Set(result.reranked_candidates.positions[3:4]) == Set([3, 4])
@test result.reranked_candidates.scores[1:2] == ones(2)
@test length(result.context) == 4
@test length(unique(result.context)) == 4
@test result.context[1] in ["chunk2", "chunk1"]
@test result.context[2] in ["chunk2", "chunk1"]
@test result.context[3] in ["chunk3", "chunk4"]
@test result.context[4] in ["chunk3", "chunk4"]
@test result.sources isa Vector{String}

# Reduce number of candidates
Expand All @@ -620,8 +632,10 @@ end
embedder_kwargs = (; model = "mock-emb"),
tagger_kwargs = (; model = "mock-meta"), api_kwargs = (;
url = "http://localhost:$(PORT)"))
@test result.emb_candidates.positions == [2, 1, 4]
@test result.reranked_candidates.positions == [2, 1]
## the last item is 3 or 4
@test result.emb_candidates.positions[3] in [3, 4]
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])
@test result.emb_candidates.scores[1:2] == ones(2)

# with default dispatch
result = retrieve(index, question;
Expand All @@ -630,8 +644,9 @@ end
embedder_kwargs = (; model = "mock-emb"),
tagger_kwargs = (; model = "mock-meta"), api_kwargs = (;
url = "http://localhost:$(PORT)"))
@test result.emb_candidates.positions == [2, 1, 4]
@test result.reranked_candidates.positions == [2, 1]
@test result.emb_candidates.positions[3] in [3, 4]
@test result.emb_candidates.scores[1:2] == ones(2)
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])

## AdvancedRetriever
adv = AdvancedRetriever()
Expand All @@ -645,20 +660,29 @@ end
@test result.rephrased_questions == [question, "Query: test question\n\nPassage:"] # from the template we use
@test result.answer == nothing
@test result.final_answer == nothing
@test result.reranked_candidates.positions == [2, 1, 4, 3]
@test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"]
## there are two equivalent orderings
@test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1])
@test Set(result.reranked_candidates.positions[3:4]) == Set([3, 4])
@test result.reranked_candidates.scores[1:2] == ones(2)
@test length(result.context) == 4
@test length(unique(result.context)) == 4
@test result.context[1] in ["chunk2", "chunk1"]
@test result.context[2] in ["chunk2", "chunk1"]
@test result.context[3] in ["chunk3", "chunk4"]
@test result.context[4] in ["chunk3", "chunk4"]
@test result.sources isa Vector{String}

# Multi-index retriever
index_keywords = ChunkKeywordsIndex(index, index_id = :TestChunkIndexX)
index_keywords = ChunkIndex(; id = :AA, index.chunks, index.sources, index.embeddings)
# Create MultiIndex instance
multi_index = MultiIndex(id = :multi, indexes = [index, index_keywords])

# Create MultiFinder instance
finder = MultiFinder([RT.CosineSimilarity(), RT.BM25Similarity()])

retriever = SimpleRetriever(; processor = RT.KeywordsProcessor(), finder)
result = retrieve(retriever, multi_index, question;
result = retrieve(SimpleRetriever(), multi_index, question;
reranker = NoReranker(), # we need to disable cohere as we cannot test it
rephraser_kwargs = (; model = "mock-gen"),
embedder_kwargs = (; model = "mock-emb"),
Expand All @@ -668,9 +692,17 @@ end
@test result.rephrased_questions == [question]
@test result.answer == nothing
@test result.final_answer == nothing
@test result.reranked_candidates.positions == [2, 1, 4, 3]
@test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"]
@test result.sources == ["source2", "source1", "source4", "source3"]
## there are two equivalent orderings
@test Set(result.reranked_candidates.positions[1:4]) == Set([2, 1])
@test result.reranked_candidates.positions[5] in [3, 4]
@test result.reranked_candidates.scores[1:4] == ones(4)
@test length(result.context) == 5 # because the second index duplicates, so we have more
@test length(unique(result.context)) == 3 # only 3 unique chunks because 1,2,1,2,3
@test all([result.context[i] in ["chunk2", "chunk1"] for i in 1:4])
@test result.context[5] in ["chunk3", "chunk4"]
@test length(unique(result.sources)) == 3
@test all([result.sources[i] in ["source2", "source1"] for i in 1:4])
@test result.sources[5] in ["source3", "source4"]

# clean up
close(echo_server)
Expand Down
7 changes: 6 additions & 1 deletion test/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,16 @@ end
index_ids = [Symbol("TestChunkIndex"), Symbol("TestChunkIndex2")],
positions = [1, 3], # Assuming chunks_data has only 3 elements, position 4 is out of bounds
scores = [0.5, 0.7])
@test mi[mc1] == ["First chunk", "6"]
## sorted=true by default
@test mi[mc1] == ["6", "First chunk"]
@test Base.getindex(mi, mc1, :chunks; sorted = true) == ["6", "First chunk"]
@test Base.getindex(mi, mc1, :sources; sorted = true) ==
["other_source3", "test_source1"]
@test Base.getindex(mi, mc1, :scores; sorted = true) == [0.7, 0.5]
@test Base.getindex(mi, mc1, :chunks; sorted = false) == ["First chunk", "6"]
@test Base.getindex(mi, mc1, :sources; sorted = false) ==
["test_source1", "other_source3"]
@test Base.getindex(mi, mc1, :scores; sorted = false) == [0.5, 0.7]
end

@testset "RAGResult" begin
Expand Down

0 comments on commit 223107f

Please sign in to comment.