Skip to content

Commit

Permalink
Add DTM specialized method
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Aug 8, 2024
1 parent a53fbfe commit 7d6a8d8
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 111 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.47.0]

### Added
- Added a new specialized method for `hcat(::DocumentTermMatrix, ::DocumentTermMatrix)` to allow for combining large DocumentTermMatrices (eg, 1M x 100K).

### Updated
- Increased the compat bound for HTTP.jl to 1.10.8 to fix a bug with Julia 1.11.

### Fixed
- Fixed a bug in `vcat_labeled_matrices` where extremely large DocumentTermMatrix could run out of memory.
- Fixed a bug in `score_to_unit_scale` where empty score vectors would error (now returns the empty array back).

## [0.46.0]

### Added
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.46.0"
version = "0.47.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -41,7 +41,7 @@ Base64 = "<0.0.1, 1"
Dates = "<0.0.1, 1"
FlashRank = "0.4"
GoogleGenAI = "0.3"
HTTP = "1"
HTTP = "1.10.8"
JSON3 = "1"
LinearAlgebra = "<0.0.1, 1"
Logging = "<0.0.1, 1"
Expand Down
63 changes: 63 additions & 0 deletions ext/RAGToolsExperimentalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,69 @@ function RT.build_tags(
return tags_, tags_vocab_
end

function RT.vcat_labeled_matrices(mat1::AbstractSparseMatrix{T1},
vocab1::AbstractVector{<:AbstractString},
mat2::AbstractSparseMatrix{T2},
vocab2::AbstractVector{<:AbstractString}) where {T1 <: Number, T2 <: Number}
T = promote_type(T1, T2)
new_words = setdiff(vocab2, vocab1)
combined_vocab = [vocab1; new_words]
vocab2_indices = Dict(word => i for (i, word) in enumerate(vocab2))

## more efficient composition
I, J, V = findnz(mat1)
aligned_mat1 = sparse(
I, J, convert(Vector{T}, V), size(mat1, 1), length(combined_vocab))

## collect the mat2 more efficiently since it's sparse
I, J, V = Int[], Int[], T[]
nz_rows = rowvals(mat2)
nz_vals = nonzeros(mat2)
for (j, word) in enumerate(combined_vocab)
if haskey(vocab2_indices, word)
@inbounds @simd for k in nzrange(mat2, vocab2_indices[word])
i = nz_rows[k]
val = nz_vals[k]
if !iszero(val)
push!(I, i)
push!(J, j)
push!(V, val)
end
end
end
end
aligned_mat2 = sparse(I, J, V, size(mat2, 1), length(combined_vocab))

return vcat(aligned_mat1, aligned_mat2), combined_vocab
end

function Base.hcat(d1::RT.DocumentTermMatrix{<:AbstractSparseMatrix},
d2::RT.DocumentTermMatrix{<:AbstractSparseMatrix})
tf_, vocab_ = RT.vcat_labeled_matrices(tf(d1), vocab(d1), tf(d2), vocab(d2))
vocab_lookup_ = Dict(t => i for (i, t) in enumerate(vocab_))

## decompose tf for efficient ops
N, M = size(tf_)
I, J, V = findnz(tf_)
doc_freq = zeros(Int, M)
@inbounds for j in eachindex(J, V)
if V[j] > 0
doc_freq[J[j]] += 1
end
end
idf = @. log(1.0f0 + (N - doc_freq + 0.5f0) / (doc_freq + 0.5f0))
doc_lengths = zeros(Float32, N)
@inbounds for i in eachindex(I, V)
if V[i] > 0
doc_lengths[I[i]] += V[i]
end
end
sumdl = sum(doc_lengths)
doc_rel_length_ = sumdl == 0 ? zeros(Float32, N) :
convert(Vector{Float32}, (doc_lengths ./ (sumdl / N)))
return RT.DocumentTermMatrix(tf_, vocab_, vocab_lookup_, idf, doc_rel_length_)
end

"""
document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
Expand Down
2 changes: 2 additions & 0 deletions src/Experimental/RAGTools/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,8 @@ scaled_x = score_to_unit_scale(x)
```
"""
function score_to_unit_scale(x::AbstractVector{T}) where {T <: Real}
isempty(x) && return x
##
ex = extrema(x)
if ex[2] - ex[1] < eps(T)
ones(T, length(x))
Expand Down
269 changes: 160 additions & 109 deletions test/Experimental/RAGTools/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,115 +176,166 @@ end
@test_throws ArgumentError embeddings(ci1)
end

# @testset "DocumentTermMatrix" begin
# Simple case
documents = [["this", "is", "a", "test"],
["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (3, 8)
@test Set(dtm.vocab) == Set(["a", "another", "bar", "baz", "foo", "is", "test", "this"])
avgdl = 3.666666666666667
@test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl])
@test length(dtm.idf) == 8

# Edge case: single document
documents = [["this", "is", "a", "test"]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (1, 4)
@test Set(dtm.vocab) == Set(["a", "is", "test", "this"])
@test dtm.doc_rel_length == ones(1)
@test length(dtm.idf) == 4

# Edge case: duplicate tokens
documents = [["this", "is", "this", "test"],
["this", "is", "another", "test"], ["this", "bar", "baz"]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (3, 6)
@test Set(dtm.vocab) == Set(["another", "bar", "baz", "is", "test", "this"])
avgdl = 3.666666666666667
@test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl])
@test length(dtm.idf) == 6

# Edge case: no tokens
documents = [String[], String[], String[]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (3, 0)
@test isempty(dtm.vocab)
@test isempty(dtm.vocab_lookup)
@test isempty(dtm.idf)
@test dtm.doc_rel_length == zeros(3)

## Methods - hcat
documents = [["this", "is", "a", "test"],
["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm1 = document_term_matrix(documents)
documents = [["this", "is", "a", "test"],
["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm2 = document_term_matrix(documents)
dtm = hcat(dtm1, dtm2)
@test size(dtm.tf) == (6, 8)
@test length(dtm.vocab) == 8
@test length(dtm.idf) == 8
@test isapprox(dtm.doc_rel_length,
[4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667,
4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667])

# Check stubs that they throw
@test_throws ArgumentError RT._stem(nothing, "abc")
@test_throws ArgumentError RT._unicode_normalize(nothing)
# end

# @testset "SubDocumentTermMatrix" begin
# Create a parent DocumentTermMatrix
documents = [["this", "is", "a", "test"], ["another", "test", "document"]]
dtm = document_term_matrix(documents)

# Create a SubDocumentTermMatrix
sub_dtm = view(dtm, [1], :)

# Test parent method
@test parent(sub_dtm) == dtm

# Test positions method
@test positions(sub_dtm) == [1]

# Test tf method
@test tf(sub_dtm) == dtm.tf[1:1, :]

# Test vocab method
@test vocab(sub_dtm) == vocab(dtm)

# Test vocab_lookup method
@test vocab_lookup(sub_dtm) == vocab_lookup(dtm)

# Test idf method
@test idf(sub_dtm) == idf(dtm)

# Test doc_rel_length method
@test doc_rel_length(sub_dtm) == doc_rel_length(dtm)[1:1]

# Test view method for SubDocumentTermMatrix
sub_dtm_view = view(sub_dtm, [1], :)
@test parent(sub_dtm_view) == dtm
@test positions(sub_dtm_view) == [1]
@test tf(sub_dtm_view) == dtm.tf[1:1, :]

# Nested view // no intersection
sub_sub_dtm_view = view(sub_dtm_view, [2], :)
@test parent(sub_sub_dtm_view) == dtm
@test isempty(positions(sub_sub_dtm_view))
@test tf(sub_sub_dtm_view) |> isempty

# Test view method with out of bounds positions
@test_throws BoundsError view(sub_dtm, [10], :)

# Test view method with intersecting positions
sub_dtm_intersect = view(dtm, [1, 2], :)
sub_dtm_view_intersect = view(sub_dtm_intersect, [2], :)
@test parent(sub_dtm_view_intersect) == dtm
@test positions(sub_dtm_view_intersect) == [2]
@test tf(sub_dtm_view_intersect) == dtm.tf[2:2, :]
# end
@testset "DocumentTermMatrix" begin
# Simple case
documents = [["this", "is", "a", "test"],
["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (3, 8)
@test Set(dtm.vocab) == Set(["a", "another", "bar", "baz", "foo", "is", "test", "this"])
avgdl = 3.666666666666667
@test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl])
@test length(dtm.idf) == 8

# Edge case: single document
documents = [["this", "is", "a", "test"]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (1, 4)
@test Set(dtm.vocab) == Set(["a", "is", "test", "this"])
@test dtm.doc_rel_length == ones(1)
@test length(dtm.idf) == 4

# Edge case: duplicate tokens
documents = [["this", "is", "this", "test"],
["this", "is", "another", "test"], ["this", "bar", "baz"]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (3, 6)
@test Set(dtm.vocab) == Set(["another", "bar", "baz", "is", "test", "this"])
avgdl = 3.666666666666667
@test all(dtm.doc_rel_length .≈ [4 / avgdl, 4 / avgdl, 3 / avgdl])
@test length(dtm.idf) == 6

# Edge case: no tokens
documents = [String[], String[], String[]]
dtm = document_term_matrix(documents)
@test size(dtm.tf) == (3, 0)
@test isempty(dtm.vocab)
@test isempty(dtm.vocab_lookup)
@test isempty(dtm.idf)
@test dtm.doc_rel_length == zeros(3)

## Methods - hcat
documents = [["this", "is", "a", "test"],
["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm1 = document_term_matrix(documents)
documents = [["this", "is", "a", "test"],
["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm2 = document_term_matrix(documents)
dtm = hcat(dtm1, dtm2)
@test size(dtm.tf) == (6, 8)
@test length(dtm.vocab) == 8
@test length(dtm.idf) == 8
@test isapprox(dtm.doc_rel_length,
[4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667,
4 / 3.666666666666667, 4 / 3.666666666666667, 3 / 3.666666666666667])

# Check stubs that they throw
@test_throws ArgumentError RT._stem(nothing, "abc")
@test_throws ArgumentError RT._unicode_normalize(nothing)

## SubDocumentTermMatrix
# Create a parent DocumentTermMatrix
documents = [["this", "is", "a", "test"], ["another", "test", "document"]]
dtm = document_term_matrix(documents)

# Create a SubDocumentTermMatrix
sub_dtm = view(dtm, [1], :)

# Test parent method
@test parent(sub_dtm) == dtm

# Test positions method
@test positions(sub_dtm) == [1]

# Test tf method
@test tf(sub_dtm) == dtm.tf[1:1, :]

# Test vocab method
@test vocab(sub_dtm) == vocab(dtm)

# Test vocab_lookup method
@test vocab_lookup(sub_dtm) == vocab_lookup(dtm)

# Test idf method
@test idf(sub_dtm) == idf(dtm)

# Test doc_rel_length method
@test doc_rel_length(sub_dtm) == doc_rel_length(dtm)[1:1]

# Test view method for SubDocumentTermMatrix
sub_dtm_view = view(sub_dtm, [1], :)
@test parent(sub_dtm_view) == dtm
@test positions(sub_dtm_view) == [1]
@test tf(sub_dtm_view) == dtm.tf[1:1, :]

# Nested view // no intersection
sub_sub_dtm_view = view(sub_dtm_view, [2], :)
@test parent(sub_sub_dtm_view) == dtm
@test isempty(positions(sub_sub_dtm_view))
@test tf(sub_sub_dtm_view) |> isempty

# Test view method with out of bounds positions
@test_throws BoundsError view(sub_dtm, [10], :)

# Test view method with intersecting positions
sub_dtm_intersect = view(dtm, [1, 2], :)
sub_dtm_view_intersect = view(sub_dtm_intersect, [2], :)
@test parent(sub_dtm_view_intersect) == dtm
@test positions(sub_dtm_view_intersect) == [2]
@test tf(sub_dtm_view_intersect) == dtm.tf[2:2, :]

### Test hcat for DocumentTermMatrix
# Create two DocumentTermMatrix instances
documents1 = [["this", "is", "a", "test"], ["another", "test", "document"]]
dtm1 = document_term_matrix(documents1)

documents2 = [["new", "document"], ["with", "different", "words"]]
dtm2 = document_term_matrix(documents2)

# Perform hcat
combined_dtm = hcat(dtm1, dtm2)

# Test the resulting DocumentTermMatrix
@test size(combined_dtm.tf, 1) == size(dtm1.tf, 1) + size(dtm2.tf, 1)
@test length(combined_dtm.vocab) == length(unique(vcat(dtm1.vocab, dtm2.vocab)))
@test all(word in combined_dtm.vocab for word in dtm1.vocab)
@test all(word in combined_dtm.vocab for word in dtm2.vocab)

# Check if the tf matrix is correctly combined
@test size(combined_dtm.tf, 2) == length(combined_dtm.vocab)
@test sum(combined_dtm.tf) sum(dtm1.tf) + sum(dtm2.tf)

# Test vocab_lookup
@test all(haskey(combined_dtm.vocab_lookup, word) for word in combined_dtm.vocab)

# Test idf
@test length(combined_dtm.idf) == length(combined_dtm.vocab)

# Test doc_rel_length
@test length(combined_dtm.doc_rel_length) == size(combined_dtm.tf, 1)

# Test with empty DocumentTermMatrix
empty_dtm = document_term_matrix(Vector{Vector{String}}())
combined_with_empty = hcat(dtm1, empty_dtm)
@test combined_with_empty == dtm1

# Test associativity
dtm3 = document_term_matrix([["third", "set", "of", "documents"]])
@test hcat(hcat(dtm1, dtm2), dtm3) == hcat(dtm1, hcat(dtm2, dtm3))

# Test with dense matrix
ddtm1 = DocumentTermMatrix(
Matrix(tf(dtm1)), vocab(dtm1), vocab_lookup(dtm1), idf(dtm1), doc_rel_length(dtm1))
ddtm2 = DocumentTermMatrix(
Matrix(tf(dtm2)), vocab(dtm2), vocab_lookup(dtm2), idf(dtm2), doc_rel_length(dtm2))
combined_ddtm = hcat(ddtm1, ddtm2)
@test size(combined_ddtm.tf, 1) == size(ddtm1.tf, 1) + size(ddtm2.tf, 1)
@test length(combined_ddtm.vocab) == length(unique(vcat(ddtm1.vocab, ddtm2.vocab)))
@test all(word in combined_ddtm.vocab for word in ddtm1.vocab)
@test all(word in combined_ddtm.vocab for word in ddtm2.vocab)
@test size(combined_ddtm.tf, 2) == length(combined_ddtm.vocab)
@test sum(combined_ddtm.tf) sum(ddtm1.tf) + sum(ddtm2.tf)
end

@testset "MultiIndex" begin
# Test constructors/accessors
Expand Down
Loading

0 comments on commit 7d6a8d8

Please sign in to comment.