Skip to content

Commit

Permalink
Update FlashRank to use only unique documents (#166)
Browse files Browse the repository at this point in the history
* Update FlashRank to use only unique documents

* update
  • Loading branch information
svilupp authored Jun 18, 2024
1 parent a4f191a commit 03029ee
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 11 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

## Fixed

## [0.32.0]

## Updated
- Changed behavior of `RAGTools.rerank(::FlashRanker,...)` to always dedupe input chunks (to reduce compute requirements).

## Fixed
- Fixed a bug in verbose INFO log in `RAGTools.rerank(::FlashRanker,...)`.

## [0.31.1]

### Updated
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.31.1"
version = "0.32.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
28 changes: 19 additions & 9 deletions ext/FlashRankPromptingToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using FlashRank
candidates::RT.AbstractCandidateChunks;
verbose::Bool = false,
top_n::Integer = length(candidates.scores),
unique_chunks::Bool = true,
kwargs...)
Re-ranks a list of candidate chunks using the FlashRank.jl local models.
Expand All @@ -25,6 +26,7 @@ Re-ranks a list of candidate chunks using the FlashRank.jl local models.
- `candidates`: The candidate chunks to be re-ranked.
- `top_n`: The number of most relevant documents to return. Default is `length(documents)`.
- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`.
- `unique_chunks`: A boolean flag indicating whether to remove duplicates from the candidate chunks prior to reranking (saves compute time). Default is `true`.
# Example
Expand Down Expand Up @@ -54,28 +56,36 @@ function RT.rerank(
candidates::RT.AbstractCandidateChunks;
verbose::Bool = false,
top_n::Integer = length(candidates.scores),
unique_chunks::Bool = true,
kwargs...)
@assert top_n>0 "top_n must be a positive integer."
documents = index[candidates, :chunks]
@assert !(isempty(documents)) "The candidate chunks must not be empty for Cohere Reranker! Check the index IDs."

is_multi_cand = candidates isa RT.MultiCandidateChunks
index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id
positions = candidates.positions
## Find unique only items
if unique_chunks
verbose && @info "Removing duplicates from candidate chunks prior to reranking"
unique_idxs = PT.unique_permutation(documents)
documents = documents[unique_idxs]
positions = positions[unique_idxs]
index_ids = is_multi_cand ? index_ids[unique_idxs] : index_ids
end

## Run re-ranker
ranker = reranker.model
result = ranker(question, documents; top_n)

## Unwrap re-ranked positions
scores = result.scores
positions = candidates.positions[result.positions]
index_ids = if candidates isa RT.MultiCandidateChunks
candidates.index_ids[result.positions]
else
candidates.index_id
end
positions = positions[result.positions]

verbose && @info "Reranking done in $(round(res.elapsed; digits=1)) seconds."
verbose && @info "Reranking done in $(round(result.elapsed; digits=1)) seconds."

return candidates isa RT.MultiCandidateChunks ?
RT.MultiCandidateChunks(index_ids, positions, scores) :
return is_multi_cand ?
RT.MultiCandidateChunks(index_ids[result.positions], positions, scores) :
RT.CandidateChunks(index_ids, positions, scores)
end

Expand Down
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,12 @@ function auth_header(api_key::Union{Nothing, AbstractString};
pushfirst!(headers, "x-api-key" => "$api_key")
return headers
end

"""
unique_permutation(inputs::AbstractVector)
Returns indices of unique items in a vector `inputs`. Access the unique values as `inputs[unique_permutation(inputs)]`.
"""
function unique_permutation(inputs::AbstractVector)
return unique(i -> inputs[i], eachindex(inputs))
end
59 changes: 58 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using PromptingTools: _extract_handlebar_variables, call_cost, call_cost_alterna
using PromptingTools: _string_to_vector, _encode_local_image
using PromptingTools: DataMessage, AIMessage
using PromptingTools: push_conversation!,
resize_conversation!, @timeout, preview, pprint, auth_header
resize_conversation!, @timeout, preview, pprint, auth_header,
unique_permutation

@testset "replace_words" begin
words = ["Disney", "Snow White", "Mickey Mouse"]
Expand Down Expand Up @@ -371,3 +372,59 @@ end
"version" => "1.0"
]
end

@testset "unique_permutation" begin
# Test with an empty array
@test unique_permutation([]) == []

# Test with an array of integers
@test unique_permutation([1, 2, 3, 2, 1]) == [1, 2, 3]

# Test with an array of strings
@test unique_permutation(["apple", "banana", "apple", "orange"]) == [1, 2, 4]

# Test with repeated identical elements
@test unique_permutation([4, 4, 4, 4]) == [1]

# Test with non-consecutive duplicates
@test unique_permutation([1, 2, 3, 1, 2, 3, 1, 2, 3]) == [1, 2, 3]
@test unique_permutation([1, 2, 1, 2, 1, 2, 3, 1, 2, 3]) == [1, 2, 7]

# Test with an array of negative integers
@test unique_permutation([-1, -2, -3, -2, -1]) == [1, 2, 3]

# Test with an array of mixed positive and negative integers
@test unique_permutation([1, -1, 2, -2, 1, -1]) == [1, 2, 3, 4]

# Test with an array of floating point numbers
@test unique_permutation([1.1, 2.2, 3.3, 2.2, 1.1]) == [1, 2, 3]

# Test with an array of mixed integers and floating point numbers
@test unique_permutation([1, 2.0, 3, 2.0, 1]) == [1, 2, 3]

# Test with an array of very large integers
@test unique_permutation([10^10, 10^10, 10^12, 10^11, 10^12]) == [1, 3, 4]

# Test with an array of very small floating point numbers
@test unique_permutation([1e-10, 1e-10, 1e-12, 1e-11, 1e-12]) == [1, 3, 4]

# Test with an array of strings with different cases
@test unique_permutation(["Apple", "apple", "Banana", "banana", "Apple"]) ==
[1, 2, 3, 4]

# Test with an array of mixed data types
@test unique_permutation([1, "1", 2, "2", 1]) == [1, 2, 3, 4]

# Test with an array of complex numbers
@test unique_permutation([1 + 1im, 2 + 2im, 1 + 1im, 3 + 3im]) == [1, 2, 4]

# Test with an array of tuples
@test unique_permutation([(1, 2), (3, 4), (1, 2), (5, 6)]) == [1, 2, 4]

# Test with an array of arrays
@test unique_permutation([[1, 2], [3, 4], [5, 6], [1, 2]]) == [1, 2, 3]

# Test with an array of dictionaries
@test unique_permutation([
Dict(:a => 1), Dict(:b => 2), Dict(:a => 1), Dict(:c => 3)]) == [1, 2, 4]
end

0 comments on commit 03029ee

Please sign in to comment.