Skip to content

Commit

Permalink
Fix truncate_dimension (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Apr 18, 2024
1 parent 23e4f0b commit a81bd76
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.20.1]

### Fixed
- Fixed `truncate_dimension` to ignore when 0 is provided (previously it would throw an error).

## [0.20.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.20.0"
version = "0.20.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
9 changes: 6 additions & 3 deletions src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched
- `docs`: A vector of strings to be embedded.
- `verbose`: A boolean flag for verbose output. Default is `true`.
- `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`.
- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`.
- `truncate_dimension`: The dimensionality of the embeddings to truncate to. Default is `nothing`, `0` will also do nothing.
- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call.
- `target_batch_size_length`: The target length (in characters) of each batch of document chunks sent for embedding. Default is 80_000 characters. Speeds up embedding process.
- `ntasks`: The number of tasks to use for asyncmap. Default is 4 * Threads.nthreads().
Expand Down Expand Up @@ -237,14 +237,17 @@ function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:Abstract
msg.content
end
embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents
if !isnothing(truncate_dimension)
@assert truncate_dimension>0 "Truncated dimensionality must be non-negative (Provided: $(truncate_dimension))"
# truncate_dimension=0 means that we skip it
if !isnothing(truncate_dimension) && truncate_dimension > 0
@assert truncate_dimension<=size(embeddings, 1) "Requested embeddings dimensionality is too high (Embeddings: $(size(embeddings)) vs dimensionality requested: $(truncate_dimension))"
## reduce + normalize again
embeddings = embeddings[1:truncate_dimension, :]
for i in axes(embeddings, 2)
embeddings[:, i] = _normalize(embeddings[:, i])
end
elseif !isnothing(truncate_dimension) && truncate_dimension == 0
# do nothing
verbose && @info "Truncate_dimension set to 0. Skipping truncation"
end
verbose && @info "Done embedding. Total cost: \$$(round(cost_tracker[],digits=3))"
return embeddings
Expand Down
4 changes: 4 additions & 0 deletions test/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ end
output = get_embeddings(
BatchEmbedder(), docs; model = "mock-emb", truncate_dimension = 100)
@test size(output) == (100, 2)
## value of 0 for truncation, skips the step
output = get_embeddings(
BatchEmbedder(), docs; model = "mock-emb", truncate_dimension = 0)
@test size(output) == (128, 2)

# Unknown type
struct RandomEmbedder123 <: AbstractEmbedder end
Expand Down

0 comments on commit a81bd76

Please sign in to comment.