Skip to content

Commit

Permalink
Add AllTagFilter (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jul 16, 2024
1 parent dd3fbbc commit dfb88a1
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 7 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.38.0]

### Added
- Added a new tagging filter `RT.AllTagFilter` to `RT.find_tags`, which requires all tags to be present in a chunk.
- Added an option in `RT.get_keywords` to set the minimum length of the keywords.
- Added a new method for `reciprocal_rank_fusion` and utility for standardizing candidate chunk scores (`score_to_unit_scale`).

## [0.37.1]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.37.1"
version = "0.38.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
5 changes: 4 additions & 1 deletion ext/SnowballPromptingToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ RT._stem(stemmer::Snowball.Stemmer, text::AbstractString) = Snowball.stem(stemme
stemmer = nothing,
stopwords::Set{String} = Set(STOPWORDS),
return_keywords::Bool = false,
min_length::Integer = 3,
kwargs...)
Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stemmer` and `stopwords`.
Expand All @@ -27,13 +28,15 @@ Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stem
- `stemmer`: A stemmer to use for stemming. Default is `nothing`.
- `stopwords`: A set of stopwords to remove. Default is `Set(STOPWORDS)`.
- `return_keywords`: A boolean flag for returning the keywords. Default is `false`. Useful for query processing in search time.
- `min_length`: The minimum length of the keywords. Default is `3`.
"""
function RT.get_keywords(
processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
stemmer = nothing,
stopwords::Set{String} = Set(RT.STOPWORDS),
return_keywords::Bool = false,
min_length::Integer = 3,
kwargs...)
## check if extension is available
ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt)
Expand All @@ -47,7 +50,7 @@ function RT.get_keywords(
## Preprocess text into tokens
stemmer = !isnothing(stemmer) ? stemmer : Snowball.Stemmer("english")
# Single-threaded as stemmer is not thread-safe
keywords = RT.preprocess_tokens(docs, stemmer; stopwords, min_length = 3)
keywords = RT.preprocess_tokens(docs, stemmer; stopwords, min_length)

## Early exit if we only want keywords (search time)
return_keywords && return keywords
Expand Down
44 changes: 42 additions & 2 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ Finds the chunks that have ANY OF the specified tag(s).
"""
struct AnyTagFilter <: AbstractTagFilter end

"""
AllTagFilter <: AbstractTagFilter
Finds the chunks that have ALL OF the specified tag(s).
"""
struct AllTagFilter <: AbstractTagFilter end

### Functions
function rephrase(rephraser::AbstractRephraser, question::AbstractString; kwargs...)
throw(ArgumentError("Not implemented yet for type $(typeof(rephraser))"))
Expand Down Expand Up @@ -501,6 +508,39 @@ function find_tags(method::AnyTagFilter, index::AbstractChunkIndex,
return CandidateChunks(index.id, pos, ones(Float32, length(pos)))
end

"""
find_tags(method::AllTagFilter, index::AbstractChunkIndex,
tag::Union{AbstractString, Regex}; kwargs...)
find_tags(method::AllTagFilter, index::AbstractChunkIndex,
tags::Vector{T}; kwargs...) where {T <: Union{AbstractString, Regex}}
Finds the indices of chunks (represented by tags in `index`) that have ALL OF the specified `tag` or `tags`.
"""
function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
tags_vec::Vector{T}; kwargs...) where {T <: Union{AbstractString, Regex}}
isnothing(tags(index)) && CandidateChunks(; index_id = index.id)
tag_idx = Int[]
for tag in tags_vec
if tag isa AbstractString
append!(tag_idx, findall(tags_vocab(index) .== tag))
else # assume it's a regex
append!(tag_idx, findall(occursin.(Ref(tag), tags_vocab(index))))
end
end
## get rows with all values true
match_row_idx = if length(tag_idx) > 0
reduce(.&, eachcol(@view(tags(index)[:, tag_idx]))) |> findall
else
Int[]
end
return CandidateChunks(index.id, match_row_idx, ones(Float32, length(match_row_idx)))
end
function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
tag::Union{AbstractString, Regex}; kwargs...)
find_tags(method, index, [tag]; kwargs...)
end

"""
find_tags(method::NoTagFilter, index::AbstractChunkIndex,
tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Expand All @@ -516,8 +556,8 @@ function find_tags(method::NoTagFilter, index::AbstractChunkIndex,
AbstractString, Regex, Nothing}}
return nothing
end
## Multi-index implementation
function find_tags(method::AnyTagFilter, index::AbstractMultiIndex,
## Multi-index implementation -- logic differs within each index and then we simply vcat them together
function find_tags(method::Union{AnyTagFilter, AllTagFilter}, index::AbstractMultiIndex,
tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <:
Union{AbstractString, Regex}}
all_indexes = indexes(index)
Expand Down
59 changes: 58 additions & 1 deletion src/Experimental/RAGTools/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,4 +591,61 @@ function reciprocal_rank_fusion(args...; k::Int = 60)
merged = [first(item) for item in sort(collect(scores), by = last, rev = true)]

return merged, scores
end
end

"""
reciprocal_rank_fusion(
positions1::AbstractVector{<:Integer}, scores1::AbstractVector{<:T},
positions2::AbstractVector{<:Integer},
scores2::AbstractVector{<:T}; k::Int = 60) where {T <: Real}
Merges two sets of rankings and their joint scores. Calculates the reciprocal rank score for each chunk (discounted by the inverse of the rank).
# Example
```julia
positions1 = [1, 3, 5, 7, 9]
scores1 = [0.9, 0.8, 0.7, 0.6, 0.5]
positions2 = [2, 4, 6, 8, 10]
scores2 = [0.5, 0.6, 0.7, 0.8, 0.9]
merged, scores = reciprocal_rank_fusion(positions1, scores1, positions2, scores2; k = 60)
```
"""
function reciprocal_rank_fusion(
positions1::AbstractVector{<:Integer}, scores1::AbstractVector{<:T},
positions2::AbstractVector{<:Integer},
scores2::AbstractVector{<:T}; k::Int = 60) where {T <: Real}
merged = Vector{Int}()
scores = Dict{Int, T}()

for (idx, (pos, sc)) in enumerate(zip(positions1, scores1))
scores[pos] = get(scores, pos, 0.0) + sc / (k + idx)
end
for (idx, (pos, sc)) in enumerate(zip(positions2, scores2))
scores[pos] = get(scores, pos, 0.0) + sc / (k + idx)
end

merged = [first(item) for item in sort(collect(scores), by = last, rev = true)]

return merged, scores
end

"""
score_to_unit_scale(x::AbstractVector{T}) where T<:Real
Shift and scale a vector of scores to the unit scale [0, 1].
# Example
```julia
x = [1.0, 2.0, 3.0, 4.0, 5.0]
scaled_x = score_to_unit_scale(x)
```
"""
function score_to_unit_scale(x::AbstractVector{T}) where {T <: Real}
ex = extrema(x)
if ex[2] - ex[1] < eps(T)
ones(T, length(x))
else
(x .- ex[1]) ./ (ex[2] - ex[1] + eps(T))
end
end
21 changes: 20 additions & 1 deletion test/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using PromptingTools.Experimental.RAGTools: ContextEnumerator, NoRephraser, Simp
HyDERephraser,
CosineSimilarity, BinaryCosineSimilarity,
MultiFinder, BM25Similarity,
NoTagFilter, AnyTagFilter,
NoTagFilter, AllTagFilter, AnyTagFilter,
SimpleRetriever, AdvancedRetriever
using PromptingTools.Experimental.RAGTools: AbstractRephraser, AbstractTagFilter,
AbstractSimilarityFinder, AbstractReranker,
Expand Down Expand Up @@ -431,6 +431,20 @@ end
# Test with multiple tags in vocab
@test find_tags(tagger, index, ["python", "jr", "x"]).positions == [2]

## With AllTagFilter -- no difference for individual
tagger2 = AllTagFilter()
@test find_tags(tagger2, index, "julia").positions == [1]
@test find_tags(tagger2, index, "julia").scores == [1.0]
@test find_tags(tagger2, index, "python").positions |> isempty
@test find_tags(tagger2, index, "java").positions |> isempty
@test find_tags(tagger2, index, r"^j").positions |> isempty
@test find_tags(tagger2, index, "jr").positions == [2]

@test find_tags(tagger2, index, ["python", "jr", "x"]).positions |> isempty
@test find_tags(tagger2, index, ["julia", "jr"]).positions |> isempty
@test find_tags(tagger2, index, ["julia", "julia"]).positions == [1]
@test find_tags(tagger2, index, ["julia", "julia"]).scores == [1.0]

# No filter tag -- give everything
cc = find_tags(NoTagFilter(), index, "julia")
@test isnothing(cc)
Expand Down Expand Up @@ -483,6 +497,11 @@ end
@test mcc4.index_ids == [:indexX, :indexX]
@test mcc4.positions == [1, 2]
@test mcc4.scores == [1.0, 1.0]

mcc5 = find_tags(AllTagFilter(), multi_index2, [r"^j"])
@test mcc5.index_ids |> isempty
@test mcc5.positions |> isempty
@test mcc5.scores |> isempty
end

@testset "rerank" begin
Expand Down
56 changes: 55 additions & 1 deletion test/Experimental/RAGTools/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using PromptingTools.Experimental.RAGTools: split_into_code_and_sentences
using PromptingTools.Experimental.RAGTools: getpropertynested, setpropertynested,
merge_kwargs_nested
using PromptingTools.Experimental.RAGTools: pack_bits, unpack_bits, preprocess_tokens,
reciprocal_rank_fusion
reciprocal_rank_fusion, score_to_unit_scale

@testset "_check_aiextract_capability" begin
@test _check_aiextract_capability("gpt-3.5-turbo") == nothing
Expand Down Expand Up @@ -598,4 +598,58 @@ end
@test Set(positions[1:2]) == Set([1, 3])
@test Set(positions[3:4]) == Set([2, 4])
@test positions[5] == 5

## Paired reciprocal rank
positions1 = [1, 2, 3, 4, 5]
scores1 = [0.9, 0.8, 0.7, 0.6, 0.5]
positions2 = [3, 4, 5, 6, 7]
scores2 = [0.5, 0.6, 0.7, 0.9, 0.9]

merged, scores = reciprocal_rank_fusion(positions1, scores1, positions2, scores2; k = 0)
@test length(merged) == 7
@test Set(merged) == Set(1:7)
@test merged[1] == 1
@test scores[1] == 0.9
@test merged[2] == 3
@test scores[3] == 0.7 / 3 + 0.5
@test merged[end] == 7
@test scores[7] == 0.9 / 5

merged, scores = reciprocal_rank_fusion(
positions1, scores1, positions2, scores2; k = 60)
@test length(merged) == 7
@test merged[1] == 3
@test merged[2] == 4
@test merged[3] == 5
@test scores[3] > scores[4]
@test scores[4] > scores[5]
@test scores[5] > scores[6]
@test scores[6] > scores[7]
end

@testset "score_to_unit_scale" begin
# Test with a normal range of values
x = [1.0, 2.0, 3.0, 4.0, 5.0]
scaled_x = score_to_unit_scale(x)
@test extrema(scaled_x) == (0.0, 1.0)

# Test with all values the same
y = [2.0, 2.0, 2.0, 2.0, 2.0]
scaled_y = score_to_unit_scale(y)
@test all(scaled_y .== 1.0)

# Test with a single value
z = [3.0]
scaled_z = score_to_unit_scale(z)
@test scaled_z == [1.0]

# Test with a range of negative values
w = [-5.0, -4.0, -3.0, -2.0, -1.0]
scaled_w = score_to_unit_scale(w)
@test extrema(scaled_w) == (0.0, 1.0)

# Test with a mix of positive and negative values
v = [-1.0, 0.0, 1.0]
scaled_v = score_to_unit_scale(v)
@test extrema(scaled_v) == (0.0, 1.0)
end

0 comments on commit dfb88a1

Please sign in to comment.