From 1473799bbeec575cddcfad8cf03c4b2deab6098a Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sun, 4 Aug 2024 20:32:00 +0100 Subject: [PATCH] Fix getindex --- CHANGELOG.md | 9 ++++++++- Project.toml | 2 +- src/Experimental/RAGTools/types.jl | 4 ++-- src/utils.jl | 6 +++--- test/Experimental/RAGTools/types.jl | 6 +++--- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a33f16b94..3bbd32f79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,10 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -## [0.44.1] +## [0.45.0] + +### Breaking Change +- `getindex(::MultiIndex, ::MultiCandidateChunks)` now returns sorted chunks by default (`sorted=true`) to guarantee that potential `context` (=`chunks`) is sorted by descending similarity score across different sub-indices. ### Updated - Updated a `hcat` implementation in `RAGTools.get_embeddings` to reduce memory allocations for large embedding batches (c. 3x fewer allocations, see `hcat_truncate`). +- Updated `length_longest_common_subsequence` signature to work only for pairs of `AbstractString` to not fail silently when wrong arguments are provided. + +### Fixed +- Changed the default behavior of `getindex(::MultiIndex, ::MultiCandidateChunks)` to always return sorted chunks for consistency with other similar functions and correct `retrieve` behavior. This was accidentally changed in v0.40 and is now reverted to the original behavior. ## [0.44.0] diff --git a/Project.toml b/Project.toml index 77b37816f..3637499ff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.44.1" +version = "0.45.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 9582a28bf..3c3b970f2 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -911,10 +911,10 @@ function Base.getindex(ci::AbstractChunkIndex, getindex(ci, cc, field; sorted) end # Getindex on Multiindex, pool the individual hits -# Sorted defaults to false --> similarly to Dict which doesn't guarantee ordering of values returned +# Sorted defaults to true because we need to guarantee that potential `context` is sorted by score across different indices function Base.getindex(mi::MultiIndex, candidate::MultiCandidateChunks{TP, TD}, - field::Symbol = :chunks; sorted::Bool = false) where {TP <: Integer, TD <: Real} + field::Symbol = :chunks; sorted::Bool = true) where {TP <: Integer, TD <: Real} @assert field in [:chunks, :sources, :scores] "Only `chunks`, `sources`, and `scores` fields are supported for now" if sorted # values can be either of chunks or sources diff --git a/src/utils.jl b/src/utils.jl index 20523a32b..9d93c7b34 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -250,9 +250,9 @@ function wrap_string(str::AbstractString, end; """ - length_longest_common_subsequence(itr1, itr2) + length_longest_common_subsequence(itr1::AbstractString, itr2::AbstractString) -Compute the length of the longest common subsequence between two sequences (ie, the higher the number, the better the match). +Compute the length of the longest common subsequence between two string sequences (ie, the higher the number, the better the match). Source: https://cn.julialang.org/LeetCode.jl/dev/democards/problems/problems/1143.longest-common-subsequence/ @@ -286,7 +286,7 @@ But it might be easier to use directly the convenience wrapper `distance_longest ``` """ -function length_longest_common_subsequence(itr1, itr2) +function length_longest_common_subsequence(itr1::AbstractString, itr2::AbstractString) m, n = length(itr1) + 1, length(itr2) + 1 dp = fill(0, m, n) diff --git a/test/Experimental/RAGTools/types.jl b/test/Experimental/RAGTools/types.jl index 6b6bad0d6..4ccaee0f8 100644 --- a/test/Experimental/RAGTools/types.jl +++ b/test/Experimental/RAGTools/types.jl @@ -831,13 +831,13 @@ end # with MultiIndex mi = MultiIndex(; id = :multi, indexes = [ci1, ci2]) - @test mi[cc] == ["chunk2", "chunk2x"] # default is sorted=false + @test mi[cc] == ["chunk2", "chunk2x"] # default is sorted=true @test Base.getindex(mi, cc, :chunks; sorted = true) == ["chunk2", "chunk2x"] @test Base.getindex(mi, cc, :chunks; sorted = false) == ["chunk2", "chunk2x"] # with MultiIndex -- flip the order of indices mi = MultiIndex(; id = :multi, indexes = [ci2, ci1]) - @test mi[cc] == ["chunk2x", "chunk2"] # default is sorted=false + @test mi[cc] == ["chunk2", "chunk2x"] # default is sorted=true @test Base.getindex(mi, cc, :chunks; sorted = true) == ["chunk2", "chunk2x"] @test Base.getindex(mi, cc, :chunks; sorted = false) == ["chunk2x", "chunk2"] end @@ -904,7 +904,7 @@ end scores = [0.5, 0.7]) ## sorted=false by default (Dict-like where order isn't guaranteed) ## sorting follows index order - @test mi[mc1] == ["First chunk", "6"] + @test mi[mc1] == ["6", "First chunk"] @test Base.getindex(mi, mc1, :chunks; sorted = true) == ["6", "First chunk"] @test Base.getindex(mi, mc1, :sources; sorted = true) == ["other_source3", "test_source1"]