Skip to content

Commit

Permalink
Add SubChunkIndex (view of index)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jul 21, 2024
1 parent 53ac0b8 commit fcd7509
Show file tree
Hide file tree
Showing 9 changed files with 701 additions and 205 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.40.0]

### Added
- Introduces `RAGTools.SubChunkIndex` to allow projecting `views` of various indices. Useful for pre-filtering your data (faster and more precise retrieval). See `?RT.SubChunkIndex` for more information and how to use it.

### Updated
- `CandidateChunks` and `MultiCandidateChunks` intersection methods updated to be an order of magnitude faster (useful for large sets like tag filters).

### Fixed
- Fixed a bug in `find_closest(finder::BM25Similarity, ...)` where `minimum_similarity` kwarg was not implemented.

## [0.39.0]

### Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.39.0"
version = "0.40.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 1 addition & 1 deletion src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ include("api_services.jl")
include("rag_interface.jl")

export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, CandidateChunks, RAGResult
export MultiIndex
export MultiIndex, SubChunkIndex, MultiCandidateChunks
include("types.jl")

export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIndexer,
Expand Down
10 changes: 5 additions & 5 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ function build_context(contexter::ContextEnumerator,
@assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative"

context = String[]
for (i, position) in enumerate(candidates.positions)
for (i, position) in enumerate(positions(candidates))
## select the right index
id = candidates isa MultiCandidateChunks ? candidates.index_ids[i] :
candidates.index_id
index_ = index isa AbstractChunkIndex ? index : index[id]
isnothing(index_) && continue
##
chunks_ = chunks(index_)[
## Refer to parent in case index is a SubChunkIndex (bc positions refer to the underlying parent chunks)
chunks_ = chunks(parent(index_))[
max(1, position - chunks_window_margin[1]):min(end,
position + chunks_window_margin[2])]
## Check if surrounding chunks are from the same source
is_same_source = sources(index_)[
is_same_source = sources(parent(index_))[
max(1, position - chunks_window_margin[1]):min(end,
position + chunks_window_margin[2])] .== sources(index_)[position]
position + chunks_window_margin[2])] .== sources(parent(index_))[position]
push!(context, "$(i). $(join(chunks_[is_same_source], "\n"))")
end
return context
Expand Down
108 changes: 63 additions & 45 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ function find_closest(
# emb is an embedding matrix where the first dimension is the embedding dimension
scores = query_emb' * emb |> vec
top_k_min = min(top_k, length(scores))
## Take the top_k largest because larger is better in Cosine similarity (=1 is the best)
positions = partialsortperm(scores, 1:top_k_min, rev = true)
if minimum_similarity > -1.0
mask = scores[positions] .>= minimum_similarity
mask = @view(scores[positions]) .>= minimum_similarity
positions = positions[mask]
else
## we want to materialize the view
Expand All @@ -229,19 +230,19 @@ function find_closest(
finder::AbstractSimilarityFinder, index::AbstractChunkIndex,
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 100, kwargs...)
isnothing(chunkdata(index)) && return CandidateChunks(; index_id = index.id)
isnothing(chunkdata(index)) && return CandidateChunks(; index_id = indexid(index))
positions, scores = find_closest(finder, chunkdata(index),
query_emb, query_tokens;
top_k, kwargs...)
return CandidateChunks(index.id, positions, Float32.(scores))
return CandidateChunks(indexid(index), positions, Float32.(scores))
end

# Dispatch to find scores for multiple embeddings
function find_closest(
finder::AbstractSimilarityFinder, index::AbstractChunkIndex,
query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}();
top_k::Int = 100, kwargs...)
isnothing(chunkdata(index)) && CandidateChunks(; index_id = index.id)
isnothing(chunkdata(index)) && CandidateChunks(; index_id = indexid(index))
## reduce top_k since we have more than one query
top_k_ = top_k ÷ size(query_emb, 2)
## simply vcat together (gets sorted from the highest similarity to the lowest)
Expand Down Expand Up @@ -273,11 +274,13 @@ function find_closest(
positions_, scores_ = find_closest(finder[i], chunkdata(all_indexes[i]),
query_emb, query_tokens;
top_k = top_k_shard, kwargs...)
append!(index_ids, fill(all_indexes[i].id, length(positions_)))
append!(index_ids, fill(indexid(all_indexes[i]), length(positions_)))
append!(positions, positions_)
append!(scores, scores_)
end
idxs = sortperm(scores, rev = true) |> x -> first(x, top_k)
## Take the top_k largest because larger is better in Cosine similarity (=1 is the best)
## Do direct sortperm because it's unlikely to be too much larger (top_k * number of shards)
idxs = sortperm(scores, rev = true) |> Base.Fix2(first, top_k)
return MultiCandidateChunks(index_ids[idxs], positions[idxs], scores[idxs])
end

Expand Down Expand Up @@ -386,6 +389,7 @@ function find_closest(
binary_query_emb = map(>(0), query_emb)
scores = hamming_distance(emb, binary_query_emb)
num_candidates = min(top_k * rescore_multiplier, length(scores))
## Take the top_k smallest because smaller is better in Hamming distance
positions = partialsortperm(scores, 1:num_candidates)

## Second pass, rescore with float embeddings and return top_k
Expand Down Expand Up @@ -428,6 +432,7 @@ function find_closest(
bit_query_emb = pack_bits(query_emb .> 0)
scores = hamming_distance(emb, bit_query_emb)
num_candidates = min(top_k * rescore_multiplier, length(scores))
## Take the top_k smallest because smaller is better in Hamming distance
positions = partialsortperm(scores, 1:num_candidates)

## Second pass, rescore with float embeddings and return top_k
Expand All @@ -454,18 +459,20 @@ function find_closest(
finder::BM25Similarity, dtm::DocumentTermMatrix,
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
bm_scores = bm25(dtm, query_tokens)
top_k_min = min(top_k, length(bm_scores))
positions = partialsortperm(bm_scores, 1:top_k_min, rev = true)
scores = bm25(dtm, query_tokens)
top_k_min = min(top_k, length(scores))
## Take the top_k largest because higher is better in BM25
## BM25 score are non-negative but unbounded (grows with number of keywords)
positions = partialsortperm(scores, 1:top_k_min, rev = true)

if minimum_similarity > -1.0
mask = scores[positions] .>= minimum_similarity
mask = @view(scores[positions]) .>= minimum_similarity
positions = positions[mask]
else
# materialize the vector
positions = positions |> collect
end
return positions, bm_scores[positions]
return positions, scores[positions]
end

### TAG Filtering
Expand All @@ -488,24 +495,29 @@ Finds the indices of chunks (represented by tags in `index`) that have ANY OF th
"""
function find_tags(method::AnyTagFilter, index::AbstractChunkIndex,
tag::Union{AbstractString, Regex}; kwargs...)
isnothing(tags(index)) && CandidateChunks(; index_id = index.id)
isnothing(tags(index)) && CandidateChunks(; index_id = indexid(index))
tag_idx = if tag isa AbstractString
findall(tags_vocab(index) .== tag)
else # assume it's a regex
findall(occursin.(tag, tags_vocab(index)))
end
# getindex.(x, 1) is to get the first dimension in each CartesianIndex
match_row_idx = @view(tags(index)[:, tag_idx]) |> findall |>
x -> getindex.(x, 1) |> unique
return CandidateChunks(index.id, match_row_idx, ones(Float32, length(match_row_idx)))
match_row_idx = @view(tags(index)[:, tag_idx]) |> findall .|> Base.Fix2(getindex, 1) |>
unique
## Index can be a SubChunkIndex, so we need to convert to the original indices
if index isa SubChunkIndex
match_row_idx = positions(index)[match_row_idx]
end
return CandidateChunks(
indexid(index), match_row_idx, ones(Float32, length(match_row_idx)))
end

# Method for multiple tags
function find_tags(method::AnyTagFilter, index::AbstractChunkIndex,
tags::Vector{T}; kwargs...) where {T <: Union{AbstractString, Regex}}
pos = [find_tags(method, index, tag).positions for tag in tags] |>
pos = [positions(find_tags(method, index, tag)) for tag in tags] |>
Base.Splat(vcat) |> unique |> x -> convert(Vector{Int}, x)
return CandidateChunks(index.id, pos, ones(Float32, length(pos)))
return CandidateChunks(indexid(index), pos, ones(Float32, length(pos)))
end

"""
Expand All @@ -519,7 +531,7 @@ Finds the indices of chunks (represented by tags in `index`) that have ALL OF th
"""
function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
tags_vec::Vector{T}; kwargs...) where {T <: Union{AbstractString, Regex}}
isnothing(tags(index)) && CandidateChunks(; index_id = index.id)
isnothing(tags(index)) && CandidateChunks(; index_id = indexid(index))
tag_idx = Int[]
for tag in tags_vec
if tag isa AbstractString
Expand All @@ -534,7 +546,12 @@ function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
else
Int[]
end
return CandidateChunks(index.id, match_row_idx, ones(Float32, length(match_row_idx)))
## Index can be a SubChunkIndex, so we need to convert to the original indices
if index isa SubChunkIndex
match_row_idx = positions(index)[match_row_idx]
end
return CandidateChunks(
indexid(index), match_row_idx, ones(Float32, length(match_row_idx)))
end
function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
tag::Union{AbstractString, Regex}; kwargs...)
Expand Down Expand Up @@ -565,21 +582,21 @@ function find_tags(method::Union{AnyTagFilter, AllTagFilter}, index::AbstractMul
return MultiCandidateChunks(; index_ids = Symbol[])

index_ids = Symbol[]
positions = Int[]
scores = Float32[]
positions_ = Int[]
scores_ = Float32[]
for i in eachindex(all_indexes)
if isnothing(tags(all_indexes[i]))
continue
end
cc = find_tags(method, all_indexes[i], tag; kwargs...)
if !isempty(cc.positions)
append!(index_ids, fill(cc.index_id, length(cc.positions)))
append!(positions, cc.positions)
append!(scores, cc.scores)
if !isempty(positions(cc))
append!(index_ids, fill(indexid(cc), length(positions(cc))))
append!(positions_, positions(cc))
append!(scores_, scores(cc))
end
end
idxs = sortperm(scores, rev = true)
return MultiCandidateChunks(index_ids[idxs], positions[idxs], scores[idxs])
idxs = sortperm(scores_, rev = true)
return MultiCandidateChunks(index_ids[idxs], positions_[idxs], scores_[idxs])
end

function find_tags(method::NoTagFilter, index::AbstractMultiIndex,
Expand Down Expand Up @@ -720,16 +737,16 @@ function rerank(
## Unwrap re-ranked positions
is_multi_cand = candidates isa MultiCandidateChunks
index_ids = Vector{Symbol}(undef, length(r.response[:results]))
positions = Vector{Int}(undef, length(r.response[:results]))
scores = Vector{Float32}(undef, length(r.response[:results]))
positions_ = Vector{Int}(undef, length(r.response[:results]))
scores_ = Vector{Float32}(undef, length(r.response[:results]))
for i in eachindex(r.response[:results])
doc = r.response[:results][i]
positions[i] = candidates.positions[doc[:index] + 1]
scores[i] = doc[:relevance_score]
positions_[i] = positions(candidates)[doc[:index] + 1]
scores_[i] = doc[:relevance_score]
index_ids[i] = if is_multi_cand
candidates.index_ids[doc[:index] + 1]
indexids(candidates)[doc[:index] + 1]
else
candidates.index_id
indexid(candidates)
end
end

Expand All @@ -745,8 +762,8 @@ function rerank(
verbose && @info "Reranking done. $search_units_str"

return is_multi_cand ?
MultiCandidateChunks(index_ids, positions, scores) :
CandidateChunks(index_ids[1], positions, scores)
MultiCandidateChunks(index_ids, positions_, scores_) :
CandidateChunks(index_ids[1], positions_, scores_)
end

"""
Expand Down Expand Up @@ -809,14 +826,14 @@ function rerank(
@assert !(isempty(documents)) "The candidate chunks must not be empty! Check the index IDs."

is_multi_cand = candidates isa MultiCandidateChunks
index_ids = is_multi_cand ? candidates.index_ids : candidates.index_id
positions = candidates.positions
index_ids = is_multi_cand ? indexids(candidates) : indexid(candidates)
positions_ = positions(candidates)
## Find unique only items
if unique_chunks
verbose && @info "Removing duplicates from candidate chunks prior to reranking"
unique_idxs = PT.unique_permutation(documents)
documents = documents[unique_idxs]
positions = positions[unique_idxs]
positions_ = positions_[unique_idxs]
index_ids = is_multi_cand ? index_ids[unique_idxs] : index_ids
end

Expand All @@ -832,16 +849,16 @@ function rerank(

## Unwrap re-ranked positions
ranked_positions = first(result.positions, top_n)
positions = positions[ranked_positions]
positions_ = positions_[ranked_positions]
## TODO: add reciprocal rank fusion and multiple passes
scores = ones(Float32, length(positions)) # no scores available
scores_ = ones(Float32, length(positions_)) # no scores available

verbose && @info "Reranking done in $(round(result.elapsed; digits=1)) seconds."
Threads.atomic_add!(cost_tracker, result.cost)

return is_multi_cand ?
MultiCandidateChunks(index_ids[ranked_positions], positions, scores) :
CandidateChunks(index_ids, positions, scores)
MultiCandidateChunks(index_ids[ranked_positions], positions_, scores_) :
CandidateChunks(index_ids, positions_, scores_)
end

### Overall types for `retrieve`
Expand Down Expand Up @@ -1117,16 +1134,17 @@ function retrieve(retriever::AbstractRetriever,
top_n, verbose = (verbose > 1), cost_tracker, reranker_kwargs_...)

verbose > 0 &&
@info "Retrieval done. Identified $(length(reranked_candidates.positions)) chunks, total cost: \$$(round(cost_tracker[], digits=2))."
@info "Retrieval done. Identified $(length(positions(reranked_candidates))) chunks, total cost: \$$(round(cost_tracker[], digits=2))."

## Return
result = RAGResult(;
question,
answer = nothing,
rephrased_questions,
final_answer = nothing,
context = collect(index[reranked_candidates, :chunks]),
sources = collect(index[reranked_candidates, :sources]),
## Ensure chunks and sources are sorted
context = collect(index[reranked_candidates, :chunks, sorted = true]),
sources = collect(index[reranked_candidates, :sources, sorted = true]),
emb_candidates,
tag_candidates,
filtered_candidates,
Expand Down
Loading

0 comments on commit fcd7509

Please sign in to comment.