From 03029ee7bb94ce623260ab7b3fd4f6a59bc0209b Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:12:04 +0200 Subject: [PATCH] Update FlashRank to use only unique documents (#166) * Update FlashRank to use only unique documents * update --- CHANGELOG.md | 10 ++++++ Project.toml | 2 +- ext/FlashRankPromptingToolsExt.jl | 28 ++++++++++----- src/utils.jl | 9 +++++ test/utils.jl | 59 ++++++++++++++++++++++++++++++- 5 files changed, 97 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e53c05149..6bdfc7160 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +## Fixed + +## [0.32.0] + +## Updated +- Changed behavior of `RAGTools.rerank(::FlashRanker,...)` to always dedupe input chunks (to reduce compute requirements). + +## Fixed +- Fixed a bug in verbose INFO log in `RAGTools.rerank(::FlashRanker,...)`. + ## [0.31.1] ### Updated diff --git a/Project.toml b/Project.toml index 7a3c37ad2..d3193f7e9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.31.1" +version = "0.32.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/ext/FlashRankPromptingToolsExt.jl b/ext/FlashRankPromptingToolsExt.jl index 3741c311a..9d62c3e73 100644 --- a/ext/FlashRankPromptingToolsExt.jl +++ b/ext/FlashRankPromptingToolsExt.jl @@ -14,6 +14,7 @@ using FlashRank candidates::RT.AbstractCandidateChunks; verbose::Bool = false, top_n::Integer = length(candidates.scores), + unique_chunks::Bool = true, kwargs...) Re-ranks a list of candidate chunks using the FlashRank.jl local models. @@ -25,6 +26,7 @@ Re-ranks a list of candidate chunks using the FlashRank.jl local models. - `candidates`: The candidate chunks to be re-ranked. - `top_n`: The number of most relevant documents to return. Default is `length(documents)`. - `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`. +- `unique_chunks`: A boolean flag indicating whether to remove duplicates from the candidate chunks prior to reranking (saves compute time). Default is `true`. # Example @@ -54,28 +56,36 @@ function RT.rerank( candidates::RT.AbstractCandidateChunks; verbose::Bool = false, top_n::Integer = length(candidates.scores), + unique_chunks::Bool = true, kwargs...) @assert top_n>0 "top_n must be a positive integer." documents = index[candidates, :chunks] @assert !(isempty(documents)) "The candidate chunks must not be empty for Cohere Reranker! Check the index IDs." + is_multi_cand = candidates isa RT.MultiCandidateChunks + index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id + positions = candidates.positions + ## Find unique only items + if unique_chunks + verbose && @info "Removing duplicates from candidate chunks prior to reranking" + unique_idxs = PT.unique_permutation(documents) + documents = documents[unique_idxs] + positions = positions[unique_idxs] + index_ids = is_multi_cand ? index_ids[unique_idxs] : index_ids + end + ## Run re-ranker ranker = reranker.model result = ranker(question, documents; top_n) ## Unwrap re-ranked positions scores = result.scores - positions = candidates.positions[result.positions] - index_ids = if candidates isa RT.MultiCandidateChunks - candidates.index_ids[result.positions] - else - candidates.index_id - end + positions = positions[result.positions] - verbose && @info "Reranking done in $(round(res.elapsed; digits=1)) seconds." + verbose && @info "Reranking done in $(round(result.elapsed; digits=1)) seconds." - return candidates isa RT.MultiCandidateChunks ? - RT.MultiCandidateChunks(index_ids, positions, scores) : + return is_multi_cand ? + RT.MultiCandidateChunks(index_ids[result.positions], positions, scores) : RT.CandidateChunks(index_ids, positions, scores) end diff --git a/src/utils.jl b/src/utils.jl index 0a542a71a..5dc017a8d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -650,3 +650,12 @@ function auth_header(api_key::Union{Nothing, AbstractString}; pushfirst!(headers, "x-api-key" => "$api_key") return headers end + +""" + unique_permutation(inputs::AbstractVector) + +Returns indices of unique items in a vector `inputs`. Access the unique values as `inputs[unique_permutation(inputs)]`. +""" +function unique_permutation(inputs::AbstractVector) + return unique(i -> inputs[i], eachindex(inputs)) +end \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index 6a6469bc0..a0495aeb7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -5,7 +5,8 @@ using PromptingTools: _extract_handlebar_variables, call_cost, call_cost_alterna using PromptingTools: _string_to_vector, _encode_local_image using PromptingTools: DataMessage, AIMessage using PromptingTools: push_conversation!, - resize_conversation!, @timeout, preview, pprint, auth_header + resize_conversation!, @timeout, preview, pprint, auth_header, + unique_permutation @testset "replace_words" begin words = ["Disney", "Snow White", "Mickey Mouse"] @@ -371,3 +372,59 @@ end "version" => "1.0" ] end + +@testset "unique_permutation" begin + # Test with an empty array + @test unique_permutation([]) == [] + + # Test with an array of integers + @test unique_permutation([1, 2, 3, 2, 1]) == [1, 2, 3] + + # Test with an array of strings + @test unique_permutation(["apple", "banana", "apple", "orange"]) == [1, 2, 4] + + # Test with repeated identical elements + @test unique_permutation([4, 4, 4, 4]) == [1] + + # Test with non-consecutive duplicates + @test unique_permutation([1, 2, 3, 1, 2, 3, 1, 2, 3]) == [1, 2, 3] + @test unique_permutation([1, 2, 1, 2, 1, 2, 3, 1, 2, 3]) == [1, 2, 7] + + # Test with an array of negative integers + @test unique_permutation([-1, -2, -3, -2, -1]) == [1, 2, 3] + + # Test with an array of mixed positive and negative integers + @test unique_permutation([1, -1, 2, -2, 1, -1]) == [1, 2, 3, 4] + + # Test with an array of floating point numbers + @test unique_permutation([1.1, 2.2, 3.3, 2.2, 1.1]) == [1, 2, 3] + + # Test with an array of mixed integers and floating point numbers + @test unique_permutation([1, 2.0, 3, 2.0, 1]) == [1, 2, 3] + + # Test with an array of very large integers + @test unique_permutation([10^10, 10^10, 10^12, 10^11, 10^12]) == [1, 3, 4] + + # Test with an array of very small floating point numbers + @test unique_permutation([1e-10, 1e-10, 1e-12, 1e-11, 1e-12]) == [1, 3, 4] + + # Test with an array of strings with different cases + @test unique_permutation(["Apple", "apple", "Banana", "banana", "Apple"]) == + [1, 2, 3, 4] + + # Test with an array of mixed data types + @test unique_permutation([1, "1", 2, "2", 1]) == [1, 2, 3, 4] + + # Test with an array of complex numbers + @test unique_permutation([1 + 1im, 2 + 2im, 1 + 1im, 3 + 3im]) == [1, 2, 4] + + # Test with an array of tuples + @test unique_permutation([(1, 2), (3, 4), (1, 2), (5, 6)]) == [1, 2, 4] + + # Test with an array of arrays + @test unique_permutation([[1, 2], [3, 4], [5, 6], [1, 2]]) == [1, 2, 3] + + # Test with an array of dictionaries + @test unique_permutation([ + Dict(:a => 1), Dict(:b => 2), Dict(:a => 1), Dict(:c => 3)]) == [1, 2, 4] +end \ No newline at end of file