From 223107f0fe0c5fe3da96ec6c0ee9664e9b6d8129 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Tue, 18 Jun 2024 11:52:15 +0200 Subject: [PATCH] Update RAG performance --- CHANGELOG.md | 11 ++++ Project.toml | 2 +- src/Experimental/RAGTools/retrieval.jl | 59 ++++++++----------- src/Experimental/RAGTools/types.jl | 7 ++- test/Experimental/RAGTools/evaluation.jl | 3 +- test/Experimental/RAGTools/retrieval.jl | 74 +++++++++++++++++------- test/Experimental/RAGTools/types.jl | 7 ++- 7 files changed, 102 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb9b01655..422d74b46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.31.0] + +### Breaking Changes +- The return type of `RAGTools.find_tags(::NoTagger,...)` is now `::Nothing` instead of `CandidateChunks`/`MultiCandidateChunks` with all documents. +- `Base.getindex(::MultiIndex, ::MultiCandidateChunks)` now always returns sorted chunks for consistency with the behavior of other `getindex` methods on `*Chunks`. + +### Updated +- Cosine similarity search now uses `partialsortperm` for better performance on large datasets. +- Skip unnecessary work when the tagging functionality in the RAG pipeline is disabled (`find_tags` with `NoTagger` always returns `nothing` which improves the compiled code). +- Changed the default behavior of `getindex(::MultiIndex, ::MultiCandidateChunks)` to always return sorted chunks for consistency with other similar functions. Note that you should always use re-rankering anyway (see `FlashRank.jl`). + ## [0.30.0] ### Fixed diff --git a/Project.toml b/Project.toml index c8c5580b0..501e8ff53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.30.0" +version = "0.31.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 976429a88..4a1dc6343 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -196,10 +196,14 @@ function find_closest( top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...) # emb is an embedding matrix where the first dimension is the embedding dimension scores = query_emb' * emb |> vec - positions = scores |> sortperm |> x -> last(x, top_k) |> reverse + top_k_min = min(top_k, length(scores)) + positions = partialsortperm(scores, 1:top_k_min, rev = true) if minimum_similarity > -1.0 mask = scores[positions] .>= minimum_similarity positions = positions[mask] + else + ## we want to materialize the view + positions = collect(positions) end return positions, scores[positions] end @@ -374,7 +378,8 @@ function find_closest( ## First pass, both in binary with Hamming, get rescore_multiplier times top_k binary_query_emb = map(>(0), query_emb) scores = hamming_distance(emb, binary_query_emb) - positions = scores |> sortperm |> x -> first(x, top_k * rescore_multiplier) + num_candidates = min(top_k * rescore_multiplier, length(scores)) + positions = partialsortperm(scores, 1:num_candidates) ## Second pass, rescore with float embeddings and return top_k new_positions, scores = find_closest(CosineSimilarity(), @view(emb[:, positions]), @@ -415,7 +420,8 @@ function find_closest( ## First pass, both in binary with Hamming, get rescore_multiplier times top_k bit_query_emb = pack_bits(query_emb .> 0) scores = hamming_distance(emb, bit_query_emb) - positions = scores |> sortperm |> x -> first(x, top_k * rescore_multiplier) + num_candidates = min(top_k * rescore_multiplier, length(scores)) + positions = partialsortperm(scores, 1:num_candidates) ## Second pass, rescore with float embeddings and return top_k unpacked_emb = unpack_bits(@view(emb[:, positions])) @@ -442,11 +448,15 @@ function find_closest( query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[]; top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...) bm_scores = bm25(dtm, query_tokens) - positions = bm_scores |> sortperm |> x -> last(x, top_k) |> reverse + top_k_min = min(top_k, length(bm_scores)) + positions = partialsortperm(bm_scores, 1:top_k_min, rev = true) if minimum_similarity > -1.0 mask = scores[positions] .>= minimum_similarity positions = positions[mask] + else + # materialize the vector + positions = positions |> collect end return positions, bm_scores[positions] end @@ -455,7 +465,8 @@ end function find_tags(::AbstractTagFilter, index::AbstractDocumentIndex, tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: - Union{AbstractString, Regex}} + Union{ + AbstractString, Regex, Nothing}} throw(ArgumentError("Not implemented yet for type $(typeof(filter)) and index $(typeof(index))")) end @@ -492,24 +503,19 @@ end """ find_tags(method::NoTagFilter, index::AbstractChunkIndex, + tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: + Union{ + AbstractString, Regex, Nothing}} tags; kwargs...) -Returns all chunks in the index, ie, no filtering. +Returns all chunks in the index, ie, no filtering, so we simply return `nothing` (easier for dispatch). """ function find_tags(method::NoTagFilter, index::AbstractChunkIndex, tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: Union{ - AbstractString, Regex}} - return CandidateChunks( - index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks))) + AbstractString, Regex, Nothing}} + return nothing end - -function find_tags(method::NoTagFilter, index::AbstractChunkIndex, - tags::Nothing; kwargs...) - return CandidateChunks( - index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks))) -end - ## Multi-index implementation function find_tags(method::AnyTagFilter, index::AbstractMultiIndex, tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: @@ -539,24 +545,8 @@ end function find_tags(method::NoTagFilter, index::AbstractMultiIndex, tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: Union{ - AbstractString, Regex}} - indexes_ = indexes(index) - length_ = sum(x -> length(x.chunks), indexes_) - index_ids = [fill(x.id, length(x.chunks)) for x in indexes_] |> Base.Splat(vcat) - - return MultiCandidateChunks( - index_ids, collect(1:length_), - zeros(Float32, length_)) -end -function find_tags(method::NoTagFilter, index::AbstractMultiIndex, - tags::Nothing; kwargs...) - indexes_ = indexes(index) - length_ = sum(x -> length(x.chunks), indexes_) - index_ids = [fill(x.id, length(x.chunks)) for x in indexes_] |> Base.Splat(vcat) - - return MultiCandidateChunks( - index_ids, collect(1:length_), - zeros(Float32, length_)) + AbstractString, Regex, Nothing}} + return nothing end ### Reranking @@ -966,6 +956,7 @@ function retrieve(retriever::AbstractRetriever, filter, index, tags; verbose = (verbose > 1), filter_kwargs_...) ## Combine the two sets of candidates, looks for intersection (hard filter)! + # With tagger=NoTagger() get_tags returns `nothing` find_tags simply passes it through to skip the intersection filtered_candidates = isnothing(tag_candidates) ? emb_candidates : (emb_candidates & tag_candidates) ## TODO: Future implementation should be to apply tag filtering BEFORE the find_closest, diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index bd7d6cb68..41c783a37 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -479,7 +479,7 @@ function Base.getindex(ci::AbstractDocumentIndex, end function Base.getindex(ci::AbstractChunkIndex, candidate::CandidateChunks{TP, TD}, - field::Symbol = :chunks) where {TP <: Integer, TD <: Real} + field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real} @assert field in [:chunks, :embeddings, :chunkdata, :sources] "Only `chunks`, `embeddings`, `chunkdata`, `sources` fields are supported for now" field = field == :embeddings ? :chunkdata : field len_ = length(chunks(ci)) @@ -504,7 +504,8 @@ function Base.getindex(ci::AbstractChunkIndex, end function Base.getindex(mi::MultiIndex, candidate::CandidateChunks{TP, TD}, - field::Symbol = :chunks) where {TP <: Integer, TD <: Real} + field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real} + ## Always sorted! @assert field in [:chunks, :sources] "Only `chunks`, `sources` fields are supported for now" valid_index = findfirst(x -> x.id == candidate.index_id, indexes(mi)) if isnothing(valid_index) && field == :chunks @@ -549,7 +550,7 @@ end # Getindex on Multiindex, pool the individual hits function Base.getindex(mi::MultiIndex, candidate::MultiCandidateChunks{TP, TD}, - field::Symbol = :chunks; sorted::Bool = false) where {TP <: Integer, TD <: Real} + field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real} @assert field in [:chunks, :sources, :scores] "Only `chunks`, `sources`, and `scores` fields are supported for now" if sorted # values can be either of chunks or sources diff --git a/test/Experimental/RAGTools/evaluation.jl b/test/Experimental/RAGTools/evaluation.jl index 94539bcc3..eb97c4329 100644 --- a/test/Experimental/RAGTools/evaluation.jl +++ b/test/Experimental/RAGTools/evaluation.jl @@ -155,6 +155,7 @@ end embeddings = zeros(128, 3), tags = vcat(trues(2, 2), falses(1, 2)), tags_vocab = ["yes", "no"]) + index.embeddings[1, 1] = 1 # Test for successful Q&A extraction from document chunks qa_evals = build_qa_evals(chunks(index), @@ -193,7 +194,7 @@ end api_kwargs = (; url = "http://localhost:$(PORT)"), parameters_dict = Dict(:key1 => "value1", :key2 => 2)) @test result.retrieval_score == 1.0 - @test result.retrieval_rank == 2 + @test result.retrieval_rank == 1 @test result.answer_score == 5 @test result.parameters == Dict(:key1 => "value1", :key2 => 2) diff --git a/test/Experimental/RAGTools/retrieval.jl b/test/Experimental/RAGTools/retrieval.jl index 70f82e3e3..e7720409f 100644 --- a/test/Experimental/RAGTools/retrieval.jl +++ b/test/Experimental/RAGTools/retrieval.jl @@ -432,12 +432,14 @@ end # No filter tag -- give everything cc = find_tags(NoTagFilter(), index, "julia") - @test cc.positions == [1, 2] - @test cc.scores == [0.0, 0.0] + @test isnothing(cc) + # @test cc.positions == [1, 2] + # @test cc.scores == [0.0, 0.0] cc = find_tags(NoTagFilter(), index, nothing) - @test cc.positions == [1, 2] - @test cc.scores == [0.0, 0.0] + @test isnothing(cc) + # @test cc.positions == [1, 2] + # @test cc.scores == [0.0, 0.0] # Unknown type struct RandomTagFilter123 <: AbstractTagFilter end @@ -456,12 +458,14 @@ end multi_index = MultiIndex(id = :multi, indexes = [index1, index2]) mcc = find_tags(NoTagFilter(), multi_index, "julia") - @test mcc.positions == [1, 2, 3, 4] - @test mcc.scores == [0.0, 0.0, 0.0, 0.0] + @test mcc == nothing + # @test mcc.positions == [1, 2, 3, 4] + # @test mcc.scores == [0.0, 0.0, 0.0, 0.0] mcc = find_tags(NoTagFilter(), multi_index, nothing) - @test mcc.positions == [1, 2, 3, 4] - @test mcc.scores == [0.0, 0.0, 0.0, 0.0] + @test mcc == nothing + # @test mcc.positions == [1, 2, 3, 4] + # @test mcc.scores == [0.0, 0.0, 0.0, 0.0] multi_index2 = MultiIndex(id = :multi2, indexes = [index, index2]) mcc2 = find_tags(AnyTagFilter(), multi_index2, "julia") @@ -538,7 +542,7 @@ end @testset "retrieve" begin # test with a mock server - PORT = rand(20000:40000) + PORT = rand(20000:40001) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-emb2", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) @@ -609,8 +613,16 @@ end @test result.rephrased_questions == [question] @test result.answer == nothing @test result.final_answer == nothing - @test result.reranked_candidates.positions == [2, 1, 4, 3] - @test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"] + ## there are two equivalent orderings + @test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1]) + @test Set(result.reranked_candidates.positions[3:4]) == Set([3, 4]) + @test result.reranked_candidates.scores[1:2] == ones(2) + @test length(result.context) == 4 + @test length(unique(result.context)) == 4 + @test result.context[1] in ["chunk2", "chunk1"] + @test result.context[2] in ["chunk2", "chunk1"] + @test result.context[3] in ["chunk3", "chunk4"] + @test result.context[4] in ["chunk3", "chunk4"] @test result.sources isa Vector{String} # Reduce number of candidates @@ -620,8 +632,10 @@ end embedder_kwargs = (; model = "mock-emb"), tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; url = "http://localhost:$(PORT)")) - @test result.emb_candidates.positions == [2, 1, 4] - @test result.reranked_candidates.positions == [2, 1] + ## the last item is 3 or 4 + @test result.emb_candidates.positions[3] in [3, 4] + @test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1]) + @test result.emb_candidates.scores[1:2] == ones(2) # with default dispatch result = retrieve(index, question; @@ -630,8 +644,9 @@ end embedder_kwargs = (; model = "mock-emb"), tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; url = "http://localhost:$(PORT)")) - @test result.emb_candidates.positions == [2, 1, 4] - @test result.reranked_candidates.positions == [2, 1] + @test result.emb_candidates.positions[3] in [3, 4] + @test result.emb_candidates.scores[1:2] == ones(2) + @test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1]) ## AdvancedRetriever adv = AdvancedRetriever() @@ -645,12 +660,21 @@ end @test result.rephrased_questions == [question, "Query: test question\n\nPassage:"] # from the template we use @test result.answer == nothing @test result.final_answer == nothing - @test result.reranked_candidates.positions == [2, 1, 4, 3] - @test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"] + ## there are two equivalent orderings + @test Set(result.reranked_candidates.positions[1:2]) == Set([2, 1]) + @test Set(result.reranked_candidates.positions[3:4]) == Set([3, 4]) + @test result.reranked_candidates.scores[1:2] == ones(2) + @test length(result.context) == 4 + @test length(unique(result.context)) == 4 + @test result.context[1] in ["chunk2", "chunk1"] + @test result.context[2] in ["chunk2", "chunk1"] + @test result.context[3] in ["chunk3", "chunk4"] + @test result.context[4] in ["chunk3", "chunk4"] @test result.sources isa Vector{String} # Multi-index retriever index_keywords = ChunkKeywordsIndex(index, index_id = :TestChunkIndexX) + index_keywords = ChunkIndex(; id = :AA, index.chunks, index.sources, index.embeddings) # Create MultiIndex instance multi_index = MultiIndex(id = :multi, indexes = [index, index_keywords]) @@ -658,7 +682,7 @@ end finder = MultiFinder([RT.CosineSimilarity(), RT.BM25Similarity()]) retriever = SimpleRetriever(; processor = RT.KeywordsProcessor(), finder) - result = retrieve(retriever, multi_index, question; + result = retrieve(SimpleRetriever(), multi_index, question; reranker = NoReranker(), # we need to disable cohere as we cannot test it rephraser_kwargs = (; model = "mock-gen"), embedder_kwargs = (; model = "mock-emb"), @@ -668,9 +692,17 @@ end @test result.rephrased_questions == [question] @test result.answer == nothing @test result.final_answer == nothing - @test result.reranked_candidates.positions == [2, 1, 4, 3] - @test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"] - @test result.sources == ["source2", "source1", "source4", "source3"] + ## there are two equivalent orderings + @test Set(result.reranked_candidates.positions[1:4]) == Set([2, 1]) + @test result.reranked_candidates.positions[5] in [3, 4] + @test result.reranked_candidates.scores[1:4] == ones(4) + @test length(result.context) == 5 # because the second index duplicates, so we have more + @test length(unique(result.context)) == 3 # only 3 unique chunks because 1,2,1,2,3 + @test all([result.context[i] in ["chunk2", "chunk1"] for i in 1:4]) + @test result.context[5] in ["chunk3", "chunk4"] + @test length(unique(result.sources)) == 3 + @test all([result.sources[i] in ["source2", "source1"] for i in 1:4]) + @test result.sources[5] in ["source3", "source4"] # clean up close(echo_server) diff --git a/test/Experimental/RAGTools/types.jl b/test/Experimental/RAGTools/types.jl index 7669dab28..8abe14253 100644 --- a/test/Experimental/RAGTools/types.jl +++ b/test/Experimental/RAGTools/types.jl @@ -539,11 +539,16 @@ end index_ids = [Symbol("TestChunkIndex"), Symbol("TestChunkIndex2")], positions = [1, 3], # Assuming chunks_data has only 3 elements, position 4 is out of bounds scores = [0.5, 0.7]) - @test mi[mc1] == ["First chunk", "6"] + ## sorted=true by default + @test mi[mc1] == ["6", "First chunk"] @test Base.getindex(mi, mc1, :chunks; sorted = true) == ["6", "First chunk"] @test Base.getindex(mi, mc1, :sources; sorted = true) == ["other_source3", "test_source1"] @test Base.getindex(mi, mc1, :scores; sorted = true) == [0.7, 0.5] + @test Base.getindex(mi, mc1, :chunks; sorted = false) == ["First chunk", "6"] + @test Base.getindex(mi, mc1, :sources; sorted = false) == + ["test_source1", "other_source3"] + @test Base.getindex(mi, mc1, :scores; sorted = false) == [0.5, 0.7] end @testset "RAGResult" begin