diff --git a/CHANGELOG.md b/CHANGELOG.md index b8236ebd3..0efb9b821 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.41.0] + +### Added +- Introduced a "view" of `DocumentTermMatrix` (=`SubDocumentTermMatrix`) to allow views of Keyword-based indices (`ChunkKeywordsIndex`). It's not a pure view (TF matrix is materialized to prevent performance degradation). + +### Fixed +- Fixed a bug in `find_closest(finder::BM25Similarity, ...)` where the view of `DocumentTermMatrix` (ie, `view(DocumentTermMatrix(...), ...)`) was undefined. +- Fixed a bug where a view of a view of a `ChunkIndex` wouldn't intersect the positions (it was returning only the latest requested positions). + ## [0.40.0] ### Added diff --git a/Project.toml b/Project.toml index 8ae06a12d..2147675a9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.40.0" +version = "0.41.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/ext/RAGToolsExperimentalExt.jl b/ext/RAGToolsExperimentalExt.jl index a0c4f8385..f6ad0491d 100644 --- a/ext/RAGToolsExperimentalExt.jl +++ b/ext/RAGToolsExperimentalExt.jl @@ -5,6 +5,7 @@ using LinearAlgebra const PT = PromptingTools using PromptingTools.Experimental.RAGTools +using PromptingTools.Experimental.RAGTools: tf, vocab, vocab_lookup, idf, doc_rel_length const RT = PromptingTools.Experimental.RAGTools # forward to LinearAlgebra.normalize @@ -92,7 +93,7 @@ function RT.document_term_matrix(documents::AbstractVector{<:AbstractString}) end """ - RT.bm25(dtm::DocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0) + RT.bm25(dtm::AbstractDocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0) Scores all documents in `dtm` based on the `query`. @@ -107,30 +108,32 @@ scores = bm25(dtm, query) # Returns array with 3 scores (one for each document) ``` """ -function RT.bm25(dtm::RT.DocumentTermMatrix, query::AbstractVector{<:AbstractString}; +function RT.bm25( + dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString}; k1::Float32 = 1.2f0, b::Float32 = 0.75f0) - scores = zeros(Float32, size(dtm.tf, 1)) + scores = zeros(Float32, size(tf(dtm), 1)) ## Identify non-zero items to leverage the sparsity - nz_rows = rowvals(dtm.tf) - nz_vals = nonzeros(dtm.tf) + nz_rows = rowvals(tf(dtm)) + nz_vals = nonzeros(tf(dtm)) for i in eachindex(query) t = query[i] - t_id = get(dtm.vocab_lookup, t, nothing) + t_id = get(vocab_lookup(dtm), t, nothing) t_id === nothing && continue - idf = dtm.idf[t_id] + idf_ = idf(dtm)[t_id] # Scan only documents that have this token - @inbounds @simd for j in nzrange(dtm.tf, t_id) + @inbounds @simd for j in nzrange(tf(dtm), t_id) ## index into the sparse matrix - di, tf = nz_rows[j], nz_vals[j] - doc_len = dtm.doc_rel_length[di] - tf_top = (tf * (k1 + 1.0f0)) - tf_bottom = (tf + k1 * (1.0f0 - b + b * doc_len)) - score = idf * tf_top / tf_bottom + di, tf_ = nz_rows[j], nz_vals[j] + doc_len = doc_rel_length(dtm)[di] + tf_top = (tf_ * (k1 + 1.0f0)) + tf_bottom = (tf_ + k1 * (1.0f0 - b + b * doc_len)) + score = idf_ * tf_top / tf_bottom ## @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score" scores[di] += score end end - scores + + return scores end end # end of module diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 51bcd9d8f..301841776 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -234,6 +234,8 @@ function find_closest( positions, scores = find_closest(finder, chunkdata(index), query_emb, query_tokens; top_k, kwargs...) + ## translate positions to original indices + positions = translate_positions_to_parent(index, positions) return CandidateChunks(indexid(index), positions, Float32.(scores)) end @@ -274,6 +276,8 @@ function find_closest( positions_, scores_ = find_closest(finder[i], chunkdata(all_indexes[i]), query_emb, query_tokens; top_k = top_k_shard, kwargs...) + ## translate positions to original indices + positions_ = translate_positions_to_parent(all_indexes[i], positions_) append!(index_ids, fill(indexid(all_indexes[i]), length(positions_))) append!(positions, positions_) append!(scores, scores_) @@ -446,7 +450,7 @@ end """ find_closest( - finder::BM25Similarity, dtm::DocumentTermMatrix, + finder::BM25Similarity, dtm::AbstractDocumentTermMatrix, query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[]; top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...) @@ -456,7 +460,7 @@ Reference: [Wikipedia: BM25](https://en.wikipedia.org/wiki/Okapi_BM25). Implementation follows: [The Next Generation of Lucene Relevance](https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/). """ function find_closest( - finder::BM25Similarity, dtm::DocumentTermMatrix, + finder::BM25Similarity, dtm::AbstractDocumentTermMatrix, query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[]; top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...) scores = bm25(dtm, query_tokens) @@ -505,9 +509,7 @@ function find_tags(method::AnyTagFilter, index::AbstractChunkIndex, match_row_idx = @view(tags(index)[:, tag_idx]) |> findall .|> Base.Fix2(getindex, 1) |> unique ## Index can be a SubChunkIndex, so we need to convert to the original indices - if index isa SubChunkIndex - match_row_idx = positions(index)[match_row_idx] - end + match_row_idx = translate_positions_to_parent(index, match_row_idx) return CandidateChunks( indexid(index), match_row_idx, ones(Float32, length(match_row_idx))) end @@ -546,10 +548,8 @@ function find_tags(method::AllTagFilter, index::AbstractChunkIndex, else Int[] end - ## Index can be a SubChunkIndex, so we need to convert to the original indices - if index isa SubChunkIndex - match_row_idx = positions(index)[match_row_idx] - end + ## translate to original indices + match_row_idx = translate_positions_to_parent(index, match_row_idx) return CandidateChunks( indexid(index), match_row_idx, ones(Float32, length(match_row_idx))) end diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 5e72520f3..942524535 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -1,13 +1,26 @@ - # More advanced index would be: HybridChunkIndex +using Base: parent ### Shared methods Base.parent(index::AbstractDocumentIndex) = index indexid(index::AbstractDocumentIndex) = index.id chunkdata(index::AbstractChunkIndex) = index.chunkdata +"Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index" +function chunkdata(index::AbstractChunkIndex, chunk_idx::AbstractVector{<:Integer}) + ## We need this accessor because different chunk indices can have chunks in different dimensions!! + chkdata = chunkdata(index) + if isnothing(chkdata) + return nothing + end + return view(chkdata, :, chunk_idx) +end + function chunkdata(index::AbstractDocumentIndex) throw(ArgumentError("`chunkdata` not implemented for $(typeof(index))")) end +function chunkdata(index::AbstractDocumentIndex, chunk_idx::AbstractVector{<:Integer}) + throw(ArgumentError("`chunkdata` not implemented for $(typeof(index)) and chunk indices: $(typeof(chunk_idx))")) +end function embeddings(index::AbstractDocumentIndex) throw(ArgumentError("`embeddings` not implemented for $(typeof(index))")) end @@ -29,6 +42,18 @@ tags_vocab(index::AbstractChunkIndex) = index.tags_vocab sources(index::AbstractChunkIndex) = index.sources extras(index::AbstractChunkIndex) = index.extras +""" + translate_positions_to_parent(index::AbstractChunkIndex, positions::AbstractVector{<:Integer}) + +Translate positions to the parent index. Useful to convert between positions in a view and the original index. + +Used whenever a `chunkdata()` is used to re-align positions in case index is a view. +""" +function translate_positions_to_parent( + index::AbstractChunkIndex, positions::AbstractVector{<:Integer}) + return positions +end + Base.var"=="(i1::AbstractChunkIndex, i2::AbstractChunkIndex) = false function Base.var"=="(i1::T, i2::T) where {T <: AbstractChunkIndex} ((sources(i1) == sources(i2)) && (tags_vocab(i1) == tags_vocab(i2)) && @@ -104,38 +129,118 @@ end embeddings(index::ChunkEmbeddingsIndex) = index.embeddings HasEmbeddings(::ChunkEmbeddingsIndex) = true chunkdata(index::ChunkEmbeddingsIndex) = embeddings(index) +# It's column aligned so we don't have to re-define `chunkdata(index, chunk_idx)` # For backward compatibility const ChunkIndex = ChunkEmbeddingsIndex +abstract type AbstractDocumentTermMatrix end """ DocumentTermMatrix{T<:AbstractString} A sparse matrix of term frequencies and document lengths to allow calculation of BM25 similarity scores. """ -struct DocumentTermMatrix{T1 <: AbstractMatrix{<:Real}, T2 <: AbstractString} +struct DocumentTermMatrix{ + T1 <: AbstractMatrix{<:Real}, T2 <: AbstractString} <: + AbstractDocumentTermMatrix ## assumed to be SparseMatrixCSC{Float32, Int64} tf::T1 vocab::Vector{T2} vocab_lookup::Dict{T2, Int} - idf::Vector{Float32} + idf::Vector{Float32} # length of vocab # |d|/avgDl doc_rel_length::Vector{Float32} end +function Base.parent(dtm::AbstractDocumentTermMatrix) + dtm +end +function tf(dtm::AbstractDocumentTermMatrix) + dtm.tf +end +function vocab(dtm::AbstractDocumentTermMatrix) + dtm.vocab +end +function vocab_lookup(dtm::AbstractDocumentTermMatrix) + dtm.vocab_lookup +end +function idf(dtm::AbstractDocumentTermMatrix) + dtm.idf +end +function doc_rel_length(dtm::AbstractDocumentTermMatrix) + dtm.doc_rel_length +end + +Base.var"=="(dtm1::AbstractDocumentTermMatrix, dtm2::AbstractDocumentTermMatrix) = false +# Must be the same type and same content +function Base.var"=="(dtm1::T, dtm2::T) where {T <: AbstractDocumentTermMatrix} + tf(dtm1) == tf(dtm2) && vocab(dtm1) == vocab(dtm2) && + vocab_lookup(dtm1) == vocab_lookup(dtm2) && idf(dtm1) == idf(dtm2) && + doc_rel_length(dtm1) == doc_rel_length(dtm2) +end +function Base.hcat(d1::AbstractDocumentTermMatrix, d2::AbstractDocumentTermMatrix) + throw(ArgumentError("A hcat not implemented for DTMs of type $(typeof(d1)) and $(typeof(d2))")) +end function Base.hcat(d1::DocumentTermMatrix, d2::DocumentTermMatrix) - tf, vocab = vcat_labeled_matrices(d1.tf, d1.vocab, d2.tf, d2.vocab) - vocab_lookup = Dict(t => i for (i, t) in enumerate(vocab)) + tf_, vocab_ = vcat_labeled_matrices(tf(d1), vocab(d1), tf(d2), vocab(d2)) + vocab_lookup_ = Dict(t => i for (i, t) in enumerate(vocab_)) - N, _ = size(tf) - doc_freq = [count(x -> x > 0, col) for col in eachcol(tf)] + N, _ = size(tf_) + doc_freq = [count(x -> x > 0, col) for col in eachcol(tf_)] idf = @. log(1.0f0 + (N - doc_freq + 0.5f0) / (doc_freq + 0.5f0)) - doc_lengths = [count(x -> x > 0, row) for row in eachrow(tf)] + doc_lengths = [count(x -> x > 0, row) for row in eachrow(tf_)] sumdl = sum(doc_lengths) - doc_rel_length = sumdl == 0 ? zeros(Float32, N) : (doc_lengths ./ (sumdl / N)) + doc_rel_length_ = sumdl == 0 ? zeros(Float32, N) : (doc_lengths ./ (sumdl / N)) return DocumentTermMatrix( - tf, vocab, vocab_lookup, idf, convert(Vector{Float32}, doc_rel_length)) + tf_, vocab_, vocab_lookup_, idf, convert(Vector{Float32}, doc_rel_length_)) +end + +"A partial view of a DocumentTermMatrix, `tf` is MATERIALIZED for performance and fewer allocations." +struct SubDocumentTermMatrix{T <: DocumentTermMatrix, + T1 <: AbstractMatrix{<:Real}} <: AbstractDocumentTermMatrix + parent::T + tf::T1 ## Materialize the sub-matrix, because it's too expensive to use otherwise (row-view of SparseMatrixCSC) + positions::Vector{Int} +end +Base.parent(dtm::SubDocumentTermMatrix) = dtm.parent +positions(dtm::SubDocumentTermMatrix) = dtm.positions +tf(dtm::SubDocumentTermMatrix) = dtm.tf +vocab(dtm::SubDocumentTermMatrix) = Base.parent(dtm) |> vocab +vocab_lookup(dtm::SubDocumentTermMatrix) = Base.parent(dtm) |> vocab_lookup +idf(dtm::SubDocumentTermMatrix) = Base.parent(dtm) |> idf +function doc_rel_length(dtm::SubDocumentTermMatrix) + view(doc_rel_length(Base.parent(dtm)), positions(dtm)) +end +# hcat for SubDocumentTermMatrix does not make sense -> the vocabulary is the same / shared + +function Base.view( + dtm::AbstractDocumentTermMatrix, doc_idx::AbstractVector{<:Integer}, token_idx) + throw(ArgumentError("A view not implemented for type $(typeof(dtm)) across docs: $(typeof(doc_idx)) and tokens: $(typeof(token_idx))")) +end +Base.@propagate_inbounds function Base.view( + dtm::AbstractDocumentTermMatrix, doc_idx::AbstractVector{<:Integer}, token_idx::Colon) + tf_mat = tf(parent(dtm)) + @boundscheck if !checkbounds(Bool, axes(tf_mat, 1), doc_idx) + ## Avoid printing huge position arrays, show the extremas of the attempted range + max_pos = extrema(doc_idx) + throw(BoundsError(tf_mat, max_pos)) + end + ## computations on top of views of sparse arrays are expensive, materialize the view + tf_ = tf_mat[doc_idx, :] + SubDocumentTermMatrix(dtm, tf_, collect(doc_idx)) +end +function Base.view( + dtm::SubDocumentTermMatrix, doc_idx::AbstractVector{<:Integer}, token_idx::Colon) + tf_mat = tf(parent(dtm)) + @boundscheck if !checkbounds(Bool, axes(tf_mat, 1), doc_idx) + ## Avoid printing huge position arrays, show the extremas of the attempted range + max_pos = extrema(doc_idx) + throw(BoundsError(tf_mat, max_pos)) + end + intersect_pos = intersect(positions(dtm), doc_idx) + return SubDocumentTermMatrix( + parent(dtm), tf_mat[intersect_pos, :], intersect_pos) end """ @@ -169,7 +274,7 @@ multi_index = MultiIndex([index, index_keywords]) ``` -You can also build the index via +You can also build the index via build_index ```julia # given some sentences and sources index_keywords = build_index(KeywordsIndexer(), sentences; chunker_kwargs=(; sources)) @@ -179,6 +284,14 @@ retriever = SimpleBM25Retriever() result = retrieve(retriever, index_keywords, "What are the best practices for parallel computing in Julia?") result.context ``` + +If you want to use airag, don't forget to specify the config to make sure keywords are processed (ie, tokenized) + and that BM25 is used for searching candidates +```julia +cfg = RAGConfig(; retriever = SimpleBM25Retriever()); +airag(cfg, index_keywords; + question = "What are the best practices for parallel computing in Julia?") +``` """ @kwdef struct ChunkKeywordsIndex{ T1 <: AbstractString, @@ -201,8 +314,51 @@ result.context end HasKeywords(::ChunkKeywordsIndex) = true +"Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index" +function chunkdata(index::ChunkKeywordsIndex, chunk_idx::AbstractVector{<:Integer}) + chkdata = index.chunkdata + if isnothing(chkdata) + return nothing + end + ## Keyword index is row-oriented, ie, chunks are rows, tokens are columns + return view(chkdata, chunk_idx, :) +end + +""" + MultiIndex + +Composite index that stores multiple ChunkIndex objects and their embeddings. + +# Fields +- `id::Symbol`: unique identifier of each index (to ensure we're using the right index with `CandidateChunks`) +- `indexes::Vector{<:AbstractChunkIndex}`: the indexes to be combined + +Use accesor `indexes` to access the individual indexes. + +# Examples + +We can create a `MultiIndex` from a vector of `AbstractChunkIndex` objects. +```julia +index = build_index(SimpleIndexer(), texts; chunker_kwargs = (; sources)) +index_keywords = ChunkKeywordsIndex(index) # same chunks as above but adds BM25 instead of embeddings + +multi_index = MultiIndex([index, index_keywords]) +``` + +To use `airag` with different types of indices, we need to specify how to find the closest items for each index +```julia +# Cosine similarity for embeddings and BM25 for keywords, same order as indexes in MultiIndex +finder = RT.MultiFinder([RT.CosineSimilarity(), RT.BM25Similarity()]) + +# Notice that we add `processor` to make sure keywords are processed (ie, tokenized) as well +cfg = RAGConfig(; retriever = SimpleRetriever(; processor = RT.KeywordsProcessor(), finder)) -"Composite index that stores multiple ChunkIndex objects and their embeddings. It's not yet fully implemented." +# Ask questions +msg = airag(cfg, multi_index; question = "What are the best practices for parallel computing in Julia?") +pprint(msg) # prettify the answer +``` + +""" @kwdef struct MultiIndex <: AbstractMultiIndex id::Symbol = gensym("MultiIndex") indexes::Vector{<:AbstractChunkIndex} = AbstractChunkIndex[] @@ -284,9 +440,14 @@ HasKeywords(index::SubChunkIndex) = HasKeywords(parent(index)) chunks(index::SubChunkIndex) = view(chunks(parent(index)), positions(index)) sources(index::SubChunkIndex) = view(sources(parent(index)), positions(index)) function chunkdata(index::SubChunkIndex) - chkdata = chunkdata(parent(index)) - isnothing(chkdata) && return nothing - view(chunkdata(parent(index)), :, positions(index)) + chkdata = chunkdata(parent(index), positions(index)) +end +"Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index" +function chunkdata(index::SubChunkIndex, chunk_idx::AbstractVector{<:Integer}) + ## We need this accessor because different chunk indices can have chunks in different dimensions!! + index_chunk_idx = translate_positions_to_parent(index, chunk_idx) + pos = intersect(positions(index), index_chunk_idx) + chkdata = chunkdata(parent(index), pos) end function embeddings(index::SubChunkIndex) if HasEmbeddings(index) @@ -332,6 +493,20 @@ function Base.show(io::IO, index::SubChunkIndex) "A view of $(typeof(parent(index))|>nameof) (id: $(indexid(parent(index)))) with $(length(index)) chunks") end +""" + translate_positions_to_parent( + index::SubChunkIndex, pos::AbstractVector{<:Integer}) + +Translate positions to the parent index. Useful to convert between positions in a view and the original index. + +Used whenever a `chunkdata()` or `tags()` are used to re-align positions to the "parent" index. +""" +function translate_positions_to_parent( + index::SubChunkIndex, pos::AbstractVector{<:Integer}) + sub_positions = positions(index) + return sub_positions[pos] +end + # # CandidateChunks for Retrieval """ @@ -592,7 +767,7 @@ function Base.var"&"(mc1::MultiCandidateChunks{TP1, TD1}, return MultiCandidateChunks(index_ids, positions_, scores_) end -# # Views and Getindex +# # Index Views and Getindex function Base.view(index::AbstractDocumentIndex, cc::AbstractCandidateChunks) throw(ArgumentError("Not implemented for type $(typeof(index)) and $(typeof(cc))")) end @@ -604,7 +779,11 @@ Base.@propagate_inbounds function Base.view(index::AbstractChunkIndex, cc::Candi throw(BoundsError(chk_vector, max_pos)) end end - return SubChunkIndex(parent(index), positions(cc)) + pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] + return SubChunkIndex(parent(index), pos) +end +Base.@propagate_inbounds function Base.view(index::SubChunkIndex, cc::CandidateChunks) + SubChunkIndex(index, cc) end Base.@propagate_inbounds function Base.view( index::AbstractChunkIndex, cc::MultiCandidateChunks) @@ -619,8 +798,12 @@ Base.@propagate_inbounds function Base.view( end return SubChunkIndex(parent(index), valid_positions) end +Base.@propagate_inbounds function Base.view(index::SubChunkIndex, cc::MultiCandidateChunks) + SubChunkIndex(index, cc) +end Base.@propagate_inbounds function SubChunkIndex(index::SubChunkIndex, cc::CandidateChunks) - intersect_pos = intersect(positions(cc), positions(index)) + pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] + intersect_pos = intersect(pos, positions(index)) @boundscheck let chk_vector = chunks(parent(index)) if !checkbounds(Bool, axes(chk_vector, 1), intersect_pos) ## Avoid printing huge position arrays, show the extremas of the attempted range @@ -667,8 +850,9 @@ function Base.getindex(ci::AbstractChunkIndex, if field == :chunks chunks(sub_index)[sorted_idx] elseif field == :chunkdata - chkdata = chunkdata(sub_index) - isnothing(chkdata) ? nothing : chkdata[:, sorted_idx] + ## If embeddings, chunks are columns + ## If keywords (DTM), chunks are rows + chkdata = chunkdata(sub_index, sorted_idx) elseif field == :sources sources(sub_index)[sorted_idx] elseif field == :scores diff --git a/test/Experimental/RAGTools/types.jl b/test/Experimental/RAGTools/types.jl index 0859b36de..6b6bad0d6 100644 --- a/test/Experimental/RAGTools/types.jl +++ b/test/Experimental/RAGTools/types.jl @@ -3,14 +3,17 @@ using PromptingTools.Experimental.RAGTools: ChunkEmbeddingsIndex, ChunkKeywordsI CandidateChunks, MultiCandidateChunks, AbstractCandidateChunks, DocumentTermMatrix, + SubDocumentTermMatrix, document_term_matrix, HasEmbeddings, HasKeywords, ChunkKeywordsIndex, AbstractChunkIndex, AbstractDocumentIndex using PromptingTools.Experimental.RAGTools: embeddings, chunks, tags, tags_vocab, sources, extras, positions, scores, parent, - RAGResult, chunkdata, preprocess_tokens -using PromptingTools.Experimental.RAGTools: SubChunkIndex, indexid, indexids + RAGResult, chunkdata, preprocess_tokens, tf, + vocab, vocab_lookup, idf, doc_rel_length +using PromptingTools.Experimental.RAGTools: SubChunkIndex, indexid, indexids, + translate_positions_to_parent using PromptingTools: last_message, last_output @testset "ChunkEmbeddingsIndex" begin @@ -29,10 +32,13 @@ using PromptingTools: last_message, last_output @test chunks(ci) == chunks_test @test (embeddings(ci)) == emb_test @test (chunkdata(ci)) == emb_test + @test chunkdata(ci, [1]) == view(emb_test, :, [1]) @test tags(ci) == tags_test @test tags_vocab(ci) == tags_vocab_test @test sources(ci) == sources_test @test length(ci) == 2 + @test translate_positions_to_parent(ci, [2, 1]) == [2, 1] + @test translate_positions_to_parent(ci, [4, 6]) == [4, 6] # Test identity/equality ci1 = ChunkEmbeddingsIndex( @@ -73,6 +79,8 @@ using PromptingTools: last_message, last_output length(tags_vocab(combined_ci)) @test sources(combined_ci) == vcat(sources(ci1), (sources(ci2))) @test length(combined_ci) == 4 + @test chunkdata(combined_ci) == nothing + @test chunkdata(combined_ci, [1]) == nothing # Test base var"==" with ChunkEmbeddingsIndex ci1 = ChunkEmbeddingsIndex(chunks = ["chunk1"], @@ -125,6 +133,12 @@ end @test tags(ci) == nothing @test tags_vocab(ci) == nothing @test extras(ci) == nothing + @test translate_positions_to_parent(ci, [1]) == [1] + @test translate_positions_to_parent(ci, [2, 1]) == [2, 1] + @test translate_positions_to_parent(ci, [4, 6]) == [4, 6] + @test translate_positions_to_parent(ci, Int[]) == Int[] + @test chunkdata(ci) == nothing + @test chunkdata(ci, [1]) == nothing # Test equality of ChunkKeywordsIndex chunks_ = ["this is a test", "this is another test", "foo bar baz"] @@ -133,6 +147,8 @@ end ci1 = ChunkKeywordsIndex(chunks = chunks_, sources = sources_, chunkdata = dtm) ci2 = ChunkKeywordsIndex(chunks = chunks_, sources = sources_, chunkdata = dtm) @test ci1 == ci2 + @test chunkdata(ci1) == dtm + @test chunkdata(ci1, [1]) == view(dtm, [1], :) ci3 = ChunkKeywordsIndex(chunks = ["chunk2"], sources = ["source2"]) @test ci1 != ci3 @@ -160,63 +176,115 @@ end @test_throws ArgumentError embeddings(ci1) end -@testset "DocumentTermMatrix" begin - # Simple case - documents = [["this", "is", "a", "test"], - ["this", "is", "another", "test"], ["foo", "bar", "baz"]] - dtm = document_term_matrix(documents) - @test size(dtm.tf) == (3, 8) - @test Set(dtm.vocab) == Set(["a", "another", "bar", "baz", "foo", "is", "test", "this"]) - avgdl = 3.666666666666667 - @test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl]) - @test length(dtm.idf) == 8 - - # Edge case: single document - documents = [["this", "is", "a", "test"]] - dtm = document_term_matrix(documents) - @test size(dtm.tf) == (1, 4) - @test Set(dtm.vocab) == Set(["a", "is", "test", "this"]) - @test dtm.doc_rel_length == ones(1) - @test length(dtm.idf) == 4 - - # Edge case: duplicate tokens - documents = [["this", "is", "this", "test"], - ["this", "is", "another", "test"], ["this", "bar", "baz"]] - dtm = document_term_matrix(documents) - @test size(dtm.tf) == (3, 6) - @test Set(dtm.vocab) == Set(["another", "bar", "baz", "is", "test", "this"]) - avgdl = 3.666666666666667 - @test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl]) - @test length(dtm.idf) == 6 - - # Edge case: no tokens - documents = [String[], String[], String[]] - dtm = document_term_matrix(documents) - @test size(dtm.tf) == (3, 0) - @test isempty(dtm.vocab) - @test isempty(dtm.vocab_lookup) - @test isempty(dtm.idf) - @test dtm.doc_rel_length == zeros(3) - - ## Methods - hcat - documents = [["this", "is", "a", "test"], - ["this", "is", "another", "test"], ["foo", "bar", "baz"]] - dtm1 = document_term_matrix(documents) - documents = [["this", "is", "a", "test"], - ["this", "is", "another", "test"], ["foo", "bar", "baz"]] - dtm2 = document_term_matrix(documents) - dtm = hcat(dtm1, dtm2) - @test size(dtm.tf) == (6, 8) - @test length(dtm.vocab) == 8 - @test length(dtm.idf) == 8 - @test isapprox(dtm.doc_rel_length, - [4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667, - 4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667]) - - # Check stubs that they throw - @test_throws ArgumentError RT._stem(nothing, "abc") - @test_throws ArgumentError RT._unicode_normalize(nothing) -end +# @testset "DocumentTermMatrix" begin +# Simple case +documents = [["this", "is", "a", "test"], + ["this", "is", "another", "test"], ["foo", "bar", "baz"]] +dtm = document_term_matrix(documents) +@test size(dtm.tf) == (3, 8) +@test Set(dtm.vocab) == Set(["a", "another", "bar", "baz", "foo", "is", "test", "this"]) +avgdl = 3.666666666666667 +@test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl]) +@test length(dtm.idf) == 8 + +# Edge case: single document +documents = [["this", "is", "a", "test"]] +dtm = document_term_matrix(documents) +@test size(dtm.tf) == (1, 4) +@test Set(dtm.vocab) == Set(["a", "is", "test", "this"]) +@test dtm.doc_rel_length == ones(1) +@test length(dtm.idf) == 4 + +# Edge case: duplicate tokens +documents = [["this", "is", "this", "test"], + ["this", "is", "another", "test"], ["this", "bar", "baz"]] +dtm = document_term_matrix(documents) +@test size(dtm.tf) == (3, 6) +@test Set(dtm.vocab) == Set(["another", "bar", "baz", "is", "test", "this"]) +avgdl = 3.666666666666667 +@test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl]) +@test length(dtm.idf) == 6 + +# Edge case: no tokens +documents = [String[], String[], String[]] +dtm = document_term_matrix(documents) +@test size(dtm.tf) == (3, 0) +@test isempty(dtm.vocab) +@test isempty(dtm.vocab_lookup) +@test isempty(dtm.idf) +@test dtm.doc_rel_length == zeros(3) + +## Methods - hcat +documents = [["this", "is", "a", "test"], + ["this", "is", "another", "test"], ["foo", "bar", "baz"]] +dtm1 = document_term_matrix(documents) +documents = [["this", "is", "a", "test"], + ["this", "is", "another", "test"], ["foo", "bar", "baz"]] +dtm2 = document_term_matrix(documents) +dtm = hcat(dtm1, dtm2) +@test size(dtm.tf) == (6, 8) +@test length(dtm.vocab) == 8 +@test length(dtm.idf) == 8 +@test isapprox(dtm.doc_rel_length, + [4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667, + 4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667]) + +# Check stubs that they throw +@test_throws ArgumentError RT._stem(nothing, "abc") +@test_throws ArgumentError RT._unicode_normalize(nothing) +# end + +# @testset "SubDocumentTermMatrix" begin +# Create a parent DocumentTermMatrix +documents = [["this", "is", "a", "test"], ["another", "test", "document"]] +dtm = document_term_matrix(documents) + +# Create a SubDocumentTermMatrix +sub_dtm = view(dtm, [1], :) + +# Test parent method +@test parent(sub_dtm) == dtm + +# Test positions method +@test positions(sub_dtm) == [1] + +# Test tf method +@test tf(sub_dtm) == dtm.tf[1:1, :] + +# Test vocab method +@test vocab(sub_dtm) == vocab(dtm) + +# Test vocab_lookup method +@test vocab_lookup(sub_dtm) == vocab_lookup(dtm) + +# Test idf method +@test idf(sub_dtm) == idf(dtm) + +# Test doc_rel_length method +@test doc_rel_length(sub_dtm) == doc_rel_length(dtm)[1:1] + +# Test view method for SubDocumentTermMatrix +sub_dtm_view = view(sub_dtm, [1], :) +@test parent(sub_dtm_view) == dtm +@test positions(sub_dtm_view) == [1] +@test tf(sub_dtm_view) == dtm.tf[1:1, :] + +# Nested view // no intersection +sub_sub_dtm_view = view(sub_dtm_view, [2], :) +@test parent(sub_sub_dtm_view) == dtm +@test isempty(positions(sub_sub_dtm_view)) +@test tf(sub_sub_dtm_view) |> isempty + +# Test view method with out of bounds positions +@test_throws BoundsError view(sub_dtm, [10], :) + +# Test view method with intersecting positions +sub_dtm_intersect = view(dtm, [1, 2], :) +sub_dtm_view_intersect = view(sub_dtm_intersect, [2], :) +@test parent(sub_dtm_view_intersect) == dtm +@test positions(sub_dtm_view_intersect) == [2] +@test tf(sub_dtm_view_intersect) == dtm.tf[2:2, :] +# end @testset "MultiIndex" begin # Test constructors/accessors @@ -432,6 +500,7 @@ end sub_index = view(ci1, cc) @test chunks(sub_index) == ["chunk2", "chunk3"] @test sources(sub_index) == ["source2", "source3"] + @test translate_positions_to_parent(sub_index, [2, 1]) == [3, 2] # Test accessing chunks from SubChunkIndex cc = CandidateChunks(ci1, [2]) @@ -441,6 +510,13 @@ end @test sub_index[cc, :embeddings] == nothing @test sub_index[cc, :chunkdata] == nothing @test parent(sub_index)[cc, :chunks] == ["chunk2"] + @test chunkdata(sub_index) == nothing + @test chunkdata(sub_index, [1]) == nothing + + # Wrong Index ID -> empty + cc_wrongid = CandidateChunks(:bad_id, [2], [0.1f0]) + sub_index_wrongid = view(ci1, cc_wrongid) + @test isempty(sub_index_wrongid) # Test creating a SubChunkIndex with out-of-bounds CandidateChunks cc = CandidateChunks(ci1, [4]) @@ -471,6 +547,7 @@ end @test chunks(sub_index11) == ["chunk1", "chunk2"] @test sources(sub_index11) == ["source1", "source2"] @test chunkdata(sub_index11) ≈ [1.0 0.5; 1.0 0.5] + @test chunkdata(sub_index11, [2]) ≈ [0.5, 0.5] @test embeddings(sub_index11) ≈ [1.0 0.5; 1.0 0.5] @test tags(sub_index11) == Bool[1 0 0; 0 1 0] @test tags_vocab(sub_index11) == tags_vocab_test @@ -554,6 +631,15 @@ end sub_oob = SubChunkIndex(sub_sub_index, [10]) @test_throws BoundsError SubChunkIndex(sub_oob, cc_oob) + # return empty if it's wrong index id + cc_wrongid = CandidateChunks(:bad_id, [2], [0.1f0]) + sub_index_wrongid = SubChunkIndex(sub_sub_index, cc_wrongid) + @test isempty(sub_index_wrongid) + + # views produce intersection, so if they don't match it becomes empty view + cc_sub_notmatch = CandidateChunks(sub_sub_index, [2]) + @test view(sub_sub_index, cc_sub_notmatch) |> isempty + # Test edge cases for SubChunkIndex created from SubChunkIndex # Empty positions cc_empty_sub = CandidateChunks(sub_index, Int[]) @@ -564,12 +650,12 @@ end @test isempty(sub_index_empty_sub) == true # Out of bounds positions - cc_oob_sub = CandidateChunks(sub_index, [10]) - @test_throws BoundsError view(sub_index, cc_oob_sub) + cc_oob_sub = CandidateChunks(ci1, [10]) + @test_throws BoundsError view(ci1, cc_oob_sub) # Duplicate positions - cc_dup_sub = CandidateChunks(sub_index, [1, 1, 2]) - sub_index_dup_sub = view(sub_index, cc_dup_sub) + cc_dup_sub = CandidateChunks(ci1, [1, 1, 2]) + sub_index_dup_sub = view(ci1, cc_dup_sub) @test length(sub_index_dup_sub) == 3 @test chunks(sub_index_dup_sub) == ["chunk1", "chunk1", "chunk2"] @test unique(sub_index_dup_sub) == SubChunkIndex(ci1, [1, 2]) @@ -608,9 +694,13 @@ end @test chunks(sub_sub_index) == ["chunk2", "chunk3"] @test sources(sub_sub_index) == ["source2", "source3"] - sub_oob = SubChunkIndex(sub_sub_index, [10]) + sub_oob = SubChunkIndex(ci2, [10]) @test_throws BoundsError SubChunkIndex(sub_oob, mcc_oob) + # views produce intersection, so if they don't match it becomes empty view + mcc_notmatch = MultiCandidateChunks(sub_sub_index, [1]) + @test view(sub_sub_index, mcc_notmatch) |> isempty + ## With keyword index chunks_ = ["chunk1", "chunk2"] sources_ = ["source1", "source2"]