Skip to content

Commit

Permalink
Updates embedding concatenation (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Aug 4, 2024
1 parent 0916bd7 commit e2553b8
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.44.1]

### Updated
- Updated a `hcat` implementation in `RAGTools.get_embeddings` to reduce memory allocations for large embedding batches (c. 3x fewer allocations, see `hcat_truncate`).

## [0.44.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.44.0"
version = "0.44.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
16 changes: 3 additions & 13 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,9 @@ function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:Abstract
Threads.atomic_add!(cost_tracker, msg.cost) # track costs
msg.content
end
embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents
# truncate_dimension=0 means that we skip it
if !isnothing(truncate_dimension) && truncate_dimension > 0
@assert truncate_dimension<=size(embeddings, 1) "Requested embeddings dimensionality is too high (Embeddings: $(size(embeddings)) vs dimensionality requested: $(truncate_dimension))"
## reduce + normalize again
embeddings = embeddings[1:truncate_dimension, :]
for i in axes(embeddings, 2)
embeddings[:, i] = _normalize(embeddings[:, i])
end
elseif !isnothing(truncate_dimension) && truncate_dimension == 0
# do nothing
verbose && @info "Truncate_dimension set to 0. Skipping truncation"
end
## Concat across documents and truncate if needed
embeddings = hcat_truncate(embeddings, truncate_dimension; verbose)
## Normalize embeddings
verbose && @info "Done embedding. Total cost: \$$(round(cost_tracker[],digits=3))"
return embeddings
end
Expand Down
119 changes: 119 additions & 0 deletions src/Experimental/RAGTools/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,125 @@ function hcat_labeled_matrices(mat1::AbstractMatrix{T1},
return hcat(aligned_mat1, aligned_mat2), combined_vocab
end

"""
hcat_truncate(matrices::AbstractVector{<:AbstractMatrix{T}},
truncate_dimension::Union{Nothing, Int} = nothing; verbose::Bool = false) where {T <:
Real}
Horizontal concatenation of matrices, with optional truncation of the rows of each matrix to the specified dimension (reducing embedding dimensionality).
More efficient that a simple splatting, as the resulting matrix is pre-allocated in one go.
Returns: a `Matrix{Float32}`
# Arguments
- `matrices::AbstractVector{<:AbstractMatrix{T}}`: Vector of matrices to concatenate
- `truncate_dimension::Union{Nothing,Int}=nothing`: Dimension to truncate to, or `nothing` or `0` to skip truncation. If truncated, the columns will be normalized.
- `verbose::Bool=false`: Whether to print verbose output.
# Examples
```julia
a = rand(Float32, 1000, 10)
b = rand(Float32, 1000, 20)
c = hcat_truncate([a, b])
size(c) # (1000, 30)
d = hcat_truncate([a, b], 500)
size(d) # (500, 30)
```
"""
function hcat_truncate(matrices::AbstractVector{<:AbstractMatrix{T}},
truncate_dimension::Union{Nothing, Int} = nothing; verbose::Bool = false) where {T <:
Real}
rows = -1
total_cols = 0
@inbounds for matrix in matrices
row, col = size(matrix)
if rows < 0
rows = row
else
@assert row==rows "All matrices must have the same number of rows (Found $row and $rows)"
end
total_cols += col
end

## Check if we need to truncate
truncate, rows = if !isnothing(truncate_dimension) && truncate_dimension > 0
@assert truncate_dimension<=rows "Requested embeddings dimensionality is too high (Embeddings: $(rows) vs dimensionality requested: $(truncate_dimension))"
true, truncate_dimension
elseif !isnothing(truncate_dimension) && iszero(truncate_dimension)
verbose && @info "Truncate_dimension set to 0. Skipping truncation"
false, rows
else
false, rows
end

## initialize result
result = Matrix{Float32}(undef, rows, total_cols)

col_offset = 1
@inbounds for matrix in matrices
cols = size(matrix, 2)
if truncate
for col in eachcol(matrix)
## We must re-normalize the truncated vectors
## LinearAlgebra.normalize but imported in RAGToolsExperimentalExt
result[:, col_offset] = _normalize(@view(col[1:rows]))
col_offset += 1
end
else
## no truncation
result[:, col_offset:(col_offset + cols - 1)] = matrix
col_offset += cols
end
end

return result
end
function hcat_truncate(vectors::AbstractVector{<:AbstractVector{T}},
truncate_dimension::Union{Nothing, Int} = nothing; verbose::Bool = false) where {T <:
Real}
rows = -1
total_cols = 0
@inbounds for vec in vectors
row = size(vec, 1)
if rows < 0
rows = row
else
@assert row==rows "All vectors must have the same number of rows (Found $row and $rows)"
end
total_cols += 1
end

# Check if we need to truncate
truncate, rows = if !isnothing(truncate_dimension) && truncate_dimension > 0
@assert truncate_dimension<=rows "Requested truncation dimension is too high (Vector length: $rows vs requested: $truncate_dimension)"
true, truncate_dimension
elseif !isnothing(truncate_dimension) && iszero(truncate_dimension)
verbose && @info "Truncate_dimension set to 0. Skipping truncation"
false, rows
else
false, rows
end

# Initialize result
result = Matrix{Float32}(undef, rows, total_cols)

# Fill the result matrix
@inbounds for i in eachindex(vectors)
vect = vectors[i]
if truncate
# We must re-normalize the truncated vectors
result[:, i] = _normalize(@view(vect[1:rows]))
else
result[:, i] = vect
end
end

return result
end

### Text Utilities
# STOPWORDS - used for annotation highlighting
# Just a small list to get started
Expand Down
105 changes: 104 additions & 1 deletion test/Experimental/RAGTools/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using PromptingTools.Experimental.RAGTools: split_into_code_and_sentences
using PromptingTools.Experimental.RAGTools: getpropertynested, setpropertynested,
merge_kwargs_nested
using PromptingTools.Experimental.RAGTools: pack_bits, unpack_bits, preprocess_tokens,
reciprocal_rank_fusion, score_to_unit_scale
reciprocal_rank_fusion, score_to_unit_scale,
hcat_truncate, _normalize

@testset "_check_aiextract_capability" begin
@test _check_aiextract_capability("gpt-3.5-turbo") == nothing
Expand Down Expand Up @@ -92,6 +93,108 @@ end
@test merged_mat [1.0 2.0 0.0 0.0; 3.0 4.0 5.0 6.0; 0.0 0.0 7.0 8.0]
end

@testset "hcat_truncate" begin
# Test basic functionality with no truncation
m1 = Float32[1 2; 3 4; 5 6]
m2 = Float32[7 8; 9 10; 11 12]
result = hcat_truncate([m1, m2])
@test size(result) == (3, 4)
@test result == Float32[1 2 7 8; 3 4 9 10; 5 6 11 12]

# Test with truncation
result_truncated = hcat_truncate([m1, m2], 2)
@test size(result_truncated) == (2, 4)

# Test normalization after truncation
expected_col1 = Float32[1, 3] / sqrt(1^2 + 3^2)
@test result_truncated[:, 1] expected_col1

# Test with single matrix input
single_result = hcat_truncate([m1])
@test single_result == m1

# Test with empty input
@test_throws Exception hcat_truncate([])

# Test with matrices of different row counts
m3 = Float32[1 2; 3 4]
@test_throws AssertionError hcat_truncate([m1, m3])

# Test with truncation dimension larger than input
@test_throws AssertionError hcat_truncate([m1, m2], 4)

# Test with truncate_dimension set to 0
zero_truncate = hcat_truncate([m1, m2], 0)
@test zero_truncate == Float32[1 2 7 8; 3 4 9 10; 5 6 11 12]

# Test with large matrices to ensure performance
large_m1 = rand(Float32, 1000, 1000)
large_m2 = rand(Float32, 1000, 1000)
@test size(hcat_truncate([large_m1, large_m2], 500)) == (500, 2000)

# Test with different types (should convert to Float32)
m4 = [1.0 2.0; 3.0 4.0; 5.0 6.0]
result_type_conversion = hcat_truncate([m4])
@test eltype(result_type_conversion) == Float32

# Test with truncate=nothing (should behave the same as no truncation)
result_nothing = hcat_truncate([m1, m2], nothing)
@test result_nothing == Float32[1 2 7 8; 3 4 9 10; 5 6 11 12]

# Test with truncate=-1 (should behave the same as no truncation)
result_negative = hcat_truncate([m1, m2], -1)
@test result_negative == Float32[1 2 7 8; 3 4 9 10; 5 6 11 12]

## Test for Vectors
# Test basic functionality
v1 = [1.0, 2.0, 3.0]
v2 = [4.0, 5.0, 6.0]
result = hcat_truncate([v1, v2])
@test size(result) == (3, 2)
@test result == [1.0 4.0; 2.0 5.0; 3.0 6.0]

# Test with truncation
result_truncated = hcat_truncate([v1, v2], 2)
@test size(result_truncated) == (2, 2)
@test result_truncated mapreduce(_normalize, hcat, eachcol([1.0 4.0; 2.0 5.0]))

# Test with single vector input
single_result = hcat_truncate([v1])
@test single_result == reshape(v1, :, 1)

# Test with empty input
@test_throws Exception hcat_truncate(Vector{Float64}[])

# Test with vectors of different lengths
v3 = [1.0, 2.0]
@test_throws AssertionError hcat_truncate([v1, v3])

# Test with truncation dimension larger than input
@test_throws AssertionError hcat_truncate([v1, v2], 4)

# Test with truncate_dimension set to 0
zero_truncate = hcat_truncate([v1, v2], 0)
@test zero_truncate == [1.0 4.0; 2.0 5.0; 3.0 6.0]

# Test with large vectors to ensure performance
large_v1 = rand(1000)
large_v2 = rand(1000)
@test size(hcat_truncate([large_v1, large_v2], 500)) == (500, 2)

# Test with different types (should convert to Float32)
v4 = [1, 2, 3]
result_type_conversion = hcat_truncate([v4])
@test eltype(result_type_conversion) == Float32

# Test with truncate=nothing (should behave the same as no truncation)
result_nothing = hcat_truncate([v1, v2], nothing)
@test result_nothing == [1.0 4.0; 2.0 5.0; 3.0 6.0]

# Test with truncate=-1 (should behave the same as no truncation)
result_negative = hcat_truncate([v1, v2], -1)
@test result_negative == [1.0 4.0; 2.0 5.0; 3.0 6.0]
end

### Text-manipulation utilities

@testset "tokenize" begin
Expand Down

0 comments on commit e2553b8

Please sign in to comment.