Skip to content

Commit

Permalink
Add SubDocumentTermMatrix (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jul 22, 2024
1 parent fcd7509 commit 89d4c43
Show file tree
Hide file tree
Showing 6 changed files with 394 additions and 108 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.41.0]

### Added
- Introduced a "view" of `DocumentTermMatrix` (=`SubDocumentTermMatrix`) to allow views of Keyword-based indices (`ChunkKeywordsIndex`). It's not a pure view (TF matrix is materialized to prevent performance degradation).

### Fixed
- Fixed a bug in `find_closest(finder::BM25Similarity, ...)` where the view of `DocumentTermMatrix` (ie, `view(DocumentTermMatrix(...), ...)`) was undefined.
- Fixed a bug where a view of a view of a `ChunkIndex` wouldn't intersect the positions (it was returning only the latest requested positions).

## [0.40.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.40.0"
version = "0.41.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
31 changes: 17 additions & 14 deletions ext/RAGToolsExperimentalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using LinearAlgebra
const PT = PromptingTools

using PromptingTools.Experimental.RAGTools
using PromptingTools.Experimental.RAGTools: tf, vocab, vocab_lookup, idf, doc_rel_length
const RT = PromptingTools.Experimental.RAGTools

# forward to LinearAlgebra.normalize
Expand Down Expand Up @@ -92,7 +93,7 @@ function RT.document_term_matrix(documents::AbstractVector{<:AbstractString})
end

"""
RT.bm25(dtm::DocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0)
RT.bm25(dtm::AbstractDocumentTermMatrix, query::Vector{String}; k1::Float32=1.2f0, b::Float32=0.75f0)
Scores all documents in `dtm` based on the `query`.
Expand All @@ -107,30 +108,32 @@ scores = bm25(dtm, query)
# Returns array with 3 scores (one for each document)
```
"""
function RT.bm25(dtm::RT.DocumentTermMatrix, query::AbstractVector{<:AbstractString};
function RT.bm25(
dtm::RT.AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
k1::Float32 = 1.2f0, b::Float32 = 0.75f0)
scores = zeros(Float32, size(dtm.tf, 1))
scores = zeros(Float32, size(tf(dtm), 1))
## Identify non-zero items to leverage the sparsity
nz_rows = rowvals(dtm.tf)
nz_vals = nonzeros(dtm.tf)
nz_rows = rowvals(tf(dtm))
nz_vals = nonzeros(tf(dtm))
for i in eachindex(query)
t = query[i]
t_id = get(dtm.vocab_lookup, t, nothing)
t_id = get(vocab_lookup(dtm), t, nothing)
t_id === nothing && continue
idf = dtm.idf[t_id]
idf_ = idf(dtm)[t_id]
# Scan only documents that have this token
@inbounds @simd for j in nzrange(dtm.tf, t_id)
@inbounds @simd for j in nzrange(tf(dtm), t_id)
## index into the sparse matrix
di, tf = nz_rows[j], nz_vals[j]
doc_len = dtm.doc_rel_length[di]
tf_top = (tf * (k1 + 1.0f0))
tf_bottom = (tf + k1 * (1.0f0 - b + b * doc_len))
score = idf * tf_top / tf_bottom
di, tf_ = nz_rows[j], nz_vals[j]
doc_len = doc_rel_length(dtm)[di]
tf_top = (tf_ * (k1 + 1.0f0))
tf_bottom = (tf_ + k1 * (1.0f0 - b + b * doc_len))
score = idf_ * tf_top / tf_bottom
## @info "di: $di, tf: $tf, doc_len: $doc_len, idf: $idf, tf_top: $tf_top, tf_bottom: $tf_bottom, score: $score"
scores[di] += score
end
end
scores

return scores
end

end # end of module
18 changes: 9 additions & 9 deletions src/Experimental/RAGTools/retrieval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ function find_closest(
positions, scores = find_closest(finder, chunkdata(index),
query_emb, query_tokens;
top_k, kwargs...)
## translate positions to original indices
positions = translate_positions_to_parent(index, positions)
return CandidateChunks(indexid(index), positions, Float32.(scores))
end

Expand Down Expand Up @@ -274,6 +276,8 @@ function find_closest(
positions_, scores_ = find_closest(finder[i], chunkdata(all_indexes[i]),
query_emb, query_tokens;
top_k = top_k_shard, kwargs...)
## translate positions to original indices
positions_ = translate_positions_to_parent(all_indexes[i], positions_)
append!(index_ids, fill(indexid(all_indexes[i]), length(positions_)))
append!(positions, positions_)
append!(scores, scores_)
Expand Down Expand Up @@ -446,7 +450,7 @@ end

"""
find_closest(
finder::BM25Similarity, dtm::DocumentTermMatrix,
finder::BM25Similarity, dtm::AbstractDocumentTermMatrix,
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
Expand All @@ -456,7 +460,7 @@ Reference: [Wikipedia: BM25](https://en.wikipedia.org/wiki/Okapi_BM25).
Implementation follows: [The Next Generation of Lucene Relevance](https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/).
"""
function find_closest(
finder::BM25Similarity, dtm::DocumentTermMatrix,
finder::BM25Similarity, dtm::AbstractDocumentTermMatrix,
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
scores = bm25(dtm, query_tokens)
Expand Down Expand Up @@ -505,9 +509,7 @@ function find_tags(method::AnyTagFilter, index::AbstractChunkIndex,
match_row_idx = @view(tags(index)[:, tag_idx]) |> findall .|> Base.Fix2(getindex, 1) |>
unique
## Index can be a SubChunkIndex, so we need to convert to the original indices
if index isa SubChunkIndex
match_row_idx = positions(index)[match_row_idx]
end
match_row_idx = translate_positions_to_parent(index, match_row_idx)
return CandidateChunks(
indexid(index), match_row_idx, ones(Float32, length(match_row_idx)))
end
Expand Down Expand Up @@ -546,10 +548,8 @@ function find_tags(method::AllTagFilter, index::AbstractChunkIndex,
else
Int[]
end
## Index can be a SubChunkIndex, so we need to convert to the original indices
if index isa SubChunkIndex
match_row_idx = positions(index)[match_row_idx]
end
## translate to original indices
match_row_idx = translate_positions_to_parent(index, match_row_idx)
return CandidateChunks(
indexid(index), match_row_idx, ones(Float32, length(match_row_idx)))
end
Expand Down
Loading

0 comments on commit 89d4c43

Please sign in to comment.