diff --git a/CHANGELOG.md b/CHANGELOG.md index b7a68f31f..28c4995ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### Fixed + +## [0.54.0] + +### Updated +- Improved the performance of BM25/Keywords-based indices for >10M documents. Introduced new kwargs of `min_term_freq` and `max_terms` in `RT.get_keywords` to reduce the size of the vocabulary. See `?RT.get_keywords` for more information. + ## [0.53.0] ### Added diff --git a/Project.toml b/Project.toml index 6a8db0fe9..17ae6a0e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.53.0" +version = "0.54.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/ext/RAGToolsExperimentalExt.jl b/ext/RAGToolsExperimentalExt.jl index 30d8ebb45..681e879c8 100644 --- a/ext/RAGToolsExperimentalExt.jl +++ b/ext/RAGToolsExperimentalExt.jl @@ -110,7 +110,9 @@ function Base.hcat(d1::RT.DocumentTermMatrix{<:AbstractSparseMatrix}, end """ - document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}}) + RT.document_term_matrix( + documents::AbstractVector{<:AbstractVector{T}}; + min_term_freq::Int = 1, max_terms::Int = typemax(Int)) where {T <: AbstractString} Builds a sparse matrix of term frequencies and document lengths from the given vector of documents wrapped in type `DocumentTermMatrix`. @@ -118,33 +120,59 @@ Expects a vector of preprocessed (tokenized) documents, where each document is a Returns: `DocumentTermMatrix` +# Arguments +- `documents`: A vector of documents, where each document is a vector of terms (clean tokens). +- `min_term_freq`: The minimum frequency a term must have to be included in the vocabulary, eg, `min_term_freq = 2` means only terms that appear at least twice will be included. +- `max_terms`: The maximum number of terms to include in the vocabulary, eg, `max_terms = 100` means only the 100 most frequent terms will be included. + # Example ``` documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]] dtm = document_term_matrix(documents) ``` """ -function RT.document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}}) - T = eltype(documents) |> eltype - vocab = convert(Vector{T}, unique(vcat(documents...))) - vocab_lookup = Dict{T, Int}(t => i for (i, t) in enumerate(vocab)) +function RT.document_term_matrix( + documents::AbstractVector{<:AbstractVector{T}}; + min_term_freq::Int = 1, max_terms::Int = typemax(Int)) where {T <: AbstractString} + ## Calculate term frequencies, sort descending + counts = Dict{T, Int}() + @inbounds for doc in documents + for term in doc + counts[term] = get(counts, term, 0) + 1 + end + end + counts = sort(collect(counts), by = x -> -x[2]) |> Base.Fix2(first, max_terms) |> + Base.Fix1(filter!, x -> x[2] >= min_term_freq) + ## Create vocabulary + vocab = convert(Vector{T}, getindex.(counts, 1)) + vocab_lookup = Dict{T, Int}(term => i for (i, term) in enumerate(vocab)) N = length(documents) doc_freq = zeros(Int, length(vocab)) - term_freq = spzeros(Float32, N, length(vocab)) doc_lengths = zeros(Float32, N) + ## Term frequency matrix to be recorded via its sparse entries: I, J, V + # term_freq = spzeros(Float32, N, length(vocab)) + I, J, V = Int[], Int[], Float32[] + + unique_terms = Set{eltype(vocab)}() + sizehint!(unique_terms, 1000) for di in eachindex(documents) - unique_terms = Set{eltype(vocab)}() + empty!(unique_terms) doc = documents[di] - for t in doc + @inbounds for t in doc doc_lengths[di] += 1 - tid = vocab_lookup[t] - term_freq[di, tid] += 1 + tid = get(vocab_lookup, t, nothing) + tid === nothing && continue + push!(I, di) + push!(J, tid) + push!(V, 1.0f0) if !(t in unique_terms) doc_freq[tid] += 1 push!(unique_terms, t) end end end + ## combine repeated terms with `+` + term_freq = sparse(I, J, V, N, length(vocab), +) idf = @. log(1.0f0 + (N - doc_freq + 0.5f0) / (doc_freq + 0.5f0)) sumdl = sum(doc_lengths) doc_rel_length = sumdl == 0 ? zeros(Float32, N) : doc_lengths ./ (sumdl / N) diff --git a/ext/SnowballPromptingToolsExt.jl b/ext/SnowballPromptingToolsExt.jl index 4e0107fbb..04658e01d 100644 --- a/ext/SnowballPromptingToolsExt.jl +++ b/ext/SnowballPromptingToolsExt.jl @@ -12,12 +12,14 @@ using Snowball RT._stem(stemmer::Snowball.Stemmer, text::AbstractString) = Snowball.stem(stemmer, text) """ - get_keywords(processor::KeywordsProcessor, docs::AbstractVector{<:AbstractString}; + RT.get_keywords( + processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString}; verbose::Bool = true, stemmer = nothing, - stopwords::Set{String} = Set(STOPWORDS), + stopwords::Set{String} = Set(RT.STOPWORDS), return_keywords::Bool = false, min_length::Integer = 3, + min_term_freq::Int = 1, max_terms::Int = typemax(Int), kwargs...) Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stemmer` and `stopwords`. @@ -29,6 +31,8 @@ Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stem - `stopwords`: A set of stopwords to remove. Default is `Set(STOPWORDS)`. - `return_keywords`: A boolean flag for returning the keywords. Default is `false`. Useful for query processing in search time. - `min_length`: The minimum length of the keywords. Default is `3`. +- `min_term_freq`: The minimum frequency a term must have to be included in the vocabulary, eg, `min_term_freq = 2` means only terms that appear at least twice will be included. +- `max_terms`: The maximum number of terms to include in the vocabulary, eg, `max_terms = 100` means only the 100 most frequent terms will be included. """ function RT.get_keywords( processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString}; @@ -37,16 +41,13 @@ function RT.get_keywords( stopwords::Set{String} = Set(RT.STOPWORDS), return_keywords::Bool = false, min_length::Integer = 3, + min_term_freq::Int = 1, max_terms::Int = typemax(Int), kwargs...) ## check if extension is available ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt) if isnothing(ext) error("You need to also import LinearAlgebra and SparseArrays to use this function") end - ## ext = Base.get_extension(PromptingTools, :SnowballPromptingToolsExt) - ## if isnothing(ext) - ## error("You need to also import Snowball.jl to use this function") - ## end ## Preprocess text into tokens stemmer = !isnothing(stemmer) ? stemmer : Snowball.Stemmer("english") # Single-threaded as stemmer is not thread-safe @@ -56,7 +57,7 @@ function RT.get_keywords( return_keywords && return keywords ## Create DTM - dtm = RT.document_term_matrix(keywords) + dtm = RT.document_term_matrix(keywords; min_term_freq, max_terms) verbose && @info "Done processing DocumentTermMatrix." return dtm diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl index d687cc4fd..b4fe0b3e5 100644 --- a/src/Experimental/RAGTools/preparation.jl +++ b/src/Experimental/RAGTools/preparation.jl @@ -207,7 +207,7 @@ function get_chunks(chunker::AbstractChunker, # split into chunks by recursively trying the separators provided # if you want to start simple - just do `split(text,"\n\n")` doc_chunks = PT.recursive_splitter(doc_raw, separators; max_length) .|> strip |> - x -> filter(!isempty, x) + Base.Fix1(filter!, !isempty) # skip if no chunks found isempty(doc_chunks) && continue append!(output_chunks, doc_chunks) diff --git a/test/Experimental/RAGTools/preparation.jl b/test/Experimental/RAGTools/preparation.jl index 1c00e4809..2f94b0a76 100644 --- a/test/Experimental/RAGTools/preparation.jl +++ b/test/Experimental/RAGTools/preparation.jl @@ -115,6 +115,32 @@ end @test Set(dtm.vocab) == Set(["this", "test", "document", "anoth", "more", "text"]) @test size(dtm.tf) == (2, 6) + # Test for KeywordsProcessor with min_term_freq and max_terms + docs_freq = [ + "apple banana cherry apple", + "banana date fig grape", + "apple banana cherry date", + "elephant fig grape" + ] + processor_freq = KeywordsProcessor() + + # Test with min_term_freq = 2 + dtm_freq = get_keywords(processor_freq, docs_freq; min_term_freq = 2) + @test Set(dtm_freq.vocab) == + Set(["appl", "banana", "cherri", "date", "fig", "grape"]) + @test size(dtm_freq.tf) == (4, 6) + + # Test with max_terms = 3 + dtm_max = get_keywords(processor_freq, docs_freq; max_terms = 3) + @test length(dtm_max.vocab) == 3 + @test size(dtm_max.tf) == (4, 3) + + # Test with both min_term_freq = 2 and max_terms = 2 + dtm_both = get_keywords(processor_freq, docs_freq; min_term_freq = 2, max_terms = 2) + @test length(dtm_both.vocab) == 2 + @test size(dtm_both.tf) == (4, 2) + @test all(sum(dtm_both.tf, dims = 1) .>= 2) + # Test for KeywordsProcessor with custom stemmer and stopwords custom_stemmer = Snowball.Stemmer("french") dtm_custom = get_keywords(