Skip to content

Commit

Permalink
Fix CohereReranker bug
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jul 9, 2024
1 parent 6e9f0ea commit 988f3d3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.37.1]

### Fixed
- Fixed a bug in CohereReranker when it wouldn't handle correctly CandidateChunks.

## [0.37.0]

### Updated
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.37.0"
version = "0.37.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
15 changes: 8 additions & 7 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -678,14 +678,15 @@ function rerank(
kwargs...)

## Unwrap re-ranked positions
is_multi_cand = candidates isa MultiCandidateChunks
index_ids = Vector{Symbol}(undef, length(r.response[:results]))
positions = Vector{Int}(undef, length(r.response[:results]))
scores = Vector{Float32}(undef, length(r.response[:results]))
for i in eachindex(r.response[:results])
doc = r.response[:results][i]
positions[i] = candidates.positions[doc[:index] + 1]
scores[i] = doc[:relevance_score]
index_ids[i] = if candidates isa MultiCandidateChunks
index_ids[i] = if is_multi_cand
candidates.index_ids[doc[:index] + 1]
else
candidates.index_id
Expand All @@ -703,23 +704,23 @@ function rerank(
end
verbose && @info "Reranking done. $search_units_str"

return candidates isa MultiCandidateChunks ?
return is_multi_cand ?
MultiCandidateChunks(index_ids, positions, scores) :
CandidateChunks(index_ids, positions, scores)
CandidateChunks(index_ids[1], positions, scores)
end

"""
rerank(
reranker::CohereReranker, index::AbstractDocumentIndex, question::AbstractString,
reranker::RankGPTReranker, index::AbstractDocumentIndex, question::AbstractString,
candidates::AbstractCandidateChunks;
verbose::Integer = 1,
api_key::AbstractString = PT.OPENAI_API_KEY,
top_n::Integer = length(candidates.scores),
model::AbstractString = PT.MODEL_CHAT,
verbose::Bool = false,
top_n::Integer = length(candidates.scores),
unique_chunks::Bool = true,
cost_tracker = Threads.Atomic{Float64}(0.0),
kwargs...)
Re-ranks a list of candidate chunks using the RankGPT algorithm. See https://github.com/sunnweiwei/RankGPT for more details.
It uses LLM calls to rank the candidate chunks.
Expand Down

0 comments on commit 988f3d3

Please sign in to comment.