diff --git a/CHANGELOG.md b/CHANGELOG.md
index 342e114d8..7b62c4688 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
+- Experimental sub-module RAGTools providing basic Retrieval-Augmented Generation functionality. See `?RAGTools` for more information. It's all nested inside of `PromptingTools.Experimental.RAGTools` to signify that it might change in the future. Key functions are `build_index` and `airag`, but it also provides a suite to make evaluation easier (see `?build_qa_evals` and `?run_qa_evals` or just see the example `examples/building_RAG.jl`)
### Fixed
- Stricter code parsing in `AICode` to avoid false positives (code blocks must end with "```\n" to catch comments inside text)
diff --git a/Project.toml b/Project.toml
index 17136594e..dd79425ea 100644
--- a/Project.toml
+++ b/Project.toml
@@ -12,21 +12,32 @@ OpenAI = "e9f21f70-7185-4079-aca2-91159181367c"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
+[weakdeps]
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+
+[extensions]
+RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra"]
+
[compat]
Aqua = "0.7"
Base64 = "<0.0.1, 1"
HTTP = "1"
JSON3 = "1"
+LinearAlgebra = "<0.0.1, 1"
Logging = "<0.0.1, 1"
OpenAI = "0.8.7"
PrecompileTools = "1"
Preferences = "1"
+SparseArrays = "<0.0.1, 1"
Test = "<0.0.1, 1"
julia = "1.9,1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["Aqua", "Test"]
+test = ["Aqua", "Test", "SparseArrays", "LinearAlgebra"]
diff --git a/docs/Project.toml b/docs/Project.toml
index b0857977a..7011d3c5f 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -1,5 +1,10 @@
[deps]
+DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
+JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192"
+SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
diff --git a/docs/generate_examples.jl b/docs/generate_examples.jl
index 5c56923e6..f71187d5b 100644
--- a/docs/generate_examples.jl
+++ b/docs/generate_examples.jl
@@ -8,4 +8,6 @@ output_dir = joinpath(@__DIR__, "src", "examples")
filter!(endswith(".jl"), example_files)
for fn in example_files
Literate.markdown(fn, output_dir; execute = true)
-end
\ No newline at end of file
+end
+
+# TODO: change meta fields at the top of each file!
\ No newline at end of file
diff --git a/docs/make.jl b/docs/make.jl
index d3e676100..16b4d2456 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -1,5 +1,9 @@
using PromptingTools
using Documenter
+using SparseArrays, LinearAlgebra
+using PromptingTools.Experimental.RAGTools
+using JSON3, Serialization, DataFramesMeta
+using Statistics: mean
DocMeta.setdocmeta!(PromptingTools,
:DocTestSetup,
@@ -7,7 +11,7 @@ DocMeta.setdocmeta!(PromptingTools,
recursive = true)
makedocs(;
- modules = [PromptingTools],
+ modules = [PromptingTools, PromptingTools.Experimental.RAGTools],
authors = "J S <49557684+svilupp@users.noreply.github.com> and contributors",
repo = "https://github.com/svilupp/PromptingTools.jl/blob/{commit}{path}#{line}",
sitename = "PromptingTools.jl",
@@ -24,9 +28,14 @@ makedocs(;
"Various examples" => "examples/readme_examples.md",
"Using AITemplates" => "examples/working_with_aitemplates.md",
"Local models with Ollama.ai" => "examples/working_with_ollama.md",
+ "Building RAG Application" => "examples/building_RAG.md",
],
"F.A.Q." => "frequently_asked_questions.md",
- "Reference" => "reference.md",
+ "Reference" => [
+ "PromptingTools.jl" => "reference.md",
+ "Experimental Modules" => "reference_experimental.md",
+ "RAGTools" => "reference_ragtools.md",
+ ],
])
deploydocs(;
diff --git a/docs/src/examples/building_RAG.md b/docs/src/examples/building_RAG.md
new file mode 100644
index 000000000..a14416093
--- /dev/null
+++ b/docs/src/examples/building_RAG.md
@@ -0,0 +1,228 @@
+```@meta
+EditURL = "../../../examples/building_RAG.jl"
+```
+
+# Building a Simple Retrieval-Augmented Generation (RAG) System with RAGTools
+
+Let's build a Retrieval-Augmented Generation (RAG) chatbot, tailored to navigate and interact with the DataFrames.jl documentation.
+"RAG" is probably the most common and valuable pattern in Generative AI at the moment.
+
+If you're not familiar with "RAG", start with this [article](https://towardsdatascience.com/add-your-own-data-to-an-llm-using-retrieval-augmented-generation-rag-b1958bf56a5a).
+
+
+````julia
+using LinearAlgebra, SparseArrays
+using PromptingTools
+using PromptingTools.Experimental.RAGTools
+## Note: RAGTools module is still experimental and will change in the future. Ideally, they will be cleaned up and moved to a dedicated package
+using JSON3, Serialization, DataFramesMeta
+using Statistics: mean
+const PT = PromptingTools
+const RT = PromptingTools.Experimental.RAGTools
+````
+
+## RAG in Two Lines
+
+Let's put together a few text pages from DataFrames.jl docs.
+Simply go to [DataFrames.jl docs](https://dataframes.juliadata.org/stable/) and copy&paste a few pages into separate text files. Save them in the `examples/data` folder (see some example pages provided). Ideally, delete all the noise (like headers, footers, etc.) and keep only the text you want to use for the chatbot. Remember, garbage in, garbage out!
+
+````julia
+files = [
+ joinpath("examples", "data", "database_style_joins.txt"),
+ joinpath("examples", "data", "what_is_dataframes.txt"),
+]
+# Build an index of chunks, embed them, and create a lookup index of metadata/tags for each chunk
+index = build_index(files; extract_metadata = false);
+````
+
+Let's ask a question
+
+````julia
+# Embeds the question, finds the closest chunks in the index, and generates an answer from the closest chunks
+answer = airag(index; question = "I like dplyr, what is the equivalent in Julia?")
+````
+
+````
+AIMessage("The equivalent package in Julia to dplyr in R is DataFramesMeta.jl. It provides convenience functions for data manipulation with syntax similar to dplyr.")
+````
+
+First RAG in two lines? Done!
+
+What does it do?
+- `build_index` will chunk the documents into smaller pieces, embed them into numbers (to be able to judge the similarity of chunks) and, optionally, create a lookup index of metadata/tags for each chunk)
+ - `index` is the result of this step and it holds your chunks, embeddings, and other metadata! Just show it :)
+- `airag` will
+ - embed your question
+ - find the closest chunks in the index (use parameters `top_k` and `minimum_similarity` to tweak the "relevant" chunks)
+ - [OPTIONAL] extracts any potential tags/filters from the question and applies them to filter down the potential candidates (use `extract_metadata=true` in `build_index`, you can also provide some filters explicitly via `tag_filter`)
+ - [OPTIONAL] re-ranks the candidate chunks (define and provide your own `rerank_strategy`, eg Cohere ReRank API)
+ - build a context from the closest chunks (use `chunks_window_margin` to tweak if we include preceding and succeeding chunks as well, see `?build_context` for more details)
+- generate an answer from the closest chunks (use `return_context=true` to see under the hood and debug your application)
+
+You should save the index for later to avoid re-embedding / re-extracting the document chunks!
+
+````julia
+serialize("examples/index.jls", index)
+index = deserialize("examples/index.jls");
+````
+
+# Evaluations
+However, we want to evaluate the quality of the system. For that, we need a set of questions and answers.
+Ideally, we would handcraft a set of high-quality Q&A pairs. However, this is time-consuming and expensive.
+Let's generate them from the chunks in our index!
+
+## Generate Q&A pairs
+
+We need to provide: chunks and sources (file paths for future reference)
+
+````julia
+evals = build_qa_evals(RT.chunks(index),
+ RT.sources(index);
+ instructions = "None.",
+ verbose = true);
+````
+
+````
+[ Info: Q&A Sets built! (cost: $0.102)
+
+````
+
+> [!TIP]
+> In practice, you would review each item in this golden evaluation set (and delete any generic/poor questions).
+> It will determine the future success of your app, so you need to make sure it's good!
+
+````julia
+# Save the evals for later
+JSON3.write("examples/evals.json", evals)
+evals = JSON3.read("examples/evals.json", Vector{RT.QAEvalItem});
+````
+
+## Explore one Q&A pair
+
+Let's explore one evals item -- it's not the best quality but gives you the idea!
+
+````julia
+evals[1]
+````
+
+````
+QAEvalItem:
+ source: examples/data/database_style_joins.txt
+ context: Database-Style Joins
+Introduction to joins
+We often need to combine two or more data sets together to provide a complete picture of the topic we are studying. For example, suppose that we have the following two data sets:
+
+julia> using DataFrames
+ question: What is the purpose of joining two or more data sets together?
+ answer: The purpose of joining two or more data sets together is to provide a complete picture of the topic being studied.
+
+````
+
+## Evaluate this Q&A pair
+
+Let's evaluate this QA item with a "judge model" (often GPT-4 is used as a judge).
+
+````julia
+# Note: that we used the same question, but generated a different context and answer via `airag`
+msg, ctx = airag(index; evals[1].question, return_context = true);
+# ctx is a RAGContext object that keeps all intermediate states of the RAG pipeline for easy evaluation
+judged = aiextract(:RAGJudgeAnswerFromContext;
+ ctx.context,
+ ctx.question,
+ ctx.answer,
+ return_type = RT.JudgeAllScores)
+judged.content
+````
+
+````
+Dict{Symbol, Any} with 6 entries:
+ :final_rating => 4.8
+ :clarity => 5
+ :completeness => 4
+ :relevance => 5
+ :consistency => 5
+ :helpfulness => 5
+````
+
+We can also run the generation + evaluation in a function (a few more metrics are available, eg, retrieval score):
+````julia
+x = run_qa_evals(evals[10], ctx;
+ parameters_dict = Dict(:top_k => 3), verbose = true, model_judge = "gpt4t")
+````
+
+````
+QAEvalResult:
+ source: examples/data/database_style_joins.txt
+ context: outerjoin: the output contains rows for values of the key that exist in any of the passed data frames.
+semijoin: Like an inner join, but output is restricted to columns from the first (left) argument.
+ question: What is the difference between outer join and semi join?
+ answer: The purpose of joining two or more data sets together is to combine them in order to provide a complete picture or analysis of a specific topic or dataset. By joining data sets, we can combine information from multiple sources to gain more insights and make more informed decisions.
+ retrieval_score: 0.0
+ retrieval_rank: nothing
+ answer_score: 5
+ parameters: Dict(:top_k => 3)
+
+````
+
+Fortunately, we don't have to do this one by one -- let's evaluate all our Q&A pairs at once.
+
+## Evaluate the Whole Set
+
+Let's run each question & answer through our eval loop in async (we do it only for the first 10 to save time). See the `?airag` for which parameters you can tweak, eg, `top_k`
+
+````julia
+results = asyncmap(evals[1:10]) do qa_item
+ # Generate an answer -- often you want the model_judge to be the highest quality possible, eg, "GPT-4 Turbo" (alias "gpt4t)
+ msg, ctx = airag(index; qa_item.question, return_context = true,
+ top_k = 3, verbose = false, model_judge = "gpt4t")
+ # Evaluate the response
+ # Note: you can log key parameters for easier analysis later
+ run_qa_evals(qa_item, ctx; parameters_dict = Dict(:top_k => 3), verbose = false)
+end
+## Note that the "failed" evals can show as "nothing" (failed as in there was some API error or parsing error), so make sure to handle them.
+results = filter(x->!isnothing(x.answer_score), results);
+````
+
+Note: You could also use the vectorized version `results = run_qa_evals(evals)` to evaluate all items at once.
+
+````julia
+
+# Let's take a simple average to calculate our score
+@info "RAG Evals: $(length(results)) results, Avg. score: $(round(mean(x->x.answer_score, results);digits=1)), Retrieval score: $(100*round(Int,mean(x->x.retrieval_score,results)))%"
+````
+
+````
+[ Info: RAG Evals: 10 results, Avg. score: 4.6, Retrieval score: 100%
+
+````
+
+Note: The retrieval score is 100% only because we have two small documents and running on 10 items only. In practice, you would have a much larger document set and a much larger eval set, which would result in a more representative retrieval score.
+
+You can also analyze the results in a DataFrame:
+
+````julia
+df = DataFrame(results)
+````
+
+```@raw html
+
1 | examples/data/database_style_joins.txt | Database-Style Joins\nIntroduction to joins\nWe often need to combine two or more data sets together to provide a complete picture of the topic we are studying. For example, suppose that we have the following two data sets:\n\njulia> using DataFrames | What is the purpose of joining two or more data sets together? | The purpose of joining two or more data sets together is to combine the data sets based on a common key and provide a complete picture of the topic being studied. | 1.0 | 1 | 5.0 | Dict(:top_k=>3) |
2 | examples/data/database_style_joins.txt | julia> people = DataFrame(ID=[20, 40], Name=["John Doe", "Jane Doe"])\n2×2 DataFrame\n Row │ ID Name\n │ Int64 String\n─────┼─────────────────\n 1 │ 20 John Doe\n 2 │ 40 Jane Doe | What is the DataFrame called 'people' composed of? | The DataFrame called 'people' consists of two columns: 'ID' and 'Name'. The 'ID' column contains integers, and the 'Name' column contains strings. | 1.0 | 1 | 4.0 | Dict(:top_k=>3) |
3 | examples/data/database_style_joins.txt | julia> jobs = DataFrame(ID=[20, 40], Job=["Lawyer", "Doctor"])\n2×2 DataFrame\n Row │ ID Job\n │ Int64 String\n─────┼───────────────\n 1 │ 20 Lawyer\n 2 │ 40 Doctor | What are the jobs and IDs listed in the dataframe? | The jobs and IDs listed in the dataframe are as follows:\n\nID: 20\nJob: Lawyer\n\nID: 40\nJob: Doctor | 1.0 | 1 | 4.67 | Dict(:top_k=>3) |
4 | examples/data/database_style_joins.txt | We might want to work with a larger data set that contains both the names and jobs for each ID. We can do this using the innerjoin function: | How can we combine the names and jobs for each ID in a larger data set? | We can use the `innerjoin` function to combine the names and jobs for each ID in a larger data set. | 1.0 | 1 | 4.33333 | Dict(:top_k=>3) |
5 | examples/data/database_style_joins.txt | julia> innerjoin(people, jobs, on = :ID)\n2×3 DataFrame\n Row │ ID Name Job\n │ Int64 String String\n─────┼─────────────────────────\n 1 │ 20 John Doe Lawyer\n 2 │ 40 Jane Doe Doctor | What is the name of the person with the ID 40 and their job? | The name of the person with the ID 40 is Jane Doe and their job is Doctor. | 1.0 | 2 | 4.67 | Dict(:top_k=>3) |
6 | examples/data/database_style_joins.txt | In relational database theory, this operation is generally referred to as a join. The columns used to determine which rows should be combined during a join are called keys.\n\nThe following functions are provided to perform seven kinds of joins: | What are the different kinds of joins? | The different kinds of joins are:\n\n1. Inner Join: Returns only the rows that have matching values in both data frames.\n2. Left Join: Returns all rows from the left data frame and the matching rows from the right data frame.\n3. Right Join: Returns all rows from the right data frame and the matching rows from the left data frame.\n4. Full Outer Join: Returns all rows from both data frames and fills in missing values with null.\n5. Cross Join: Returns the cartesian product of the rows from both data frames.\n6. Semi Join: Returns only the rows from the left data frame that have matching values in the right data frame.\n7. Anti Join: Returns only the rows from the left data frame that do not have matching values in the right data frame. | 1.0 | 1 | 4.66667 | Dict(:top_k=>3) |
7 | examples/data/database_style_joins.txt | innerjoin: the output contains rows for values of the key that exist in all passed data frames. | What does the output of the inner join operation contain? | The output of the inner join operation contains only the rows for values of the key that exist in all passed data frames. | 1.0 | 1 | 5.0 | Dict(:top_k=>3) |
8 | examples/data/database_style_joins.txt | leftjoin: the output contains rows for values of the key that exist in the first (left) argument, whether or not that value exists in the second (right) argument. | What is the purpose of the left join operation? | The purpose of the left join operation is to combine data from two tables based on a common key, where all rows from the left (first) table are included in the output, regardless of whether there is a match in the right (second) table. | 1.0 | 1 | 4.66667 | Dict(:top_k=>3) |
9 | examples/data/database_style_joins.txt | rightjoin: the output contains rows for values of the key that exist in the second (right) argument, whether or not that value exists in the first (left) argument. | What is the purpose of the right join operation? | The purpose of the right join operation is to include all the rows from the second (right) argument, regardless of whether a match is found in the first (left) argument. | 1.0 | 1 | 4.67 | Dict(:top_k=>3) |
10 | examples/data/database_style_joins.txt | outerjoin: the output contains rows for values of the key that exist in any of the passed data frames.\nsemijoin: Like an inner join, but output is restricted to columns from the first (left) argument. | What is the difference between outer join and semi join? | The difference between outer join and semi join is that outer join includes rows for values of the key that exist in any of the passed data frames, whereas semi join is like an inner join but only outputs columns from the first argument. | 1.0 | 1 | 4.66667 | Dict(:top_k=>3) |
+```
+
+We're done for today!
+
+# What would we do next?
+- Review your evaluation golden data set and keep only the good items
+- Play with the chunk sizes (max_length in build_index) and see how it affects the quality
+- Explore using metadata/key filters (`extract_metadata=true` in build_index)
+- Add filtering for semantic similarity (embedding distance) to make sure we don't pick up irrelevant chunks in the context
+- Use multiple indices or a hybrid index (add a simple BM25 lookup from TextAnalysis.jl)
+- Data processing is the most important step - properly parsed and split text could make wonders
+- Add re-ranking of context (see `rerank` function, you can use Cohere ReRank API)`)
+- Improve the question embedding (eg, rephrase it, generate hypothetical answers and use them to find better context)
+
+... and much more! See some ideas in [Anyscale RAG tutorial](https://www.anyscale.com/blog/a-comprehensive-guide-for-building-rag-based-llm-applications-part-1)
+
+---
+
+*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
+
diff --git a/docs/src/reference_experimental.md b/docs/src/reference_experimental.md
new file mode 100644
index 000000000..428775406
--- /dev/null
+++ b/docs/src/reference_experimental.md
@@ -0,0 +1,12 @@
+# Reference for Experimental Module
+
+Note: This module is experimental and may change in future releases.
+The intention is for the functionality to be moved to separate packages over time.
+
+```@index
+Modules = [PromptingTools.Experimental]
+```
+
+```@autodocs
+Modules = [PromptingTools.Experimental]
+```
diff --git a/docs/src/reference_ragtools.md b/docs/src/reference_ragtools.md
new file mode 100644
index 000000000..5f7bd23d9
--- /dev/null
+++ b/docs/src/reference_ragtools.md
@@ -0,0 +1,9 @@
+# Reference for RAGTools
+
+```@index
+Modules = [PromptingTools.Experimental.RAGTools]
+```
+
+```@autodocs
+Modules = [PromptingTools.Experimental.RAGTools]
+```
diff --git a/examples/building_RAG.jl b/examples/building_RAG.jl
new file mode 100644
index 000000000..d868830f1
--- /dev/null
+++ b/examples/building_RAG.jl
@@ -0,0 +1,147 @@
+# # Building a Simple Retrieval-Augmented Generation (RAG) System with RAGTools
+
+# Let's build a Retrieval-Augmented Generation (RAG) chatbot, tailored to navigate and interact with the DataFrames.jl documentation.
+# "RAG" is probably the most common and valuable pattern in Generative AI at the moment.
+
+# If you're not familiar with "RAG", start with this [article](https://towardsdatascience.com/add-your-own-data-to-an-llm-using-retrieval-augmented-generation-rag-b1958bf56a5a).
+
+## Imports
+using LinearAlgebra, SparseArrays
+using PromptingTools
+## Note: RAGTools is still experimental and will change in the future. Ideally, they will be cleaned up and moved to a dedicated package
+using PromptingTools.Experimental.RAGTools
+using JSON3, Serialization, DataFramesMeta
+using Statistics: mean
+const PT = PromptingTools
+const RT = PromptingTools.Experimental.RAGTools
+
+# ## RAG in Two Lines
+
+# Let's put together a few text pages from DataFrames.jl docs.
+# Simply go to [DataFrames.jl docs](https://dataframes.juliadata.org/stable/) and copy&paste a few pages into separate text files. Save them in the `examples/data` folder (see some example pages provided). Ideally, delete all the noise (like headers, footers, etc.) and keep only the text you want to use for the chatbot. Remember, garbage in, garbage out!
+
+files = [
+ joinpath("examples", "data", "database_style_joins.txt"),
+ joinpath("examples", "data", "what_is_dataframes.txt"),
+]
+## Build an index of chunks, embed them, and create a lookup index of metadata/tags for each chunk
+index = build_index(files; extract_metadata = false)
+
+# Let's ask a question
+## Embeds the question, finds the closest chunks in the index, and generates an answer from the closest chunks
+answer = airag(index; question = "I like dplyr, what is the equivalent in Julia?")
+
+# First RAG in two lines? Done!
+#
+# What does it do?
+# - `build_index` will chunk the documents into smaller pieces, embed them into numbers (to be able to judge the similarity of chunks) and, optionally, create a lookup index of metadata/tags for each chunk)
+# - `index` is the result of this step and it holds your chunks, embeddings, and other metadata! Just show it :)
+# - `airag` will
+# - embed your question
+# - find the closest chunks in the index (use parameters `top_k` and `minimum_similarity` to tweak the "relevant" chunks)
+# - [OPTIONAL] extracts any potential tags/filters from the question and applies them to filter down the potential candidates (use `extract_metadata=true` in `build_index`, you can also provide some filters explicitly via `tag_filter`)
+# - [OPTIONAL] re-ranks the candidate chunks (define and provide your own `rerank_strategy`, eg Cohere ReRank API)
+# - build a context from the closest chunks (use `chunks_window_margin` to tweak if we include preceding and succeeding chunks as well, see `?build_context` for more details)
+# - generate an answer from the closest chunks (use `return_context=true` to see under the hood and debug your application)
+
+# You should save the index for later to avoid re-embedding / re-extracting the document chunks!
+serialize("examples/index.jls", index)
+index = deserialize("examples/index.jls")
+
+# # Evaluations
+# However, we want to evaluate the quality of the system. For that, we need a set of questions and answers.
+# Ideally, we would hand-craft a set of high quality Q&A pairs. However, this is time consuming and expensive.
+# Let's generate them from the chunks in our index!
+
+# ## Generate Q&A pairs
+
+# We need to provide: chunks and sources (filepaths for future reference)
+evals = build_qa_evals(RT.chunks(index),
+ RT.sources(index);
+ instructions = "None.",
+ verbose = true);
+## Info: Q&A Sets built! (cost: $0.143) -- not bad!
+
+# > [!TIP]
+# > In practice, you would review each item in this golden evaluation set (and delete any generic/poor questions).
+# > It will determine the future success of your app, so you need to make sure it's good!
+
+## Save the evals for later
+JSON3.write("examples/evals.json", evals)
+evals = JSON3.read("examples/evals.json", Vector{RT.QAEvalItem});
+
+# ## Explore one Q&A pair
+# Let's explore one evals item -- it's not the best but gives you the idea!
+#
+evals[1]
+
+# ## Evaluate this Q&A pair
+
+# Let's evaluate this QA item with a "judge model" (often GPT-4 is used as a judge).
+
+## Note: that we used the same question, but generated a different context and answer via `airag`
+msg, ctx = airag(index; evals[1].question, return_context = true);
+
+## ctx is a RAGContext object that keeps all intermediate states of the RAG pipeline for easy evaluation
+judged = aiextract(:RAGJudgeAnswerFromContext;
+ ctx.context,
+ ctx.question,
+ ctx.answer,
+ return_type = RT.JudgeAllScores)
+judged.content
+## Dict{Symbol, Any} with 7 entries:
+## :final_rating => 4.8
+## :clarity => 5
+## :completeness => 5
+## :relevance => 5
+## :consistency => 4
+## :helpfulness => 5
+## :rationale => "The answer is highly relevant to the user's question, as it provides a comprehensive list of frameworks that are compared with DataFrames.jl. The answer is complete, covering all
+
+# We can also run the whole evaluation in a function (a few more metrics are available):
+x = run_qa_evals(evals[10], ctx;
+ parameters_dict = Dict(:top_k => 3), verbose = true, model_judge = "gpt4t")
+
+# Fortunately, we don't have to do this one by one -- let's evaluate all our Q&A pairs at once.
+
+# ## Evaluate the whole set
+
+# Let's run each question & answer through our eval loop in async (we do it only for the first 10 to save time). See the `?airag` for which parameters you can tweak, eg, `top_k`
+
+results = asyncmap(evals[1:10]) do qa_item
+ ## Generate an answer -- often you want the model_judge to be the highest quality possible, eg, "GPT-4 Turbo" (alias "gpt4t)
+ msg, ctx = airag(index; qa_item.question, return_context = true,
+ top_k = 3, verbose = false, model_judge = "gpt4t")
+ ## Evaluate the response
+ ## Note: you can log key parameters for easier analysis later
+ run_qa_evals(qa_item, ctx; parameters_dict = Dict(:top_k => 3), verbose = false)
+end
+## Note that the "failed" evals can show as "nothing", so make sure to handle them.
+results = filter(x -> !isnothing(x.answer_score), results);
+
+# Note: You could also use the vectorized version `results = run_qa_evals(evals)` to evaluate all items at once.
+
+## Let's take a simple average to calculate our score
+@info "RAG Evals: $(length(results)) results, Avg. score: $(round(mean(x->x.answer_score, results);digits=1)), Retrieval score: $(100*round(Int,mean(x->x.retrieval_score,results)))%"
+## [ Info: RAG Evals: 10 results, Avg. score: 4.6, Retrieval score: 100%
+
+# Note: The retrieval score is 100% only because we have two small documents and running on 10 items only. In practice, you would have a much larger document set and a much larger eval set, which would result in a more representative retrieval score.
+
+# You can also analyze the results in a DataFrame:
+
+df = DataFrame(results)
+first(df, 5)
+
+# We're done for today!
+
+# # What would we do next?
+# - Review your evaluation golden data set and keep only the good items
+# - Play with the chunk sizes (max_length in build_index) and see how it affects the quality
+# - Explore using metadata/key filters (`extract_metadata=true` in build_index)
+# - Add filtering for semantic similarity (embedding distance) to make sure we don't pick up irrelevant chunks in the context
+# - Use multiple indices or a hybrid index (add a simple BM25 lookup from TextAnalysis.jl)
+# - Data processing is the most important step - properly parsed and split text could make wonders
+# - Add re-ranking of context (see `rerank` function, you can use Cohere ReRank API)`)
+# - Improve the question embedding (eg, rephrase it, generate hypothetical answers and use them to find better context)
+#
+# ... and much more! See some ideas in [Anyscale RAG tutorial](https://www.anyscale.com/blog/a-comprehensive-guide-for-building-rag-based-llm-applications-part-1)
\ No newline at end of file
diff --git a/examples/data/database_style_joins.txt b/examples/data/database_style_joins.txt
new file mode 100644
index 000000000..9e04ecab1
--- /dev/null
+++ b/examples/data/database_style_joins.txt
@@ -0,0 +1,392 @@
+Database-Style Joins
+Introduction to joins
+We often need to combine two or more data sets together to provide a complete picture of the topic we are studying. For example, suppose that we have the following two data sets:
+
+julia> using DataFrames
+
+julia> people = DataFrame(ID=[20, 40], Name=["John Doe", "Jane Doe"])
+2×2 DataFrame
+ Row │ ID Name
+ │ Int64 String
+─────┼─────────────────
+ 1 │ 20 John Doe
+ 2 │ 40 Jane Doe
+
+julia> jobs = DataFrame(ID=[20, 40], Job=["Lawyer", "Doctor"])
+2×2 DataFrame
+ Row │ ID Job
+ │ Int64 String
+─────┼───────────────
+ 1 │ 20 Lawyer
+ 2 │ 40 Doctor
+
+We might want to work with a larger data set that contains both the names and jobs for each ID. We can do this using the innerjoin function:
+
+julia> innerjoin(people, jobs, on = :ID)
+2×3 DataFrame
+ Row │ ID Name Job
+ │ Int64 String String
+─────┼─────────────────────────
+ 1 │ 20 John Doe Lawyer
+ 2 │ 40 Jane Doe Doctor
+
+In relational database theory, this operation is generally referred to as a join. The columns used to determine which rows should be combined during a join are called keys.
+
+The following functions are provided to perform seven kinds of joins:
+
+innerjoin: the output contains rows for values of the key that exist in all passed data frames.
+leftjoin: the output contains rows for values of the key that exist in the first (left) argument, whether or not that value exists in the second (right) argument.
+rightjoin: the output contains rows for values of the key that exist in the second (right) argument, whether or not that value exists in the first (left) argument.
+outerjoin: the output contains rows for values of the key that exist in any of the passed data frames.
+semijoin: Like an inner join, but output is restricted to columns from the first (left) argument.
+antijoin: The output contains rows for values of the key that exist in the first (left) but not the second (right) argument. As with semijoin, output is restricted to columns from the first (left) argument.
+crossjoin: The output is the cartesian product of rows from all passed data frames.
+See the Wikipedia page on SQL joins for more information.
+
+Here are examples of different kinds of join:
+
+julia> jobs = DataFrame(ID=[20, 60], Job=["Lawyer", "Astronaut"])
+2×2 DataFrame
+ Row │ ID Job
+ │ Int64 String
+─────┼──────────────────
+ 1 │ 20 Lawyer
+ 2 │ 60 Astronaut
+
+julia> innerjoin(people, jobs, on = :ID)
+1×3 DataFrame
+ Row │ ID Name Job
+ │ Int64 String String
+─────┼─────────────────────────
+ 1 │ 20 John Doe Lawyer
+
+julia> leftjoin(people, jobs, on = :ID)
+2×3 DataFrame
+ Row │ ID Name Job
+ │ Int64 String String?
+─────┼──────────────────────────
+ 1 │ 20 John Doe Lawyer
+ 2 │ 40 Jane Doe missing
+
+julia> rightjoin(people, jobs, on = :ID)
+2×3 DataFrame
+ Row │ ID Name Job
+ │ Int64 String? String
+─────┼────────────────────────────
+ 1 │ 20 John Doe Lawyer
+ 2 │ 60 missing Astronaut
+
+julia> outerjoin(people, jobs, on = :ID)
+3×3 DataFrame
+ Row │ ID Name Job
+ │ Int64 String? String?
+─────┼────────────────────────────
+ 1 │ 20 John Doe Lawyer
+ 2 │ 40 Jane Doe missing
+ 3 │ 60 missing Astronaut
+
+julia> semijoin(people, jobs, on = :ID)
+1×2 DataFrame
+ Row │ ID Name
+ │ Int64 String
+─────┼─────────────────
+ 1 │ 20 John Doe
+
+julia> antijoin(people, jobs, on = :ID)
+1×2 DataFrame
+ Row │ ID Name
+ │ Int64 String
+─────┼─────────────────
+ 1 │ 40 Jane Doe
+
+Cross joins are the only kind of join that does not use a on key:
+
+julia> crossjoin(people, jobs, makeunique = true)
+4×4 DataFrame
+ Row │ ID Name ID_1 Job
+ │ Int64 String Int64 String
+─────┼───────────────────────────────────
+ 1 │ 20 John Doe 20 Lawyer
+ 2 │ 20 John Doe 60 Astronaut
+ 3 │ 40 Jane Doe 20 Lawyer
+ 4 │ 40 Jane Doe 60 Astronaut
+
+Key value comparisons and floating point values
+Key values from the two or more data frames are compared using the isequal function. This is consistent with the Set and Dict types in Julia Base.
+
+It is not recommended to use floating point numbers as keys: floating point comparisons can be surprising and unpredictable. If you do use floating point keys, note that by default an error is raised when keys include -0.0 (negative zero) or NaN values. Here is an example:
+
+julia> innerjoin(DataFrame(id=[-0.0]), DataFrame(id=[0.0]), on=:id)
+ERROR: ArgumentError: Currently for numeric values `NaN` and `-0.0` in their real or imaginary components are not allowed. Such value was found in column :id in left data frame. Use CategoricalArrays.jl to wrap these values in a CategoricalVector to perform the requested join.
+
+This can be overridden by wrapping the key values in a categorical vector.
+
+Joining on key columns with different names
+In order to join data frames on keys which have different names in the left and right tables, you may pass left => right pairs as on argument:
+
+julia> a = DataFrame(ID=[20, 40], Name=["John Doe", "Jane Doe"])
+2×2 DataFrame
+ Row │ ID Name
+ │ Int64 String
+─────┼─────────────────
+ 1 │ 20 John Doe
+ 2 │ 40 Jane Doe
+
+julia> b = DataFrame(IDNew=[20, 40], Job=["Lawyer", "Doctor"])
+2×2 DataFrame
+ Row │ IDNew Job
+ │ Int64 String
+─────┼───────────────
+ 1 │ 20 Lawyer
+ 2 │ 40 Doctor
+
+julia> innerjoin(a, b, on = :ID => :IDNew)
+2×3 DataFrame
+ Row │ ID Name Job
+ │ Int64 String String
+─────┼─────────────────────────
+ 1 │ 20 John Doe Lawyer
+ 2 │ 40 Jane Doe Doctor
+
+Here is another example with multiple columns:
+
+julia> a = DataFrame(City=["Amsterdam", "London", "London", "New York", "New York"],
+ Job=["Lawyer", "Lawyer", "Lawyer", "Doctor", "Doctor"],
+ Category=[1, 2, 3, 4, 5])
+5×3 DataFrame
+ Row │ City Job Category
+ │ String String Int64
+─────┼─────────────────────────────
+ 1 │ Amsterdam Lawyer 1
+ 2 │ London Lawyer 2
+ 3 │ London Lawyer 3
+ 4 │ New York Doctor 4
+ 5 │ New York Doctor 5
+
+julia> b = DataFrame(Location=["Amsterdam", "London", "London", "New York", "New York"],
+ Work=["Lawyer", "Lawyer", "Lawyer", "Doctor", "Doctor"],
+ Name=["a", "b", "c", "d", "e"])
+5×3 DataFrame
+ Row │ Location Work Name
+ │ String String String
+─────┼───────────────────────────
+ 1 │ Amsterdam Lawyer a
+ 2 │ London Lawyer b
+ 3 │ London Lawyer c
+ 4 │ New York Doctor d
+ 5 │ New York Doctor e
+
+julia> innerjoin(a, b, on = [:City => :Location, :Job => :Work])
+9×4 DataFrame
+ Row │ City Job Category Name
+ │ String String Int64 String
+─────┼─────────────────────────────────────
+ 1 │ Amsterdam Lawyer 1 a
+ 2 │ London Lawyer 2 b
+ 3 │ London Lawyer 3 b
+ 4 │ London Lawyer 2 c
+ 5 │ London Lawyer 3 c
+ 6 │ New York Doctor 4 d
+ 7 │ New York Doctor 5 d
+ 8 │ New York Doctor 4 e
+ 9 │ New York Doctor 5 e
+
+Handling of duplicate keys and tracking source data frame
+Additionally, notice that in the last join rows 2 and 3 had the same values on on variables in both joined DataFrames. In such a situation innerjoin, outerjoin, leftjoin and rightjoin will produce all combinations of matching rows. In our example rows from 2 to 5 were created as a result. The same behavior can be observed for rows 4 and 5 in both joined DataFrames.
+
+In order to check that columns passed as the on argument define unique keys (according to isequal) in each input data frame you can set the validate keyword argument to a two-element tuple or a pair of Bool values, with each element indicating whether to run check for the corresponding data frame. Here is an example for the join operation described above:
+
+julia> innerjoin(a, b, on = [(:City => :Location), (:Job => :Work)], validate=(true, true))
+ERROR: ArgumentError: Merge key(s) are not unique in both df1 and df2. df1 contains 2 duplicate keys: (City = "London", Job = "Lawyer") and (City = "New York", Job = "Doctor"). df2 contains 2 duplicate keys: (Location = "London", Work = "Lawyer") and (Location = "New York", Work = "Doctor").
+
+Finally, using the source keyword argument you can add a column to the resulting data frame indicating whether the given row appeared only in the left, the right or both data frames. Here is an example:
+
+julia> a = DataFrame(ID=[20, 40], Name=["John", "Jane"])
+2×2 DataFrame
+ Row │ ID Name
+ │ Int64 String
+─────┼───────────────
+ 1 │ 20 John
+ 2 │ 40 Jane
+
+julia> b = DataFrame(ID=[20, 60], Job=["Lawyer", "Doctor"])
+2×2 DataFrame
+ Row │ ID Job
+ │ Int64 String
+─────┼───────────────
+ 1 │ 20 Lawyer
+ 2 │ 60 Doctor
+
+julia> outerjoin(a, b, on=:ID, validate=(true, true), source=:source)
+3×4 DataFrame
+ Row │ ID Name Job source
+ │ Int64 String? String? String
+─────┼─────────────────────────────────────
+ 1 │ 20 John Lawyer both
+ 2 │ 40 Jane missing left_only
+ 3 │ 60 missing Doctor right_only
+
+Note that this time we also used the validate keyword argument and it did not produce errors as the keys defined in both source data frames were unique.
+
+Renaming joined columns
+Often you want to keep track of the source data frame. This feature is supported with the renamecols keyword argument:
+
+julia> innerjoin(a, b, on=:ID, renamecols = "_left" => "_right")
+1×3 DataFrame
+ Row │ ID Name_left Job_right
+ │ Int64 String String
+─────┼─────────────────────────────
+ 1 │ 20 John Lawyer
+
+In the above example we added the "_left" suffix to the non-key columns from the left table and the "_right" suffix to the non-key columns from the right table.
+
+Alternatively it is allowed to pass a function transforming column names:
+
+julia> innerjoin(a, b, on=:ID, renamecols = lowercase => uppercase)
+1×3 DataFrame
+ Row │ ID name JOB
+ │ Int64 String String
+─────┼───────────────────────
+ 1 │ 20 John Lawyer
+
+Matching missing values in joins
+By default when you try to to perform a join on a key that has missing values you get an error:
+
+julia> df1 = DataFrame(id=[1, missing, 3], a=1:3)
+3×2 DataFrame
+ Row │ id a
+ │ Int64? Int64
+─────┼────────────────
+ 1 │ 1 1
+ 2 │ missing 2
+ 3 │ 3 3
+
+julia> df2 = DataFrame(id=[1, 2, missing], b=1:3)
+3×2 DataFrame
+ Row │ id b
+ │ Int64? Int64
+─────┼────────────────
+ 1 │ 1 1
+ 2 │ 2 2
+ 3 │ missing 3
+
+julia> innerjoin(df1, df2, on=:id)
+ERROR: ArgumentError: Missing values in key columns are not allowed when matchmissing == :error. `missing` found in column :id in left data frame.
+
+If you would prefer missing values to be treated as equal pass the matchmissing=:equal keyword argument:
+
+julia> innerjoin(df1, df2, on=:id, matchmissing=:equal)
+2×3 DataFrame
+ Row │ id a b
+ │ Int64? Int64 Int64
+─────┼───────────────────────
+ 1 │ 1 1 1
+ 2 │ missing 2 3
+
+Alternatively you might want to drop all rows with missing values. In this case pass matchmissing=:notequal:
+
+julia> innerjoin(df1, df2, on=:id, matchmissing=:notequal)
+1×3 DataFrame
+ Row │ id a b
+ │ Int64? Int64 Int64
+─────┼──────────────────────
+ 1 │ 1 1 1
+
+Specifying row order in the join result
+By default the order of rows produced by the join operation is undefined:
+
+julia> df_left = DataFrame(id=[1, 2, 4, 5], left=1:4)
+4×2 DataFrame
+ Row │ id left
+ │ Int64 Int64
+─────┼──────────────
+ 1 │ 1 1
+ 2 │ 2 2
+ 3 │ 4 3
+ 4 │ 5 4
+
+julia> df_right = DataFrame(id=[2, 1, 3, 6, 7], right=1:5)
+5×2 DataFrame
+ Row │ id right
+ │ Int64 Int64
+─────┼──────────────
+ 1 │ 2 1
+ 2 │ 1 2
+ 3 │ 3 3
+ 4 │ 6 4
+ 5 │ 7 5
+
+julia> outerjoin(df_left, df_right, on=:id)
+7×3 DataFrame
+ Row │ id left right
+ │ Int64 Int64? Int64?
+─────┼─────────────────────────
+ 1 │ 2 2 1
+ 2 │ 1 1 2
+ 3 │ 4 3 missing
+ 4 │ 5 4 missing
+ 5 │ 3 missing 3
+ 6 │ 6 missing 4
+ 7 │ 7 missing 5
+
+If you would like the result to keep the row order of the left table pass the order=:left keyword argument:
+
+julia> outerjoin(df_left, df_right, on=:id, order=:left)
+7×3 DataFrame
+ Row │ id left right
+ │ Int64 Int64? Int64?
+─────┼─────────────────────────
+ 1 │ 1 1 2
+ 2 │ 2 2 1
+ 3 │ 4 3 missing
+ 4 │ 5 4 missing
+ 5 │ 3 missing 3
+ 6 │ 6 missing 4
+ 7 │ 7 missing 5
+
+Note that in this case keys missing from the left table are put after the keys present in it.
+
+Similarly order=:right keeps the order of the right table (and puts keys not present in it at the end):
+
+julia> outerjoin(df_left, df_right, on=:id, order=:right)
+7×3 DataFrame
+ Row │ id left right
+ │ Int64 Int64? Int64?
+─────┼─────────────────────────
+ 1 │ 2 2 1
+ 2 │ 1 1 2
+ 3 │ 3 missing 3
+ 4 │ 6 missing 4
+ 5 │ 7 missing 5
+ 6 │ 4 3 missing
+ 7 │ 5 4 missing
+
+In-place left join
+A common operation is adding data from a reference table to some main table. It is possible to perform such an in-place update using the leftjoin! function. In this case the left table is updated in place with matching rows from the right table.
+
+julia> main = DataFrame(id=1:4, main=1:4)
+4×2 DataFrame
+ Row │ id main
+ │ Int64 Int64
+─────┼──────────────
+ 1 │ 1 1
+ 2 │ 2 2
+ 3 │ 3 3
+ 4 │ 4 4
+
+julia> leftjoin!(main, DataFrame(id=[2, 4], info=["a", "b"]), on=:id);
+
+julia> main
+4×3 DataFrame
+ Row │ id main info
+ │ Int64 Int64 String?
+─────┼───────────────────────
+ 1 │ 1 1 missing
+ 2 │ 2 2 a
+ 3 │ 3 3 missing
+ 4 │ 4 4 b
+
+Note that in this case the order and number of rows in the left table is not changed. Therefore, in particular, it is not allowed to have duplicate keys in the right table:
+
+julia> leftjoin!(main, DataFrame(id=[2, 2], info_bad=["a", "b"]), on=:id)
+ERROR: ArgumentError: duplicate rows found in right table
\ No newline at end of file
diff --git a/examples/data/what_is_dataframes.txt b/examples/data/what_is_dataframes.txt
new file mode 100644
index 000000000..c641aa202
--- /dev/null
+++ b/examples/data/what_is_dataframes.txt
@@ -0,0 +1,141 @@
+Welcome to the DataFrames.jl documentation!
+
+This resource aims to teach you everything you need to know to get up and running with tabular data manipulation using the DataFrames.jl package.
+
+For more illustrations of DataFrames.jl usage, in particular in conjunction with other packages you can check-out the following resources (they are kept up to date with the released version of DataFrames.jl):
+
+What is DataFrames.jl?
+DataFrames.jl provides a set of tools for working with tabular data in Julia. Its design and functionality are similar to those of pandas (in Python) and data.frame, data.table and dplyr (in R), making it a great general purpose data science tool.
+
+DataFrames.jl plays a central role in the Julia Data ecosystem, and has tight integrations with a range of different libraries. DataFrames.jl isn't the only tool for working with tabular data in Julia – as noted below, there are some other great libraries for certain use-cases – but it provides great data wrangling functionality through a familiar interface.
+
+To understand the toolchain in more detail, have a look at the tutorials in this manual. New users can start with the First Steps with DataFrames.jl section.
+
+You may find the DataFramesMeta.jl package or one of the other convenience packages discussed in the Data manipulation frameworks section of this manual helpful when writing more advanced data transformations, especially if you do not have a significant programming experience. These packages provide convenience syntax similar to dplyr in R.
+
+If you use metadata when working with DataFrames.jl you might find the TableMetadataTools.jl package useful. This package defines several convenience functions for performing typical metadata operations.
+
+DataFrames.jl and the Julia Data Ecosystem
+The Julia data ecosystem can be a difficult space for new users to navigate, in part because the Julia ecosystem tends to distribute functionality across different libraries more than some other languages. Because many people coming to DataFrames.jl are just starting to explore the Julia data ecosystem, below is a list of well-supported libraries that provide different data science tools, along with a few notes about what makes each library special, and how well integrated they are with DataFrames.jl.
+
+Statistics
+StatsKit.jl: A convenience meta-package which loads a set of essential packages for statistics, including those mentioned below in this section and DataFrames.jl itself.
+Statistics: The Julia standard library comes with a wide range of statistics functionality, but to gain access to these functions you must call using Statistics.
+LinearAlgebra: Like Statistics, many linear algebra features (factorizations, inversions, etc.) live in a library you have to load to use.
+SparseArrays are also in the standard library but must be loaded to be used.
+FreqTables.jl: Create frequency tables / cross-tabulations. Tightly integrated with DataFrames.jl.
+HypothesisTests.jl: A range of hypothesis testing tools.
+GLM.jl: Tools for estimating linear and generalized linear models. Tightly integrated with DataFrames.jl.
+StatsModels.jl: For converting heterogeneous DataFrame into homogeneous matrices for use with linear algebra libraries or machine learning applications that don't directly support DataFrames. Will do things like convert categorical variables into indicators/one-hot-encodings, create interaction terms, etc.
+MultivariateStats.jl: linear regression, ridge regression, PCA, component analyses tools. Not well integrated with DataFrames.jl, but easily used in combination with StatsModels.
+Machine Learning
+MLJ.jl: if you're more of an applied user, there is a single package the pulls from all these different libraries and provides a single, scikit-learn inspired API: MLJ.jl. MLJ.jl provides a common interface for a wide range of machine learning algorithms.
+ScikitLearn.jl: A Julia wrapper around the full Python scikit-learn machine learning library. Not well integrated with DataFrames.jl, but can be combined using StatsModels.jl.
+AutoMLPipeline: A package that makes it trivial to create complex ML pipeline structures using simple expressions. It leverages on the built-in macro programming features of Julia to symbolically process, manipulate pipeline expressions, and makes it easy to discover optimal structures for machine learning regression and classification.
+Deep learning: KNet.jl and Flux.jl.
+Plotting
+Plots.jl: Powerful, modern plotting library with a syntax akin to that of matplotlib (in Python) or plot (in R). StatsPlots.jl provides Plots.jl with recipes for many standard statistical plots.
+Gadfly.jl: High-level plotting library with a "grammar of graphics" syntax akin to that of ggplot (in R).
+AlgebraOfGraphics.jl: A "grammar of graphics" library build upon Makie.jl.
+VegaLite.jl: High-level plotting library that uses a different "grammar of graphics" syntax and has an emphasis on interactive graphics.
+Data Wrangling:
+Impute.jl: various methods for handling missing data in vectors, matrices and tables.
+DataFramesMeta.jl: A range of convenience functions for DataFrames.jl that augment select and transform to provide a user experience similar to that provided by dplyr in R.
+DataFrameMacros.jl: Provides macro versions of the common DataFrames.jl functions similar to DataFramesMeta.jl, with convenient syntax for the manipulation of multiple columns at once.
+Query.jl: Query.jl provides a single framework for data wrangling that works with a range of libraries, including DataFrames.jl, other tabular data libraries (more on those below), and even non-tabular data. Provides many convenience functions analogous to those in dplyr in R or LINQ.
+You can find more information on these packages in the Data manipulation frameworks section of this manual.
+And More!
+Graphs.jl: A pure-Julia, high performance network analysis library. Edgelists in DataFrames can be easily converted into graphs using the GraphDataFrameBridge.jl package.
+IO:
+DataFrames.jl work well with a range of formats, including:
+CSV files (using CSV.jl),
+Apache Arrow (using Arrow.jl)
+reading Stata, SAS and SPSS files (using ReadStatTables.jl; alternatively Queryverse users can choose StatFiles.jl),
+Parquet files (using Parquet2.jl),
+reading R data files (.rda, .RData) (using RData.jl).
+While not all of these libraries are tightly integrated with DataFrames.jl, because DataFrames are essentially collections of aligned Julia vectors, so it is easy to (a) pull out a vector for use with a non-DataFrames-integrated library, or (b) convert your table into a homogeneously-typed matrix using the Matrix constructor or StatsModels.jl.
+
+Other Julia Tabular Libraries
+DataFrames.jl is a great general purpose tool for data manipulation and wrangling, but it's not ideal for all applications. For users with more specialized needs, consider using:
+
+TypedTables.jl: Type-stable heterogeneous tables. Useful for improved performance when the structure of your table is relatively stable and does not feature thousands of columns.
+JuliaDB.jl: For users working with data that is too large to fit in memory, we suggest JuliaDB.jl, which offers better performance for large datasets, and can handle out-of-core data manipulations (Python users can think of JuliaDB.jl as the Julia version of dask).
+Note that most tabular data libraries in the Julia ecosystem (including DataFrames.jl) support a common interface (defined in the Tables.jl package). As a result, some libraries are capable or working with a range of tabular data structures, making it easy to move between tabular libraries as your needs change. A user of Query.jl, for example, can use the same code to manipulate data in a DataFrame, a Table (defined by TypedTables.jl), or a JuliaDB table.
+
+Questions?
+If there is something you expect DataFrames to be capable of, but cannot figure out how to do, please reach out with questions in Domains/Data on Discourse. Additionally you might want to listen to an introduction to DataFrames.jl on JuliaAcademy.
+
+Please report bugs by opening an issue.
+
+You can follow the source links throughout the documentation to jump right to the source files on GitHub to make pull requests for improving the documentation and function capabilities.
+
+Please review DataFrames contributing guidelines before submitting your first PR!
+
+Information on specific versions can be found on the Release page.
+
+Package Manual
+First Steps with DataFrames.jl
+Setting up the Environment
+Constructors and Basic Utility Functions
+Getting and Setting Data in a Data Frame
+Basic Usage of Transformation Functions
+Getting Started
+Installation
+The DataFrame Type
+Database-Style Joins
+Introduction to joins
+Key value comparisons and floating point values
+Joining on key columns with different names
+Handling of duplicate keys and tracking source data frame
+Renaming joined columns
+Matching missing values in joins
+Specifying row order in the join result
+In-place left join
+The Split-Apply-Combine Strategy
+Design of the split-apply-combine support
+Examples of the split-apply-combine operations
+Using GroupedDataFrame as an iterable and indexable object
+Simulating the SQL where clause
+Column-independent operations
+Column-independent operations versus functions
+Specifying group order in groupby
+Reshaping and Pivoting Data
+Sorting
+Categorical Data
+Missing Data
+Comparisons
+Comparison with the Python package pandas
+Comparison with the R package dplyr
+Comparison with the R package data.table
+Comparison with Stata (version 8 and above)
+Data manipulation frameworks
+DataFramesMeta.jl
+DataFrameMacros.jl
+Query.jl
+API
+Only exported (i.e. available for use without DataFrames. qualifier after loading the DataFrames.jl package with using DataFrames) types and functions are considered a part of the public API of the DataFrames.jl package. In general all such objects are documented in this manual (in case some documentation is missing please kindly report an issue here).
+
+Note
+Breaking changes to public and documented API are avoided in DataFrames.jl where possible.
+
+The following changes are not considered breaking:
+
+specific floating point values computed by operations may change at any time; users should rely only on approximate accuracy;
+in functions that use the default random number generator provided by Base Julia the specific random numbers computed may change across Julia versions;
+if the changed functionality is classified as a bug;
+if the changed behavior was not documented; two major cases are:
+in its implementation some function accepted a wider range of arguments that it was documented to handle - changes in handling of undocumented arguments are not considered as breaking;
+the type of the value returned by a function changes, but it still follows the contract specified in the documentation; for example if a function is documented to return a vector then changing its type from Vector to PooledVector is not considered as breaking;
+error behavior: code that threw an exception can change exception type thrown or stop throwing an exception;
+changes in display (how objects are printed);
+changes to the state of global objects from Base Julia whose state normally is considered volatile (e.g. state of global random number generator).
+All types and functions that are part of public API are guaranteed to go through a deprecation period before a breaking change is made to them or they would be removed.
+
+The standard practice is that breaking changes are implemented when a major release of DataFrames.jl is made (e.g. functionalities deprecated in a 1.x release would be changed in the 2.0 release).
+
+In rare cases a breaking change might be introduced in a minor release. In such a case the changed behavior still goes through one minor release during which it is deprecated. The situations where such a breaking change might be allowed are (still such breaking changes will be avoided if possible):
+
+the affected functionality was previously clearly identified in the documentation as being subject to changes (for example in DataFrames.jl 1.4 release propagation rules of :note-style metadata are documented as such);
+the change is on the border of being classified as a bug (in rare cases even if a behavior of some function was documented its consequences for certain argument combinations could be decided to be unintended and not wanted);
+the change is needed to adjust DataFrames.jl functionality to changes in Base Julia.
+Please be warned that while Julia allows you to access internal functions or types of DataFrames.jl these can change without warning between versions of DataFrames.jl. In particular it is not safe to directly access fields of types that are a part of public API of the DataFrames.jl package using e.g. the getfield function. Whenever some operation on fields of defined types is considered allowed an appropriate exported function should be used instead.
\ No newline at end of file
diff --git a/ext/RAGToolsExperimentalExt.jl b/ext/RAGToolsExperimentalExt.jl
new file mode 100644
index 000000000..7047853c0
--- /dev/null
+++ b/ext/RAGToolsExperimentalExt.jl
@@ -0,0 +1,34 @@
+module RAGToolsExperimentalExt
+
+using PromptingTools, SparseArrays
+using LinearAlgebra: normalize
+const PT = PromptingTools
+
+using PromptingTools.Experimental.RAGTools
+
+# forward to LinearAlgebra.normalize
+PromptingTools.Experimental.RAGTools._normalize(arr::AbstractArray) = normalize(arr)
+
+# "Builds a sparse matrix of tags and a vocabulary from the given vector of chunk metadata. Requires SparseArrays.jl to be loaded."
+function PromptingTools.Experimental.RAGTools.build_tags(chunk_metadata::Vector{
+ Vector{String},
+ })
+ tags_vocab_ = vcat(chunk_metadata...) |> unique |> sort
+ tags_vocab_index = Dict{String, Int}(t => i for (i, t) in enumerate(tags_vocab_))
+ Is, Js = Int[], Int[]
+ for i in eachindex(chunk_metadata)
+ for tag in chunk_metadata[i]
+ push!(Is, i)
+ push!(Js, tags_vocab_index[tag])
+ end
+ end
+ tags_ = sparse(Is,
+ Js,
+ trues(length(Is)),
+ length(chunk_metadata),
+ length(tags_vocab_),
+ &)
+ return tags_, tags_vocab_
+end
+
+end
\ No newline at end of file
diff --git a/src/Experimental/Experimental.jl b/src/Experimental/Experimental.jl
new file mode 100644
index 000000000..6a1f2cb20
--- /dev/null
+++ b/src/Experimental/Experimental.jl
@@ -0,0 +1,15 @@
+"""
+ Experimental
+
+This module is for experimental code that is not yet ready for production.
+It is not included in the main module, so it must be explicitly imported.
+
+Contains:
+- `RAGTools`: Retrieval-Augmented Generation (RAG) functionality.
+"""
+module Experimental
+
+export RAGTools
+include("RAGTools/RAGTools.jl")
+
+end # module Experimental
\ No newline at end of file
diff --git a/src/Experimental/RAGTools/RAGTools.jl b/src/Experimental/RAGTools/RAGTools.jl
new file mode 100644
index 000000000..a315a2a7f
--- /dev/null
+++ b/src/Experimental/RAGTools/RAGTools.jl
@@ -0,0 +1,33 @@
+"""
+ RAGTools
+
+Provides Retrieval-Augmented Generation (RAG) functionality.
+
+Requires: LinearAlgebra, SparseArrays, PromptingTools for proper functionality.
+
+This module is experimental and may change at any time. It is intended to be moved to a separate package in the future.
+"""
+module RAGTools
+
+using PromptingTools
+using JSON3
+const PT = PromptingTools
+
+include("utils.jl")
+
+export ChunkIndex, CandidateChunks # MultiIndex
+include("types.jl")
+
+export build_index, build_tags
+include("preparation.jl")
+
+export find_closest, find_tags, rerank
+include("retrieval.jl")
+
+export airag, build_context
+include("generation.jl")
+
+export build_qa_evals, run_qa_evals
+include("evaluation.jl")
+
+end
\ No newline at end of file
diff --git a/src/Experimental/RAGTools/evaluation.jl b/src/Experimental/RAGTools/evaluation.jl
new file mode 100644
index 000000000..fe25b32f1
--- /dev/null
+++ b/src/Experimental/RAGTools/evaluation.jl
@@ -0,0 +1,285 @@
+### For testing and eval
+# This is a return_type for extraction when generating Q&A set with aiextract
+@kwdef struct QAItem
+ question::String = ""
+ answer::String = ""
+end
+# This is for saving in JSON format for evaluation later
+@kwdef struct QAEvalItem
+ source::String = ""
+ context::String = ""
+ question::String = ""
+ answer::String = ""
+end
+
+@kwdef struct QAEvalResult
+ source::AbstractString
+ context::AbstractString
+ question::AbstractString
+ answer::AbstractString
+ retrieval_score::Union{Number, Nothing} = nothing
+ retrieval_rank::Union{Int, Nothing} = nothing
+ answer_score::Union{Number, Nothing} = nothing
+ parameters::Dict{Symbol, Any} = Dict{Symbol, Any}()
+end
+
+"Provide the `final_rating` between 1-5. Provide the rationale for it."
+@kwdef struct JudgeRating
+ rationale::Union{Nothing, String} = nothing
+ final_rating::Int
+end
+
+"`final_rating` is the average of all scoring criteria. Explain the `final_rating` in `rationale`"
+@kwdef struct JudgeAllScores
+ relevance::Int
+ completeness::Int
+ clarity::Int
+ consistency::Int
+ helpfulness::Int
+ rationale::Union{Nothing, String} = nothing
+ final_rating::Float64
+end
+
+function Base.isvalid(x::QAEvalItem)
+ !isempty(x.question) && !isempty(x.answer) && !isempty(x.context)
+end
+# for equality tests
+function Base.var"=="(x::Union{QAItem, QAEvalItem, QAEvalResult},
+ y::Union{QAItem, QAEvalItem, QAEvalResult})
+ typeof(x) == typeof(y) &&
+ all([getfield(x, f) == getfield(y, f) for f in fieldnames(typeof(x))])
+end
+
+# Nicer show method with some colors!
+function Base.show(io::IO, t::Union{QAItem, QAEvalItem, QAEvalResult})
+ printstyled(io, "$(nameof(typeof(t))):\n", color = :green, bold = true)
+ for f in fieldnames(typeof(t))
+ printstyled(io, " ", f, color = :blue, bold = true)
+ println(io, ": ", getfield(t, f))
+ end
+end
+# Define how JSON3 should serialize/deserialize the struct into JSON files
+JSON3.StructTypes.StructType(::Type{QAEvalItem}) = JSON3.StructTypes.Struct()
+JSON3.StructTypes.StructType(::Type{QAEvalResult}) = JSON3.StructTypes.Struct()
+
+"""
+ build_qa_evals(doc_chunks::Vector{<:AbstractString}, sources::Vector{<:AbstractString};
+ model=PT.MODEL_CHAT, instructions="None.", qa_template::Symbol=:RAGCreateQAFromContext,
+ verbose::Bool=true, api_kwargs::NamedTuple = NamedTuple(), kwargs...) -> Vector{QAEvalItem}
+
+Create a collection of question and answer evaluations (`QAEvalItem`) from document chunks and sources.
+This function generates Q&A pairs based on the provided document chunks, using a specified AI model and template.
+
+# Arguments
+- `doc_chunks::Vector{<:AbstractString}`: A vector of document chunks, each representing a segment of text.
+- `sources::Vector{<:AbstractString}`: A vector of source identifiers corresponding to each chunk in `doc_chunks` (eg, filenames or paths).
+- `model`: The AI model used for generating Q&A pairs. Default is `PT.MODEL_CHAT`.
+- `instructions::String`: Additional instructions or context to provide to the model generating QA sets. Defaults to "None.".
+- `qa_template::Symbol`: A template symbol that dictates the AITemplate that will be used. It must have placeholder `context`. Default is `:CreateQAFromContext`.
+- `api_kwargs::NamedTuple`: Parameters that will be forwarded to the API endpoint.
+- `verbose::Bool`: If `true`, additional information like costs will be logged. Defaults to `true`.
+
+# Returns
+`Vector{QAEvalItem}`: A vector of `QAEvalItem` structs, each containing a source, context, question, and answer. Invalid or empty items are filtered out.
+
+# Notes
+
+- The function internally uses `aiextract` to generate Q&A pairs based on the provided `qa_template`. So you can use any kwargs that you want.
+- Each `QAEvalItem` includes the context (document chunk), the generated question and answer, and the source.
+- The function tracks and reports the cost of AI calls if `verbose` is enabled.
+- Items where the question, answer, or context is empty are considered invalid and are filtered out.
+
+# Examples
+
+Creating Q&A evaluations from a set of document chunks:
+```julia
+doc_chunks = ["Text from document 1", "Text from document 2"]
+sources = ["source1", "source2"]
+qa_evals = build_qa_evals(doc_chunks, sources)
+```
+"""
+function build_qa_evals(doc_chunks::Vector{<:AbstractString},
+ sources::Vector{<:AbstractString};
+ model = PT.MODEL_CHAT, instructions = "None.",
+ qa_template::Symbol = :RAGCreateQAFromContext, verbose::Bool = true,
+ api_kwargs::NamedTuple = NamedTuple(), kwargs...)
+ ##
+ @assert length(doc_chunks)==length(sources) "Length of `doc_chunks` and `sources` must be the same."
+ placeholders = only(aitemplates(qa_template)).variables # only one template should be found
+ @assert (:context in placeholders) "Provided Q&A Template $(qa_template) is not suitable. It must have placeholder: `context`."
+ ##
+ cost_tracker = Threads.Atomic{Float64}(0.0)
+ output = asyncmap(zip(doc_chunks, sources)) do (context, source)
+ try
+ msg = aiextract(qa_template;
+ return_type = QAItem,
+ context,
+ instructions,
+ verbose,
+ model, api_kwargs)
+ Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model)) # track costs
+ QAEvalItem(; context, msg.content.question, msg.content.answer, source)
+ catch e
+ verbose && @warn e
+ QAEvalItem()
+ end
+ end
+ verbose && @info "Q&A Sets built! (cost: \$$(round(cost_tracker[], digits=3)))"
+ return filter(isvalid, output)
+end
+
+"Returns 1.0 if `context` overlaps or is contained within any of the `candidate_context`"
+function score_retrieval_hit(orig_context::AbstractString,
+ candidate_context::Vector{<:AbstractString})
+ 1.0 * (any(occursin.(Ref(orig_context), candidate_context)) ||
+ any(occursin.(candidate_context, Ref(orig_context))))
+end
+
+"Returns Integer rank of the position where `context` overlaps or is contained within a `candidate_context`"
+function score_retrieval_rank(orig_context::AbstractString,
+ candidate_context::Vector{<:AbstractString})
+ findfirst((occursin.(Ref(orig_context), candidate_context)) .||
+ (occursin.(candidate_context, Ref(orig_context))))
+end
+
+"""
+ run_qa_evals(qa_item::QAEvalItem, ctx::RAGContext; verbose::Bool = true,
+ parameters_dict::Dict{Symbol, Any}, judge_template::Symbol = :RAGJudgeAnswerFromContext,
+ model_judge::AbstractString,api_kwargs::NamedTuple = NamedTuple()) -> QAEvalResult
+
+Evaluates a single `QAEvalItem` using a RAG context (`RAGContext`) and returns a `QAEvalResult` structure. This function assesses the relevance and accuracy of the answers generated in a QA evaluation context.
+
+# Arguments
+- `qa_item::QAEvalItem`: The QA evaluation item containing the question and its answer.
+- `ctx::RAGContext`: The context used for generating the QA pair, including the original context and the answers.
+ Comes from `airag(...; return_context=true)`
+- `verbose::Bool`: If `true`, enables verbose logging. Defaults to `true`.
+- `parameters_dict::Dict{Symbol, Any}`: Track any parameters used for later evaluations. Keys must be Symbols.
+- `judge_template::Symbol`: The template symbol for the AI model used to judge the answer. Defaults to `:RAGJudgeAnswerFromContext`.
+- `model_judge::AbstractString`: The AI model used for judging the answer's quality.
+ Defaults to standard chat model, but it is advisable to use more powerful model GPT-4.
+- `api_kwargs::NamedTuple`: Parameters that will be forwarded to the API endpoint.
+
+# Returns
+`QAEvalResult`: An evaluation result that includes various scores and metadata related to the QA evaluation.
+
+# Notes
+- The function computes a retrieval score and rank based on how well the context matches the QA context.
+- It then uses the `judge_template` and `model_judge` to score the answer's accuracy and relevance.
+- In case of errors during evaluation, the function logs a warning (if `verbose` is `true`) and the `answer_score` will be set to `nothing`.
+
+# Examples
+
+Evaluating a QA pair using a specific context and model:
+```julia
+qa_item = QAEvalItem(question="What is the capital of France?", answer="Paris", context="France is a country in Europe.")
+ctx = RAGContext(source="Wikipedia", context="France is a country in Europe.", answer="Paris")
+parameters_dict = Dict("param1" => "value1", "param2" => "value2")
+
+eval_result = run_qa_evals(qa_item, ctx, parameters_dict=parameters_dict, model_judge="MyAIJudgeModel")
+```
+"""
+function run_qa_evals(qa_item::QAEvalItem, ctx::RAGContext;
+ verbose::Bool = true, parameters_dict::Dict{Symbol, Any} = Dict{Symbol, Any}(),
+ judge_template::Symbol = :RAGJudgeAnswerFromContextShort,
+ model_judge::AbstractString = PT.MODEL_CHAT,
+ api_kwargs::NamedTuple = NamedTuple())
+ retrieval_score = score_retrieval_hit(qa_item.context, ctx.context)
+ retrieval_rank = score_retrieval_rank(qa_item.context, ctx.context)
+
+ # Note we could evaluate if RAGContext and QAEvalItem are at least using the same sources etc.
+
+ answer_score = try
+ msg = aiextract(judge_template; model = model_judge, verbose,
+ ctx.context,
+ ctx.question,
+ ctx.answer,
+ return_type = JudgeAllScores, api_kwargs)
+ final_rating = if msg.content isa AbstractDict && haskey(msg.content, :final_rating)
+ # if return type parsing failed
+ msg.content[:final_rating]
+ else
+ # if return_type worked
+ msg.content.final_rating
+ end
+ catch e
+ verbose && @warn "Error in QA eval ($(qa_item.question)): $e"
+ nothing
+ end
+
+ return QAEvalResult(;
+ qa_item.source,
+ qa_item.context,
+ qa_item.question,
+ ctx.answer,
+ retrieval_score,
+ retrieval_rank,
+ answer_score,
+ parameters = parameters_dict)
+end
+
+"""
+ run_qa_evals(index::AbstractChunkIndex, qa_items::AbstractVector{<:QAEvalItem};
+ api_kwargs::NamedTuple = NamedTuple(),
+ airag_kwargs::NamedTuple = NamedTuple(),
+ qa_evals_kwargs::NamedTuple = NamedTuple(),
+ verbose::Bool = true, parameters_dict::Dict{Symbol, Any} = Dict{Symbol, Any}())
+
+Evaluates a vector of `QAEvalItem`s and returns a vector `QAEvalResult`.
+This function assesses the relevance and accuracy of the answers generated in a QA evaluation context.
+
+See `?run_qa_evals` for more details.
+
+# Arguments
+- `qa_items::AbstractVector{<:QAEvalItem}`: The vector of QA evaluation items containing the questions and their answers.
+- `verbose::Bool`: If `true`, enables verbose logging. Defaults to `true`.
+- `api_kwargs::NamedTuple`: Parameters that will be forwarded to the API calls. See `?aiextract` for details.
+- `airag_kwargs::NamedTuple`: Parameters that will be forwarded to `airag` calls. See `?airag` for details.
+- `qa_evals_kwargs::NamedTuple`: Parameters that will be forwarded to `run_qa_evals` calls. See `?run_qa_evals` for details.
+- `parameters_dict::Dict{Symbol, Any}`: Track any parameters used for later evaluations. Keys must be Symbols.
+
+# Returns
+`Vector{QAEvalResult}`: Vector of evaluation results that includes various scores and metadata related to the QA evaluation.
+
+# Example
+```julia
+index = "..." # Assuming a proper index is defined
+qa_items = [QAEvalItem(question="What is the capital of France?", answer="Paris", context="France is a country in Europe."),
+ QAEvalItem(question="What is the capital of Germany?", answer="Berlin", context="Germany is a country in Europe.")]
+
+# Let's run a test with `top_k=5`
+results = run_qa_evals(index, qa_items; airag_kwargs=(;top_k=5), parameters_dict=Dict(:top_k => 5))
+
+# Filter out the "failed" calls
+results = filter(x->!isnothing(x.answer_score), results);
+
+# See average judge score
+mean(x->x.answer_score, results)
+```
+
+"""
+function run_qa_evals(index::AbstractChunkIndex, qa_items::AbstractVector{<:QAEvalItem};
+ api_kwargs::NamedTuple = NamedTuple(),
+ airag_kwargs::NamedTuple = NamedTuple(),
+ qa_evals_kwargs::NamedTuple = NamedTuple(),
+ verbose::Bool = true, parameters_dict::Dict{Symbol, Any} = Dict{Symbol, Any}())
+ # Run evaluations in parallel
+ results = asyncmap(qa_items) do qa_item
+ # Generate an answer -- often you want the model_judge to be the highest quality possible, eg, "GPT-4 Turbo" (alias "gpt4t)
+ msg, ctx = airag(index; qa_item.question, return_context = true,
+ verbose, api_kwargs, airag_kwargs...)
+
+ # Evaluate the response
+ # Note: you can log key parameters for easier analysis later
+ run_qa_evals(qa_item,
+ ctx;
+ parameters_dict,
+ verbose,
+ api_kwargs,
+ qa_evals_kwargs...)
+ end
+ success_count = count(x -> !isnothing(x.answer_score), results)
+ verbose &&
+ @info "QA Evaluations complete ($((success_count)/length(qa_items)) evals successful)!"
+ return results
+end
\ No newline at end of file
diff --git a/src/Experimental/RAGTools/generation.jl b/src/Experimental/RAGTools/generation.jl
new file mode 100644
index 000000000..59b130b3c
--- /dev/null
+++ b/src/Experimental/RAGTools/generation.jl
@@ -0,0 +1,176 @@
+# stub to be replaced within the package extension
+function _normalize end
+
+"""
+ build_context(index::AbstractChunkIndex, reranked_candidates::CandidateChunks; chunks_window_margin::Tuple{Int, Int}) -> Vector{String}
+
+Build context strings for each position in `reranked_candidates` considering a window margin around each position.
+
+# Arguments
+- `reranked_candidates::CandidateChunks`: Candidate chunks which contain positions to extract context from.
+- `index::ChunkIndex`: The index containing chunks and sources.
+- `chunks_window_margin::Tuple{Int, Int}`: A tuple indicating the margin (before, after) around each position to include in the context.
+ Defaults to `(1,1)`, which means 1 preceding and 1 suceeding chunk will be included. With `(0,0)`, only the matching chunks will be included.
+
+# Returns
+- `Vector{String}`: A vector of context strings, each corresponding to a position in `reranked_candidates`.
+
+# Examples
+```julia
+index = ChunkIndex(...) # Assuming a proper index is defined
+candidates = CandidateChunks(index.id, [2, 4], [0.1, 0.2])
+context = build_context(index, candidates; chunks_window_margin=(0, 1)) # include only one following chunk for each matching chunk
+```
+"""
+function build_context(index::AbstractChunkIndex, reranked_candidates::CandidateChunks;
+ chunks_window_margin::Tuple{Int, Int} = (1, 1))
+ @assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative"
+ context = String[]
+ for (i, position) in enumerate(reranked_candidates.positions)
+ chunks_ = chunks(index)[max(1, position - chunks_window_margin[1]):min(end,
+ position + chunks_window_margin[2])]
+ is_same_source = sources(index)[max(1, position - chunks_window_margin[1]):min(end,
+ position + chunks_window_margin[2])] .== sources(index)[position]
+ push!(context, "$(i). $(join(chunks_[is_same_source], "\n"))")
+ end
+ return context
+end
+
+"""
+ airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext;
+ question::AbstractString,
+ top_k::Int = 3, `minimum_similarity::AbstractFloat`= -1.0,
+ tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto,
+ rerank_strategy::RerankingStrategy = Passthrough(),
+ model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT,
+ model_metadata::String = PT.MODEL_CHAT,
+ metadata_template::Symbol = :RAGExtractMetadataShort,
+ chunks_window_margin::Tuple{Int, Int} = (1, 1),
+ return_context::Bool = false, verbose::Bool = true,
+ api_kwargs::NamedTuple = NamedTuple(),
+ kwargs...)
+
+Generates a response for a given question using a Retrieval-Augmented Generation (RAG) approach.
+
+The function selects relevant chunks from an `ChunkIndex`, optionally filters them based on metadata tags, reranks them, and then uses these chunks to construct a context for generating a response.
+
+# Arguments
+- `index::AbstractChunkIndex`: The chunk index to search for relevant text.
+- `rag_template::Symbol`: Template for the RAG model, defaults to `:RAGAnswerFromContext`.
+- `question::AbstractString`: The question to be answered.
+- `top_k::Int`: Number of top candidates to retrieve based on embedding similarity.
+- `minimum_similarity::AbstractFloat`: Minimum similarity threshold (between -1 and 1) for filtering chunks based on embedding similarity. Defaults to -1.0.
+- `tag_filter::Union{Symbol, Vector{String}, Regex}`: Mechanism for filtering chunks based on tags (either automatically detected, specific tags, or a regex pattern). Disabled by setting to `nothing`.
+- `rerank_strategy::RerankingStrategy`: Strategy for reranking the retrieved chunks.
+- `model_embedding::String`: Model used for embedding the question, default is `PT.MODEL_EMBEDDING`.
+- `model_chat::String`: Model used for generating the final response, default is `PT.MODEL_CHAT`.
+- `model_metadata::String`: Model used for extracting metadata, default is `PT.MODEL_CHAT`.
+- `metadata_template::Symbol`: Template for the metadata extraction process from the question, defaults to: `:RAGExtractMetadataShort`
+- `chunks_window_margin::Tuple{Int,Int}`: The window size around each chunk to consider for context building. See `?build_context` for more information.
+- `return_context::Bool`: If `true`, returns the context used for RAG along with the response.
+- `verbose::Bool`: If `true`, enables verbose logging.
+- `api_kwargs`: API parameters that will be forwarded to the API calls
+
+# Returns
+- If `return_context` is `false`, returns the generated message (`msg`).
+- If `return_context` is `true`, returns a tuple of the generated message (`msg`) and the RAG context (`rag_context`).
+
+# Notes
+- The function first finds the closest chunks to the question embedding, then optionally filters these based on tags. After that, it reranks the candidates and builds a context for the RAG model.
+- The `tag_filter` can be used to refine the search. If set to `:auto`, it attempts to automatically determine relevant tags (if `index` has them available).
+- The `chunks_window_margin` allows including surrounding chunks for richer context, considering they are from the same source.
+- The function currently supports only single `ChunkIndex`.
+
+# Examples
+
+Using `airag` to get a response for a question:
+```julia
+index = build_index(...) # create an index
+question = "How to make a barplot in Makie.jl?"
+msg = airag(index, :RAGAnswerFromContext; question)
+
+# or simply
+msg = airag(index; question)
+```
+
+See also `build_index`, `build_context`, `CandidateChunks`, `find_closest`, `find_tags`, `rerank`
+"""
+function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext;
+ question::AbstractString,
+ top_k::Int = 3, minimum_similarity::AbstractFloat = -1.0,
+ tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto,
+ rerank_strategy::RerankingStrategy = Passthrough(),
+ model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT,
+ model_metadata::String = PT.MODEL_CHAT,
+ metadata_template::Symbol = :RAGExtractMetadataShort,
+ chunks_window_margin::Tuple{Int, Int} = (1, 1),
+ return_context::Bool = false, verbose::Bool = true,
+ api_kwargs::NamedTuple = NamedTuple(),
+ kwargs...)
+ ## Note: Supports only single ChunkIndex for now
+ ## Checks
+ @assert !(tag_filter isa Symbol && tag_filter != :auto) "Only `:auto`, `Vector{String}`, or `Regex` are supported for `tag_filter`"
+ @assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative"
+ placeholders = only(aitemplates(rag_template)).variables # only one template should be found
+ @assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(rag_template) is not suitable. It must have placeholders: `question` and `context`."
+
+ question_emb = aiembed(question,
+ _normalize;
+ model = model_embedding,
+ verbose, api_kwargs).content .|> Float32 # no need for Float64
+ emb_candidates = find_closest(index, question_emb; top_k, minimum_similarity)
+
+ tag_candidates = if tag_filter == :auto && !isnothing(tags(index)) &&
+ !isempty(model_metadata)
+ _check_aiextract_capability(model_metadata)
+ # extract metadata via LLM call
+ metadata_ = try
+ msg = aiextract(metadata_template; return_type = MaybeMetadataItems,
+ text = question,
+ instructions = "In addition to extracted items, suggest 2-3 filter keywords that could be relevant to answer this question.",
+ verbose, model = model_metadata, api_kwargs)
+ ## eg, ["software:::pandas", "language:::python", "julia_package:::dataframes"]
+ ## we split it and take only the keyword, not the category
+ metadata_extract(msg.content.items) |>
+ x -> split.(x, ":::") |> x -> getindex.(x, 2)
+ catch e
+ String[]
+ end
+ find_tags(index, metadata_)
+ elseif tag_filter isa Union{Vector{String}, Regex}
+ find_tags(index, tag_filter)
+ elseif isnothing(tag_filter)
+ nothing
+ else
+ ## not filtering -- use all rows and ignore this
+ nothing
+ end
+
+ filtered_candidates = isnothing(tag_candidates) ? emb_candidates :
+ (emb_candidates & tag_candidates)
+ reranked_candidates = rerank(rerank_strategy, index, question, filtered_candidates)
+
+ ## Build the context
+ context = build_context(index, reranked_candidates; chunks_window_margin)
+
+ ## LLM call
+ msg = aigenerate(rag_template; question,
+ context = join(context, "\n\n"), model = model_chat, verbose,
+ api_kwargs,
+ kwargs...)
+
+ if return_context # for evaluation
+ rag_context = RAGContext(;
+ question,
+ answer = msg.content,
+ context,
+ sources = sources(index)[reranked_candidates.positions],
+ emb_candidates,
+ tag_candidates,
+ filtered_candidates,
+ reranked_candidates)
+ return msg, rag_context
+ else
+ return msg
+ end
+end
\ No newline at end of file
diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl
new file mode 100644
index 000000000..50e937805
--- /dev/null
+++ b/src/Experimental/RAGTools/preparation.jl
@@ -0,0 +1,147 @@
+### Preparation
+# Types used to extract `tags` from document chunks
+@kwdef struct MetadataItem
+ value::String
+ category::String
+end
+@kwdef struct MaybeMetadataItems
+ items::Union{Nothing, Vector{MetadataItem}}
+end
+
+"""
+ metadata_extract(item::MetadataItem)
+ metadata_extract(items::Vector{MetadataItem})
+
+Extracts the metadata item into a string of the form `category:::value` (lowercased and spaces replaced with underscores).
+
+# Example
+```julia
+msg = aiextract(:RAGExtractMetadataShort; return_type=MaybeMetadataItems, text="I like package DataFrames", instructions="None.")
+metadata = metadata_extract(msg.content.items)
+```
+"""
+function metadata_extract(item::MetadataItem)
+ "$(strip(item.category)):::$(strip(item.value))" |> lowercase |>
+ x -> replace(x, " " => "_")
+end
+metadata_extract(items::Nothing) = String[]
+metadata_extract(items::Vector{MetadataItem}) = metadata_extract.(items)
+
+"Builds a matrix of tags and a vocabulary list. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!"
+function build_tags end
+# Implementation in ext/RAGToolsExperimentalExt.jl
+
+"Build an index for RAG (Retriever-Augmented Generation) applications. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!"
+function build_index end
+
+"""
+ build_index(files::Vector{<:AbstractString};
+ separators = ["\n\n", ". ", "\n"], max_length::Int = 256,
+ extract_metadata::Bool = false, verbose::Bool = true,
+ metadata_template::Symbol = :RAGExtractMetadataShort,
+ model_embedding::String = PT.MODEL_EMBEDDING,
+ model_metadata::String = PT.MODEL_CHAT,
+ api_kwargs::NamedTuple = NamedTuple())
+
+Build an index for RAG (Retriever-Augmented Generation) applications from the provided file paths.
+The function processes each file, splits its content into chunks, embeds these chunks,
+optionally extracts metadata, and then compiles this information into a retrievable index.
+
+# Arguments
+- `files`: A vector of valid file paths to be indexed.
+- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `["\n\n", ". ", "\n"]`.
+- `max_length`: The maximum length of each chunk (if possible with provided separators). Default is 256.
+- `extract_metadata`: A boolean flag indicating whether to extract metadata from each chunk (to build filter `tags` in the index). Default is `false`.
+ Metadata extraction incurs additional cost and requires `model_metadata` and `metadata_template` to be provided.
+- `verbose`: A boolean flag for verbose output. Default is `true`.
+- `metadata_template`: A symbol indicating the template to be used for metadata extraction. Default is `:RAGExtractMetadataShort`.
+- `model_embedding`: The model to use for embedding.
+- `model_metadata`: The model to use for metadata extraction.
+- `api_kwargs`: Parameters to be provided to the API endpoint.
+
+# Returns
+- `ChunkIndex`: An object containing the compiled index of chunks, embeddings, tags, vocabulary, and sources.
+
+See also: `MultiIndex`, `CandidateChunks`, `find_closest`, `find_tags`, `rerank`, `airag`
+
+# Examples
+```julia
+# Assuming `test_files` is a vector of file paths
+index = build_index(test_files; max_length=10, extract_metadata=true)
+
+# Another example with metadata extraction and verbose output
+index = build_index(["file1.txt", "file2.txt"];
+ separators=[". "],
+ extract_metadata=true,
+ verbose=true)
+```
+"""
+function build_index(files::Vector{<:AbstractString};
+ separators = ["\n\n", ". ", "\n"], max_length::Int = 256,
+ extract_metadata::Bool = false, verbose::Bool = true,
+ metadata_template::Symbol = :RAGExtractMetadataShort,
+ model_embedding::String = PT.MODEL_EMBEDDING,
+ model_metadata::String = PT.MODEL_CHAT,
+ api_kwargs::NamedTuple = NamedTuple())
+ ##
+ @assert all(isfile, files) "Some `files` don't exist (Check: $(join(filter(!isfile,files),", "))"
+
+ output_chunks = Vector{Vector{SubString{String}}}()
+ output_embeddings = Vector{Matrix{Float32}}()
+ output_metadata = Vector{Vector{Vector{String}}}()
+ output_sources = Vector{Vector{eltype(files)}}()
+ cost_tracker = Threads.Atomic{Float64}(0.0)
+
+ for fn in files
+ verbose && @info "Processing file: $fn"
+ doc_raw = read(fn, String)
+ isempty(doc_raw) && continue
+ # split into chunks, if you want to start simple - just do `split(text,"\n\n")`
+ doc_chunks = PT.split_by_length(doc_raw, separators; max_length) .|> strip |>
+ x -> filter(!isempty, x)
+ # skip if no chunks found
+ isempty(doc_chunks) && continue
+ push!(output_chunks, doc_chunks)
+ push!(output_sources, fill(fn, length(doc_chunks)))
+
+ # Notice that we embed all doc_chunks at once, not one by one
+ # OpenAI supports embedding multiple documents to reduce the number of API calls/network latency time
+ emb = aiembed(doc_chunks, _normalize; model = model_embedding, verbose, api_kwargs)
+ Threads.atomic_add!(cost_tracker, PT.call_cost(emb, model_embedding)) # track costs
+ push!(output_embeddings, Float32.(emb.content))
+
+ if extract_metadata && !isempty(model_metadata)
+ _check_aiextract_capability(model_metadata)
+ metadata_ = asyncmap(doc_chunks) do chunk
+ try
+ msg = aiextract(metadata_template;
+ return_type = MaybeMetadataItems,
+ text = chunk,
+ instructions = "None.",
+ verbose,
+ model = model_metadata, api_kwargs)
+ Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model_metadata)) # track costs
+ items = metadata_extract(msg.content.items)
+ catch
+ String[]
+ end
+ end
+ push!(output_metadata, metadata_)
+ end
+ end
+ ## Create metadata tags and associated vocabulary
+ tags, tags_vocab = if !isempty(output_metadata)
+ # Requires SparseArrays.jl!
+ build_tags(vcat(output_metadata...)) # need to vcat to be on the "chunk-level"
+ else
+ tags, tags_vocab = nothing, nothing
+ end
+ verbose && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))"
+
+ index = ChunkIndex(;
+ embeddings = hcat(output_embeddings...),
+ tags, tags_vocab,
+ chunks = vcat(output_chunks...),
+ sources = vcat(output_sources...))
+ return index
+end
diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl
new file mode 100644
index 000000000..a3d136aa4
--- /dev/null
+++ b/src/Experimental/RAGTools/retrieval.jl
@@ -0,0 +1,64 @@
+"""
+ find_closest(emb::AbstractMatrix{<:Real},
+ query_emb::AbstractVector{<:Real};
+ top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0)
+
+Finds the indices of chunks (represented by embeddings in `emb`) that are closest (cosine similarity) to query embedding (`query_emb`).
+
+If `minimum_similarity` is provided, only indices with similarity greater than or equal to it are returned.
+Similarity can be between -1 and 1 (-1 = completely opposite, 1 = exactly the same).
+
+Returns only `top_k` closest indices.
+"""
+function find_closest(emb::AbstractMatrix{<:Real},
+ query_emb::AbstractVector{<:Real};
+ top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0)
+ # emb is an embedding matrix where the first dimension is the embedding dimension
+ distances = query_emb' * emb |> vec
+ positions = distances |> sortperm |> reverse |> x -> first(x, top_k)
+ if minimum_similarity > -1.0
+ mask = distances[positions] .>= minimum_similarity
+ positions = positions[mask]
+ end
+ return positions, distances[positions]
+end
+function find_closest(index::AbstractChunkIndex,
+ query_emb::AbstractVector{<:Real};
+ top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0)
+ isnothing(embeddings(index)) && CandidateChunks(; index_id = index.id)
+ positions, distances = find_closest(embeddings(index),
+ query_emb;
+ top_k,
+ minimum_similarity)
+ return CandidateChunks(index.id, positions, Float32.(distances))
+end
+
+function find_tags(index::AbstractChunkIndex,
+ tag::Union{AbstractString, Regex})
+ isnothing(tags(index)) && CandidateChunks(; index_id = index.id)
+ tag_idx = if tag isa AbstractString
+ findall(tags_vocab(index) .== tag)
+ else # assume it's a regex
+ findall(occursin.(tag, tags_vocab(index)))
+ end
+ # getindex.(x, 1) is to get the first dimension in each CartesianIndex
+ match_row_idx = @view(tags(index)[:, tag_idx]) |> findall |>
+ x -> getindex.(x, 1) |> unique
+ return CandidateChunks(index.id, match_row_idx, ones(Float32, length(match_row_idx)))
+end
+function find_tags(index::AbstractChunkIndex,
+ tags::Vector{<:AbstractString})
+ pos = [find_tags(index, tag).positions for tag in tags] |>
+ Base.Splat(vcat) |> unique |> x -> convert(Vector{Int}, x)
+ return CandidateChunks(index.id, pos, ones(Float32, length(pos)))
+end
+
+# Assuming the rerank and strategy definitions are in the Main module or relevant module
+abstract type RerankingStrategy end
+
+struct Passthrough <: RerankingStrategy end
+
+function rerank(strategy::Passthrough, index, question, candidate_chunks; kwargs...)
+ # Since this is a Passthrough strategy, it returns the candidate_chunks unchanged
+ return candidate_chunks
+end
\ No newline at end of file
diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl
new file mode 100644
index 000000000..2aeb7a4d0
--- /dev/null
+++ b/src/Experimental/RAGTools/types.jl
@@ -0,0 +1,136 @@
+### Types
+# Defines three key types for RAG: ChunkIndex, MultiIndex, and CandidateChunks
+# In addition, RAGContext is defined for debugging purposes
+
+abstract type AbstractDocumentIndex end
+abstract type AbstractChunkIndex <: AbstractDocumentIndex end
+# More advanced index would be: HybridChunkIndex
+
+# Stores document chunks and their embeddings
+@kwdef struct ChunkIndex{
+ T1 <: AbstractString,
+ T2 <: Union{Nothing, Matrix{<:Real}},
+ T3 <: Union{Nothing, AbstractMatrix{<:Bool}},
+} <: AbstractChunkIndex
+ id::Symbol = gensym("ChunkIndex")
+ # underlying document chunks / snippets
+ chunks::Vector{T1}
+ # for semantic search
+ embeddings::T2 = nothing
+ # for exact search, filtering, etc.
+ # expected to be some sparse structure, eg, sparse matrix or nothing
+ # column oriented, ie, each column is one item in `tags_vocab` and rows are the chunks
+ tags::T3 = nothing
+ tags_vocab::Union{Nothing, Vector{<:AbstractString}} = nothing
+ sources::Vector{<:AbstractString}
+end
+embeddings(index::ChunkIndex) = index.embeddings
+chunks(index::ChunkIndex) = index.chunks
+tags(index::ChunkIndex) = index.tags
+tags_vocab(index::ChunkIndex) = index.tags_vocab
+sources(index::ChunkIndex) = index.sources
+
+function Base.var"=="(i1::ChunkIndex, i2::ChunkIndex)
+ ((i1.sources == i2.sources) && (i1.tags_vocab == i2.tags_vocab) &&
+ (i1.embeddings == i2.embeddings) && (i1.chunks == i2.chunks) && (i1.tags == i2.tags))
+end
+
+function Base.vcat(i1::ChunkIndex, i2::ChunkIndex)
+ tags_, tags_vocab_ = if (isnothing(tags(i1)) || isnothing(tags(i2)))
+ nothing, nothing
+ elseif tags_vocab(i1) == tags_vocab(i2)
+ vcat(tags(i1), tags(i2)), tags_vocab(i1)
+ else
+ merge_labeled_matrices(tags(i1), tags_vocab(i1), tags(i2), tags_vocab(i2))
+ end
+ embeddings_ = (isnothing(embeddings(i1)) || isnothing(embeddings(i2))) ? nothing :
+ hcat(embeddings(i1), embeddings(i2))
+ ChunkIndex(;
+ chunks = vcat(chunks(i1), chunks(i2)),
+ embeddings = embeddings_,
+ tags = tags_,
+ tags_vocab = tags_vocab_,
+ sources = vcat(i1.sources, i2.sources))
+end
+
+"Composite index that stores multiple ChunkIndex objects and their embeddings"
+@kwdef struct MultiIndex <: AbstractDocumentIndex
+ id::Symbol = gensym("MultiIndex")
+ indexes::Vector{<:ChunkIndex}
+end
+indexes(index::MultiIndex) = index.indexes
+# check that each index has a counterpart in the other MultiIndex
+function Base.var"=="(i1::MultiIndex, i2::MultiIndex)
+ length(indexes(i1)) != length(indexes(i2)) && return false
+ for i in i1.indexes
+ if !(i in i2.indexes)
+ return false
+ end
+ end
+ for i in i2.indexes
+ if !(i in i1.indexes)
+ return false
+ end
+ end
+ return true
+end
+
+abstract type AbstractCandidateChunks end
+@kwdef struct CandidateChunks{T <: Real} <: AbstractCandidateChunks
+ index_id::Symbol
+ positions::Vector{Int} = Int[]
+ distances::Vector{T} = Float32[]
+end
+# combine/intersect two candidate chunks. average the score if available
+function Base.var"&"(cc1::CandidateChunks, cc2::CandidateChunks)
+ cc1.index_id != cc2.index_id && return CandidateChunks(; index_id = cc1.index_id)
+
+ positions = intersect(cc1.positions, cc2.positions)
+ distances = if !isempty(cc1.distances) && !isempty(cc2.distances)
+ (cc1.distances[positions] .+ cc2.distances[positions]) ./ 2
+ else
+ Float32[]
+ end
+ CandidateChunks(cc1.index_id, positions, distances)
+end
+function Base.getindex(ci::ChunkIndex, candidate::CandidateChunks, field::Symbol = :chunks)
+ @assert field==:chunks "Only `chunks` field is supported for now"
+ len_ = length(chunks(ci))
+ @assert all(1 .<= candidate.positions .<= len_) "Some positions are out of bounds"
+ if ci.id == candidate.index_id
+ chunks(ci)[candidate.positions]
+ else
+ eltype(chunks(ci))[]
+ end
+end
+function Base.getindex(mi::MultiIndex, candidate::CandidateChunks, field::Symbol = :chunks)
+ @assert field==:chunks "Only `chunks` field is supported for now"
+ valid_index = findfirst(x -> x.id == candidate.index_id, indexes(mi))
+ if isnothing(valid_index)
+ String[]
+ else
+ getindex(indexes(mi)[valid_index], candidate)
+ end
+end
+
+"""
+ RAGContext
+
+A struct for debugging RAG answers. It contains the question, answer, context, and the candidate chunks at each step of the RAG pipeline.
+"""
+@kwdef struct RAGContext
+ question::AbstractString
+ answer::AbstractString
+ context::Vector{<:AbstractString}
+ sources::Vector{<:AbstractString}
+ emb_candidates::CandidateChunks
+ tag_candidates::Union{Nothing, CandidateChunks}
+ filtered_candidates::CandidateChunks
+ reranked_candidates::CandidateChunks
+end
+
+# Structured show method for easier reading (each kwarg on a new line)
+function Base.show(io::IO,
+ t::Union{AbstractDocumentIndex, AbstractCandidateChunks, RAGContext})
+ dump(IOContext(io, :limit => true), t, maxdepth = 1)
+end
diff --git a/src/Experimental/RAGTools/utils.jl b/src/Experimental/RAGTools/utils.jl
new file mode 100644
index 000000000..f980a61e0
--- /dev/null
+++ b/src/Experimental/RAGTools/utils.jl
@@ -0,0 +1,23 @@
+# Utility to check model suitability
+function _check_aiextract_capability(model::AbstractString)
+ # Check that the provided model is known and that it is an OpenAI model (for the aiextract function to work)
+ @assert haskey(PT.MODEL_REGISTRY,
+ model)&&PT.MODEL_REGISTRY[model].schema isa PT.AbstractOpenAISchema "Only OpenAI models support the metadata extraction now. $model is not a registered OpenAI model."
+end
+# Utitity to be able to combine indices from different sources/documents easily
+function merge_labeled_matrices(mat1::AbstractMatrix{T1},
+ vocab1::Vector{String},
+ mat2::AbstractMatrix{T2},
+ vocab2::Vector{String}) where {T1 <: Number, T2 <: Number}
+ T = promote_type(T1, T2)
+ new_words = setdiff(vocab2, vocab1)
+ combined_vocab = [vocab1; new_words]
+ vocab2_indices = Dict(word => i for (i, word) in enumerate(vocab2))
+
+ aligned_mat1 = hcat(mat1, zeros(T, size(mat1, 1), length(new_words)))
+ aligned_mat2 = [haskey(vocab2_indices, word) ? @view(mat2[:, vocab2_indices[word]]) :
+ zeros(T, size(mat2, 1)) for word in combined_vocab]
+ aligned_mat2 = aligned_mat2 |> Base.Splat(hcat)
+
+ return vcat(aligned_mat1, aligned_mat2), combined_vocab
+end
\ No newline at end of file
diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl
index 7f5c6bdd1..a137bd866 100644
--- a/src/PromptingTools.jl
+++ b/src/PromptingTools.jl
@@ -65,6 +65,9 @@ include("llm_ollama_managed.jl")
export @ai_str, @aai_str
include("macros.jl")
+## Experimental modules
+include("Experimental/Experimental.jl")
+
function __init__()
# Load templates
load_templates!()
diff --git a/src/llm_interface.jl b/src/llm_interface.jl
index 3dc8d2f22..aead1cf77 100644
--- a/src/llm_interface.jl
+++ b/src/llm_interface.jl
@@ -139,19 +139,19 @@ end
function aiembed(doc_or_docs, args...; model = MODEL_EMBEDDING, kwargs...)
global MODEL_REGISTRY
schema = get(MODEL_REGISTRY, model, (; schema = PROMPT_SCHEMA)).schema
- aiembed(schema, doc_or_docs, args...; kwargs...)
+ aiembed(schema, doc_or_docs, args...; model, kwargs...)
end
function aiclassify(prompt; model = MODEL_CHAT, kwargs...)
global MODEL_REGISTRY
schema = get(MODEL_REGISTRY, model, (; schema = PROMPT_SCHEMA)).schema
- aiclassify(schema, prompt; kwargs...)
+ aiclassify(schema, prompt; model, kwargs...)
end
function aiextract(prompt; model = MODEL_CHAT, kwargs...)
global MODEL_REGISTRY
schema = get(MODEL_REGISTRY, model, (; schema = PROMPT_SCHEMA)).schema
- aiextract(schema, prompt; kwargs...)
+ aiextract(schema, prompt; model, kwargs...)
end
function aiscan(prompt; model = MODEL_CHAT, kwargs...)
schema = get(MODEL_REGISTRY, model, (; schema = PROMPT_SCHEMA)).schema
- aiscan(schema, prompt; kwargs...)
+ aiscan(schema, prompt; model, kwargs...)
end
\ No newline at end of file
diff --git a/src/utils.jl b/src/utils.jl
index 168d772cb..a0f6b996b 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -66,7 +66,13 @@ split_by_length(text; separator=",", max_length=10000) # for 4K context window
length(chunks[1]) # Output: 4
```
"""
-function split_by_length(text::String; separator::String = " ", max_length::Int = 35000)
+function split_by_length(text::String;
+ separator::String = " ",
+ max_length::Int = 35000)
+ ## shortcut
+ length(text) <= max_length && return [text]
+
+ ## split by separator
minichunks = split(text, separator)
sep_length = length(separator)
chunks = String[]
@@ -99,6 +105,66 @@ function split_by_length(text::String; separator::String = " ", max_length::Int
return chunks
end
+
+# Overload for dispatch on multiple separators
+function split_by_length(text::String,
+ separator::String,
+ max_length::Int = 35000)
+ split_by_length(text; separator, max_length)
+end
+
+"""
+ split_by_length(text::String, separators::Vector{String}; max_length::Int=35000) -> Vector{String}
+
+Split a given string `text` into chunks using a series of separators, with each chunk having a maximum length of `max_length`.
+This function is useful for splitting large documents or texts into smaller segments that are more manageable for processing, particularly for models or systems with limited context windows.
+
+# Arguments
+- `text::String`: The text to be split.
+- `separators::Vector{String}`: An ordered list of separators used to split the text. The function iteratively applies these separators to split the text.
+- `max_length::Int=35000`: The maximum length of each chunk. Defaults to 35,000 characters. This length is considered after each iteration of splitting, ensuring chunks fit within specified constraints.
+
+# Returns
+`Vector{String}`: A vector of strings, where each string is a chunk of the original text that is smaller than or equal to `max_length`.
+
+# Notes
+
+- The function processes the text iteratively with each separator in the provided order. This ensures more nuanced splitting, especially in structured texts.
+- Each chunk is as close to `max_length` as possible without exceeding it (unless we cannot split it any further)
+- If the `text` is empty, the function returns an empty array.
+- Separators are re-added to the text chunks after splitting, preserving the original structure of the text as closely as possible. Apply `strip` if you do not need them.
+
+# Examples
+
+Splitting text using multiple separators:
+```julia
+text = "Paragraph 1\n\nParagraph 2. Sentence 1. Sentence 2.\nParagraph 3"
+separators = ["\n\n", ". ", "\n"]
+chunks = split_by_length(text, separators, max_length=20)
+```
+
+Using a single separator:
+```julia
+text = "Hello,World," ^ 2900 # length 34900 characters
+chunks = split_by_length(text, [","], max_length=10000)
+```
+"""
+function split_by_length(text, separators::Vector{String}; max_length)
+ @assert !isempty(separators) "`separators` can't be empty"
+ separator = popfirst!(separators)
+ chunks = split_by_length(text; separator, max_length)
+
+ isempty(separators) && return chunks
+ ## Iteratively split by separators
+ for separator in separators
+ chunks = mapreduce(text_ -> split_by_length(text_; max_length, separator),
+ vcat,
+ chunks)
+ end
+
+ return chunks
+end
+
### INTERNAL FUNCTIONS - DO NOT USE DIRECTLY
# helper to extract handlebar variables (eg, `{{var}}`) from a prompt string
function _extract_handlebar_variables(s::AbstractString)
@@ -109,18 +175,63 @@ function _extract_handlebar_variables(vect::Vector{Dict{String, <:AbstractString
unique([_extract_handlebar_variables(v) for d in vect for (k, v) in d if k == "text"])
end
-# helper to produce summary message of how many tokens were used and for how much
-function _report_stats(msg,
- model::String,
+"""
+ call_cost(msg, model::String;
+ cost_of_token_prompt::Number = default_prompt_cost,
+ cost_of_token_generation::Number = default_generation_cost) -> Number
+
+Calculate the cost of a call based on the number of tokens in the message and the cost per token.
+
+# Arguments
+- `msg`: The message object, which should contain a `tokens` field
+ with two elements: [number_of_prompt_tokens, number_of_generation_tokens].
+- `model::String`: The name of the model to use for determining token costs. If the model
+ is not found in `MODEL_REGISTRY`, default costs are used.
+- `cost_of_token_prompt::Number`: The cost per prompt token. Defaults to the cost in `MODEL_REGISTRY`
+ for the given model, or 0.0 if the model is not found.
+- `cost_of_token_generation::Number`: The cost per generation token. Defaults to the cost in
+ `MODEL_REGISTRY` for the given model, or 0.0 if the model is not found.
+
+# Returns
+- `Number`: The total cost of the call.
+
+# Examples
+```julia
+# Assuming MODEL_REGISTRY is set up with appropriate costs
+MODEL_REGISTRY = Dict(
+ "model1" => (cost_of_token_prompt = 0.05, cost_of_token_generation = 0.10),
+ "model2" => (cost_of_token_prompt = 0.07, cost_of_token_generation = 0.02)
+)
+
+msg1 = AIMessage([10, 20]) # 10 prompt tokens, 20 generation tokens
+cost1 = call_cost(msg1, "model1")
+# cost1 = 10 * 0.05 + 20 * 0.10 = 2.5
+
+msg2 = DataMessage([15, 30]) # 15 prompt tokens, 30 generation tokens
+cost2 = call_cost(msg2, "model2")
+# cost2 = 15 * 0.07 + 30 * 0.02 = 1.35
+
+# Using custom token costs
+msg3 = AIMessage([5, 10])
+cost3 = call_cost(msg3, "model3", cost_of_token_prompt = 0.08, cost_of_token_generation = 0.12)
+# cost3 = 5 * 0.08 + 10 * 0.12 = 1.6
+```
+"""
+function call_cost(msg, model::String;
cost_of_token_prompt::Number = get(MODEL_REGISTRY,
model,
(; cost_of_token_prompt = 0.0)).cost_of_token_prompt,
cost_of_token_generation::Number = get(MODEL_REGISTRY, model,
(; cost_of_token_generation = 0.0)).cost_of_token_generation)
- cost = (msg.tokens[1] * cost_of_token_prompt +
- msg.tokens[2] * cost_of_token_generation)
+ cost = msg.tokens[1] * cost_of_token_prompt +
+ msg.tokens[2] * cost_of_token_generation
+ return cost
+end
+# helper to produce summary message of how many tokens were used and for how much
+function _report_stats(msg,
+ model::String)
+ cost = call_cost(msg, model)
cost_str = iszero(cost) ? "" : " @ Cost: \$$(round(cost; digits=4))"
-
return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds"
end
# Loads and encodes the provided image path as a base64 string
diff --git a/templates/RAG/CreateQAFromContext.json b/templates/RAG/CreateQAFromContext.json
new file mode 100644
index 000000000..83e900ba1
--- /dev/null
+++ b/templates/RAG/CreateQAFromContext.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Generate Question and Answer from the provided Context.If you don't have any special instructions, provide `instructions=\"None.\"`. Placeholders: `context`, `instructions`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"You are a world-class teacher preparing contextual Question & Answer sets for evaluating AI systems.\"),\n\n**Instructions for Question Generation:**\n1. Analyze the provided Context chunk thoroughly.\n2. Formulate a question that:\n - Is specific and directly related to the information in the context chunk.\n - Is not too short or generic; it should require detailed understanding of the context to answer.\n - Can only be answered using the information from the provided context, without needing external information.\n\n**Instructions for Reference Answer Creation:**\n1. Based on the generated question, compose a reference answer that:\n - Directly and comprehensively answers the question.\n - Stays strictly within the bounds of the provided context chunk.\n - Is clear, concise, and to the point, avoiding unnecessary elaboration or repetition.\n\n**Example 1:**\n- Context Chunk: \"In 1928, Alexander Fleming discovered penicillin, which marked the beginning of modern antibiotics.\"\n- Generated Question: \"What was the significant discovery made by Alexander Fleming in 1928 and its impact?\"\n- Reference Answer: \"Alexander Fleming discovered penicillin in 1928, which led to the development of modern antibiotics.\"\n\nIf the user provides special instructions, prioritize these over the general instructions.\n","variables":[],"_type":"systemmessage"},{"content":"# Context Information\n---\n{{context}}\n---\n\n\n# Special Instructions\n\n{{instructions}}\n","variables":["context","instructions"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/templates/RAG/RAGAnswerFromContext.json b/templates/RAG/RAGAnswerFromContext.json
new file mode 100644
index 000000000..272ca4e20
--- /dev/null
+++ b/templates/RAG/RAGAnswerFromContext.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Answers the provided Questions based on the Context. Placeholders: `question`, `context`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"Act as a world-class AI assistant with access to the latest knowledge via Context Information. \n\n**Instructions:**\n- Answer the question based only on the provided Context.\n- If you don't know the answer, just say that you don't know, don't try to make up an answer.\n- Be brief and concise.\n\n**Context Information:**\n---\n{{context}}\n---\n","variables":["context"],"_type":"systemmessage"},{"content":"# Question\n\n{{question}}\n\n\n\n# Answer\n\n","variables":["question"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/templates/RAG/RAGCreateQAFromContext.json b/templates/RAG/RAGCreateQAFromContext.json
new file mode 100644
index 000000000..83e900ba1
--- /dev/null
+++ b/templates/RAG/RAGCreateQAFromContext.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Generate Question and Answer from the provided Context.If you don't have any special instructions, provide `instructions=\"None.\"`. Placeholders: `context`, `instructions`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"You are a world-class teacher preparing contextual Question & Answer sets for evaluating AI systems.\"),\n\n**Instructions for Question Generation:**\n1. Analyze the provided Context chunk thoroughly.\n2. Formulate a question that:\n - Is specific and directly related to the information in the context chunk.\n - Is not too short or generic; it should require detailed understanding of the context to answer.\n - Can only be answered using the information from the provided context, without needing external information.\n\n**Instructions for Reference Answer Creation:**\n1. Based on the generated question, compose a reference answer that:\n - Directly and comprehensively answers the question.\n - Stays strictly within the bounds of the provided context chunk.\n - Is clear, concise, and to the point, avoiding unnecessary elaboration or repetition.\n\n**Example 1:**\n- Context Chunk: \"In 1928, Alexander Fleming discovered penicillin, which marked the beginning of modern antibiotics.\"\n- Generated Question: \"What was the significant discovery made by Alexander Fleming in 1928 and its impact?\"\n- Reference Answer: \"Alexander Fleming discovered penicillin in 1928, which led to the development of modern antibiotics.\"\n\nIf the user provides special instructions, prioritize these over the general instructions.\n","variables":[],"_type":"systemmessage"},{"content":"# Context Information\n---\n{{context}}\n---\n\n\n# Special Instructions\n\n{{instructions}}\n","variables":["context","instructions"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/templates/RAG/RAGExtractMetadataLong.json b/templates/RAG/RAGExtractMetadataLong.json
new file mode 100644
index 000000000..9ede8c3ca
--- /dev/null
+++ b/templates/RAG/RAGExtractMetadataLong.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Extracts metadata from the provided text using longer instructions set and examples. If you don't have any special instructions, provide `instructions=\"None.\"`. Placeholders: `text`, `instructions`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"You're a world-class data extraction engine built by OpenAI together with Google and to extract filter metadata to power the most advanced search engine in the world. \n \n **Instructions for Extraction:**\n 1. Carefully read through the provided Text\n 2. Identify and extract:\n - All relevant entities such as names, places, dates, etc.\n - Any special items like technical terms, unique identifiers, etc.\n - In the case of Julia code or Julia documentation: specifically extract package names, struct names, function names, and important variable names (eg, uppercased variables)\n 3. Keep extracted values and categories short. Maximum 2-3 words!\n 4. You can only extract 3-5 items per Text, so select the most important ones.\n 5. Assign search filter Category to each extracted Value\n \n **Example 1:**\n - Document Chunk: \"Dr. Jane Smith published her findings on neuroplasticity in 2021. The research heavily utilized the DataFrames.jl and Plots.jl packages.\"\n - Extracted keywords:\n - Name: Dr. Jane Smith\n - Date: 2021\n - Technical Term: neuroplasticity\n - JuliaPackage: DataFrames.jl, Plots.jl\n - JuliaLanguage:\n - Identifier:\n - Other: \n\n If the user provides special instructions, prioritize these over the general instructions.\n","variables":[],"_type":"systemmessage"},{"content":"# Text\n\n{{text}}\n\n\n\n# Special Instructions\n\n{{instructions}}","variables":["text","instructions"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/templates/RAG/RAGExtractMetadataShort.json b/templates/RAG/RAGExtractMetadataShort.json
new file mode 100644
index 000000000..88132e929
--- /dev/null
+++ b/templates/RAG/RAGExtractMetadataShort.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Extracts metadata from the provided text. If you don't have any special instructions, provide `instructions=\"None.\"`. Placeholders: `text`, `instructions`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"Extract search keywords and their categories from the Text provided below (format \"value:category\"). Each keyword must be at most 2-3 words. Provide at most 3-5 keywords. I will tip you $50 if the search is successful.","variables":[],"_type":"systemmessage"},{"content":"# Text\n\n{{text}}\n\n\n\n# Special Instructions\n\n{{instructions}}","variables":["text","instructions"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/templates/RAG/RAGJudgeAnswerFromContext.json b/templates/RAG/RAGJudgeAnswerFromContext.json
new file mode 100644
index 000000000..e988d8129
--- /dev/null
+++ b/templates/RAG/RAGJudgeAnswerFromContext.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Judge answer to a question on a scale from 1-5. Placeholders: `question`, `context`, `answer`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"You're an impartial judge. Your task is to evaluate the quality of the Answer provided by an AI assistant in response to the User Question on a scale 1-5.\n\n1. **Scoring Criteria:**\n- **Relevance (1-5):** How well does the provided answer align with the context? \n - *1: Not relevant, 5: Highly relevant*\n- **Completeness (1-5):** Does the provided answer cover all the essential points mentioned in the context?\n - *1: Very incomplete, 5: Very complete*\n- **Clarity (1-5):** How clear and understandable is the provided answer?\n - *1: Not clear at all, 5: Extremely clear*\n- **Consistency (1-5):** How consistent is the provided answer with the overall context?\n - *1: Highly inconsistent, 5: Perfectly consistent*\n- **Helpfulness (1-5):** How helpful is the provided answer in answering the user's question?\n - *1: Not helpful at all, 5: Extremely helpful*\n\n2. **Judging Instructions:**\n- As an impartial judge, please evaluate the provided answer based on the above criteria. \n- Assign a score from 1 to 5 for each criterion, considering the original context, question and the provided answer.\n- The Final Score is an average of these individual scores, representing the overall quality and relevance of the provided answer. It must be between 1-5.\n\n```\n","variables":[],"_type":"systemmessage"},{"content":"# User Question\n---\n{{question}}\n---\n\n\n# Context Information\n---\n{{context}}\n---\n\n\n# Assistant's Answer\n---\n{{answer}}\n---\n\n\n# Judge's Evaluation\n","variables":["question","context","answer"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/templates/RAG/RAGJudgeAnswerFromContextShort.json b/templates/RAG/RAGJudgeAnswerFromContextShort.json
new file mode 100644
index 000000000..93ea6447f
--- /dev/null
+++ b/templates/RAG/RAGJudgeAnswerFromContextShort.json
@@ -0,0 +1 @@
+[{"content":"Template Metadata","description":"For RAG applications. Simple and short prompt to judge answer to a question on a scale from 1-5. Placeholders: `question`, `context`, `answer`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"You re an impartial judge. \nRead carefully the provided question and the answer based on the context. \nProvide a rating on a scale 1-5 (1=worst quality, 5=best quality) that reflects how relevant, helpful, clear, and consistent with the provided context the answer was.\n```\n","variables":[],"_type":"systemmessage"},{"content":"# User Question\n---\n{{question}}\n---\n\n\n# Context Information\n---\n{{context}}\n---\n\n\n# Assistant's Answer\n---\n{{answer}}\n---\n\n\n# Judge's Evaluation\n","variables":["question","context","answer"],"_type":"usermessage"}]
\ No newline at end of file
diff --git a/test/Experimental/RAGTools/evaluation.jl b/test/Experimental/RAGTools/evaluation.jl
new file mode 100644
index 000000000..f9ea295e8
--- /dev/null
+++ b/test/Experimental/RAGTools/evaluation.jl
@@ -0,0 +1,207 @@
+using PromptingTools.Experimental.RAGTools: QAItem, QAEvalItem, QAEvalResult
+using PromptingTools.Experimental.RAGTools: score_retrieval_hit, score_retrieval_rank
+using PromptingTools.Experimental.RAGTools: build_qa_evals, run_qa_evals, chunks, sources
+using PromptingTools.Experimental.RAGTools: JudgeAllScores, MetadataItem, MaybeMetadataItems
+
+@testset "QAEvalItem" begin
+ empty_qa = QAEvalItem()
+ @test !isvalid(empty_qa)
+ full_qa = QAEvalItem(; question = "a", answer = "b", context = "c")
+ @test isvalid(full_qa)
+end
+
+@testset "Base.show,JSON3.write" begin
+ # Helper function to simulate the IO capture for custom show methods
+ function capture_show(io::IOBuffer, x)
+ show(io, x)
+ return String(take!(io))
+ end
+
+ # Testing Base.show for QAItem
+ qa_item = QAItem("What is Julia?",
+ "Julia is a high-level, high-performance programming language.")
+
+ test_output = capture_show(IOBuffer(), qa_item)
+ @test test_output ==
+ "QAItem:\n question: What is Julia?\n answer: Julia is a high-level, high-performance programming language.\n"
+ json_output = JSON3.write(qa_item)
+ @test JSON3.read(json_output, QAItem) == qa_item
+
+ # Testing Base.show for QAEvalItem
+ qa_eval_item = QAEvalItem(source = "testsource.jl",
+ context = "Julia is a high-level, high-performance programming language.",
+ question = "What is Julia?",
+ answer = "A language.")
+
+ test_output = capture_show(IOBuffer(), qa_eval_item)
+ @test test_output ==
+ "QAEvalItem:\n source: testsource.jl\n context: Julia is a high-level, high-performance programming language.\n question: What is Julia?\n answer: A language.\n"
+ json_output = JSON3.write(qa_eval_item)
+ @test JSON3.read(json_output, QAEvalItem) == qa_eval_item
+
+ # Testing Base.show for QAEvalResult
+ params = Dict(:key1 => "value1", :key2 => 2)
+ qa_eval_result = QAEvalResult(source = "testsource.jl",
+ context = "Julia is amazing for technical computing.",
+ question = "Why is Julia good?",
+ answer = "Because of its speed and ease of use.",
+ retrieval_score = 0.89,
+ retrieval_rank = 1,
+ answer_score = 100.0,
+ parameters = params)
+
+ test_output = capture_show(IOBuffer(), qa_eval_result)
+ @test test_output ==
+ "QAEvalResult:\n source: testsource.jl\n context: Julia is amazing for technical computing.\n question: Why is Julia good?\n answer: Because of its speed and ease of use.\n retrieval_score: 0.89\n retrieval_rank: 1\n answer_score: 100.0\n parameters: Dict{Symbol, Any}(:key2 => 2, :key1 => \"value1\")\n"
+ json_output = JSON3.write(qa_eval_result)
+ @test JSON3.read(json_output, QAEvalResult) == qa_eval_result
+end
+
+@testset "score_retrieval_hit,score_retrieval_rank" begin
+ orig_context = "I am a horse."
+ candidate_context = ["Hello", "World", "I am a horse...."]
+ candidate_context2 = ["Hello", "I am a hors"]
+ candidate_context3 = ["Hello", "World", "I am X horse...."]
+ @test score_retrieval_hit(orig_context, candidate_context) == 1.0
+ @test score_retrieval_hit(orig_context, candidate_context2) == 1.0
+ @test score_retrieval_hit(orig_context, candidate_context[1:2]) == 0.0
+ @test score_retrieval_hit(orig_context, candidate_context3) == 0.0
+
+ @test score_retrieval_rank(orig_context, candidate_context) == 3
+ @test score_retrieval_rank(orig_context, candidate_context2) == 2
+ @test score_retrieval_rank(orig_context, candidate_context[1:2]) == nothing
+ @test score_retrieval_rank(orig_context, candidate_context3) == nothing
+end
+
+@testset "build_qa_evals" begin
+ # test with a mock server
+ PORT = rand(1000:2000)
+ PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-qa", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-judge", schema = PT.CustomOpenAISchema())
+
+ echo_server = HTTP.serve!(PORT; verbose = -1) do req
+ content = JSON3.read(req.body)
+
+ if content[:model] == "mock-gen"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [Dict(:message => user_msg)],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-emb"
+ # for i in 1:length(content[:input])
+ response = Dict(:data => [Dict(:embedding => ones(Float32, 128))],
+ :usage => Dict(:total_tokens => length(content[:input]),
+ :prompt_tokens => length(content[:input]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-meta"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [
+ Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(MaybeMetadataItems([
+ MetadataItem("yes", "category"),
+ ]))))),
+ ],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-qa"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [
+ Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(QAItem("Question",
+ "Answer"))))),
+ ],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-judge"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [
+ Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(JudgeAllScores(5,
+ 5,
+ 5,
+ 5,
+ 5,
+ "Some reasons",
+ 5.0))))),
+ ],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ else
+ @info content
+ end
+ return HTTP.Response(200, JSON3.write(response))
+ end
+
+ # Index setup
+ index = ChunkIndex(;
+ sources = [".", ".", "."],
+ chunks = ["a", "b", "c"],
+ embeddings = zeros(128, 3),
+ tags = vcat(trues(2, 2), falses(1, 2)),
+ tags_vocab = ["yes", "no"],)
+
+ # Test for successful Q&A extraction from document chunks
+ qa_evals = build_qa_evals(chunks(index),
+ sources(index),
+ instructions = "Some instructions.",
+ model = "mock-qa",
+ api_kwargs = (; url = "http://localhost:$(PORT)"))
+
+ @test length(qa_evals) == length(chunks(index))
+ @test all(getproperty.(qa_evals, :source) .== ".")
+ @test all(getproperty.(qa_evals, :context) == ["a", "b", "c"])
+ @test all(getproperty.(qa_evals, :question) .== "Question")
+ @test all(getproperty.(qa_evals, :answer) .== "Answer")
+
+ # Error checks
+ @test_throws AssertionError build_qa_evals(chunks(index),
+ String[])
+ @test_throws AssertionError build_qa_evals(chunks(index),
+ String[]; qa_template = :BlankSystemUser)
+
+ # Test run_qa_evals on 1 item
+ msg, ctx = airag(index; question = qa_evals[1].question, model_embedding = "mock-emb",
+ model_chat = "mock-gen",
+ model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"),
+ tag_filter = :auto,
+ extract_metadata = false, verbose = false,
+ return_context = true)
+
+ result = run_qa_evals(qa_evals[1], ctx;
+ model_judge = "mock-judge",
+ api_kwargs = (; url = "http://localhost:$(PORT)"),
+ parameters_dict = Dict(:key1 => "value1", :key2 => 2))
+ @test result.retrieval_score == 1.0
+ @test result.retrieval_rank == 1
+ @test result.answer_score == 5
+ @test result.parameters == Dict(:key1 => "value1", :key2 => 2)
+
+ # Test all evals at once
+ # results = run_qa_evals(index, qa_evals; model_judge = "mock-judge",
+ # api_kwargs = (; url = "http://localhost:$(PORT)"))
+ results = run_qa_evals(index, qa_evals;
+ airag_kwargs = (;
+ model_embedding = "mock-emb",
+ model_chat = "mock-gen",
+ model_metadata = "mock-meta"),
+ qa_evals_kwargs = (; model_judge = "mock-judge"),
+ api_kwargs = (; url = "http://localhost:$(PORT)"),
+ parameters_dict = Dict(:key1 => "value1", :key2 => 2))
+
+ @test length(results) == length(qa_evals)
+ @test all(getproperty.(results, :retrieval_score) .== 1.0)
+ @test all(getproperty.(results, :retrieval_rank) .== 1)
+ @test all(getproperty.(results, :answer_score) .== 5)
+ @test all(getproperty.(results, :parameters) .==
+ Ref(Dict(:key1 => "value1", :key2 => 2)))
+ # clean up
+ close(echo_server)
+end
\ No newline at end of file
diff --git a/test/Experimental/RAGTools/generation.jl b/test/Experimental/RAGTools/generation.jl
new file mode 100644
index 000000000..b19d03a44
--- /dev/null
+++ b/test/Experimental/RAGTools/generation.jl
@@ -0,0 +1,130 @@
+using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, MetadataItem, build_context
+
+@testset "build_context" begin
+ index = ChunkIndex(;
+ sources = [".", ".", "."],
+ chunks = ["a", "b", "c"],
+ embeddings = zeros(128, 3),
+ tags = vcat(trues(2, 2), falses(1, 2)),
+ tags_vocab = ["yes", "no"],)
+ candidates = CandidateChunks(index.id, [1, 2], [0.1, 0.2])
+
+ # Standard Case
+ context = build_context(index, candidates)
+ expected_output = ["1. a\nb",
+ "2. a\nb\nc"]
+ @test context == expected_output
+
+ # No Surrounding Chunks
+ context = build_context(index, candidates; chunks_window_margin = (0, 0))
+ expected_output = ["1. a",
+ "2. b"]
+ @test context == expected_output
+
+ # Wrong inputs
+ @test_throws AssertionError build_context(index,
+ candidates;
+ chunks_window_margin = (-1, 0))
+end
+
+@testset "airag" begin
+ # test with a mock server
+ PORT = rand(1000:2000)
+ PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema())
+
+ echo_server = HTTP.serve!(PORT; verbose = -1) do req
+ content = JSON3.read(req.body)
+
+ if content[:model] == "mock-gen"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [Dict(:message => user_msg)],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-emb"
+ # for i in 1:length(content[:input])
+ response = Dict(:data => [Dict(:embedding => ones(Float32, 128))],
+ :usage => Dict(:total_tokens => length(content[:input]),
+ :prompt_tokens => length(content[:input]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-meta"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [
+ Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(MaybeMetadataItems([
+ MetadataItem("yes", "category"),
+ ]))))),
+ ],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ else
+ @info content
+ end
+ return HTTP.Response(200, JSON3.write(response))
+ end
+
+ ## Index
+ index = ChunkIndex(;
+ sources = [".", ".", "."],
+ chunks = ["a", "b", "c"],
+ embeddings = zeros(128, 3),
+ tags = vcat(trues(2, 2), falses(1, 2)),
+ tags_vocab = ["yes", "no"],)
+ ## Sub-calls
+ question_emb = aiembed(["x", "x"];
+ model = "mock-emb",
+ api_kwargs = (; url = "http://localhost:$(PORT)"))
+ @test question_emb.content == ones(128)
+ metadata_msg = aiextract(:RAGExtractMetadataShort; return_type = MaybeMetadataItems,
+ text = "x",
+ model = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"))
+ @test metadata_msg.content.items == [MetadataItem("yes", "category")]
+ answer_msg = aigenerate(:RAGAnswerFromContext;
+ question = "Time?",
+ context = "XYZ",
+ model = "mock-gen", api_kwargs = (; url = "http://localhost:$(PORT)"))
+ @test occursin("Time?", answer_msg.content)
+ ## E2E
+ msg = airag(index; question = "Time?", model_embedding = "mock-emb",
+ model_chat = "mock-gen",
+ model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"),
+ tag_filter = ["yes"],
+ return_context = false)
+ @test occursin("Time?", msg.content)
+
+ ## Test different kwargs
+ msg, ctx = airag(index; question = "Time?", model_embedding = "mock-emb",
+ model_chat = "mock-gen",
+ model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"),
+ tag_filter = :auto,
+ extract_metadata = false, verbose = false,
+ return_context = true)
+ @test ctx.context == ["1. a\nb\nc", "2. a\nb"]
+ @test ctx.emb_candidates.positions == [3, 2, 1]
+ @test ctx.emb_candidates.distances == zeros(3)
+ @test ctx.tag_candidates.positions == [1, 2]
+ @test ctx.tag_candidates.distances == ones(2)
+ @test ctx.filtered_candidates.positions == [2, 1] #re-sort
+ @test ctx.filtered_candidates.distances == 0.5ones(2)
+ @test ctx.reranked_candidates.positions == [2, 1] # no change
+ @test ctx.reranked_candidates.distances == 0.5ones(2) # no change
+
+ ## Not tag filter
+ msg, ctx = airag(index; question = "Time?", model_embedding = "mock-emb",
+ model_chat = "mock-gen",
+ model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"),
+ tag_filter = nothing,
+ return_context = true)
+ @test ctx.context == ["1. b\nc", "2. a\nb\nc", "3. a\nb"]
+ @test ctx.emb_candidates.positions == [3, 2, 1]
+ @test ctx.emb_candidates.distances == zeros(3)
+ @test ctx.tag_candidates == nothing
+ @test ctx.filtered_candidates.positions == [3, 2, 1] #re-sort
+ @test ctx.reranked_candidates.positions == [3, 2, 1] # no change
+ # clean up
+ close(echo_server)
+end
diff --git a/test/Experimental/RAGTools/preparation.jl b/test/Experimental/RAGTools/preparation.jl
new file mode 100644
index 000000000..3e8396fbc
--- /dev/null
+++ b/test/Experimental/RAGTools/preparation.jl
@@ -0,0 +1,128 @@
+using PromptingTools.Experimental.RAGTools: metadata_extract, MetadataItem
+using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, build_tags, build_index
+
+@testset "metadata_extract" begin
+ # MetadataItem Structure
+ item = MetadataItem("value", "category")
+ @test item.value == "value"
+ @test item.category == "category"
+
+ # MaybeMetadataItems Structure
+ items = MaybeMetadataItems([
+ MetadataItem("value1", "category1"),
+ MetadataItem("value2", "category2"),
+ ])
+ @test length(items.items) == 2
+ @test items.items[1].value == "value1"
+ @test items.items[1].category == "category1"
+
+ empty_items = MaybeMetadataItems(nothing)
+ @test isempty(metadata_extract(empty_items.items))
+
+ # Metadata Extraction Function
+ single_item = MetadataItem("DataFrames", "Julia Package")
+ multiple_items = [
+ MetadataItem("pandas", "Software"),
+ MetadataItem("Python", "Language"),
+ MetadataItem("DataFrames", "Julia Package"),
+ ]
+
+ @test metadata_extract(single_item) == "julia_package:::dataframes"
+ @test metadata_extract(multiple_items) ==
+ ["software:::pandas", "language:::python", "julia_package:::dataframes"]
+
+ @test metadata_extract(nothing) == String[]
+end
+
+@testset "build_tags" begin
+ # Single Tag
+ chunk_metadata = [["tag1"]]
+ tags_, tags_vocab_ = build_tags(chunk_metadata)
+
+ @test length(tags_vocab_) == 1
+ @test tags_vocab_ == ["tag1"]
+ @test nnz(tags_) == 1
+ @test tags_[1, 1] == true
+
+ # Multiple Tags with Repetition
+ chunk_metadata = [["tag1", "tag2"], ["tag2", "tag3"]]
+ tags_, tags_vocab_ = build_tags(chunk_metadata)
+
+ @test length(tags_vocab_) == 3
+ @test tags_vocab_ == ["tag1", "tag2", "tag3"]
+ @test nnz(tags_) == 4
+ @test all([tags_[1, 1], tags_[1, 2], tags_[2, 2], tags_[2, 3]])
+
+ # Empty Metadata
+ chunk_metadata = [String[]]
+ tags_, tags_vocab_ = build_tags(chunk_metadata)
+
+ @test isempty(tags_vocab_)
+ @test size(tags_) == (1, 0)
+
+ # Mixed Empty and Non-Empty Metadata
+ chunk_metadata = [["tag1"], String[], ["tag2", "tag3"]]
+ tags_, tags_vocab_ = build_tags(chunk_metadata)
+
+ @test length(tags_vocab_) == 3
+ @test tags_vocab_ == ["tag1", "tag2", "tag3"]
+ @test nnz(tags_) == 3
+ @test all([tags_[1, 1], tags_[3, 2], tags_[3, 3]])
+end
+
+@testset "build_index" begin
+ # test with a mock server
+ PORT = rand(1000:2000)
+ PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema())
+ PT.register_model!(; name = "mock-get", schema = PT.CustomOpenAISchema())
+
+ echo_server = HTTP.serve!(PORT; verbose = -1) do req
+ content = JSON3.read(req.body)
+
+ if content[:model] == "mock-gen"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [Dict(:message => user_msg)],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-emb"
+ response = Dict(:data => [Dict(:embedding => ones(Float32, 128))
+ for i in 1:length(content[:input])],
+ :usage => Dict(:total_tokens => length(content[:input]),
+ :prompt_tokens => length(content[:input]),
+ :completion_tokens => 0))
+ elseif content[:model] == "mock-meta"
+ user_msg = last(content[:messages])
+ response = Dict(:choices => [
+ Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(MaybeMetadataItems([
+ MetadataItem("yes", "category"),
+ ]))))),
+ ],
+ :model => content[:model],
+ :usage => Dict(:total_tokens => length(user_msg[:content]),
+ :prompt_tokens => length(user_msg[:content]),
+ :completion_tokens => 0))
+ else
+ @info content
+ end
+ return HTTP.Response(200, JSON3.write(response))
+ end
+
+ text = "This is a long text that will be split into chunks.\n\n It will be split by the separator. And also by the separator '\n'."
+ tmp, _ = mktemp()
+ write(tmp, text)
+ mini_files = [tmp, tmp]
+ index = build_index(mini_files; max_length = 10, extract_metadata = true,
+ model_embedding = "mock-emb",
+ model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"))
+ @test index.embeddings == hcat(fill(normalize(ones(Float32, 128)), 8)...)
+ @test index.chunks[1:4] == index.chunks[5:8]
+ @test index.sources == fill(tmp, 8)
+ @test index.tags == ones(8, 1)
+ @test index.tags_vocab == ["category:::yes"]
+
+ # clean up
+ close(echo_server)
+end
\ No newline at end of file
diff --git a/test/Experimental/RAGTools/retrieval.jl b/test/Experimental/RAGTools/retrieval.jl
new file mode 100644
index 000000000..fcb9bc819
--- /dev/null
+++ b/test/Experimental/RAGTools/retrieval.jl
@@ -0,0 +1,64 @@
+using PromptingTools.Experimental.RAGTools: find_closest, find_tags
+using PromptingTools.Experimental.RAGTools: Passthrough, rerank
+
+@testset "find_closest" begin
+ test_embeddings = [1.0 2.0 -1.0; 3.0 4.0 -3.0; 5.0 6.0 -6.0] |>
+ x -> mapreduce(normalize, hcat, eachcol(x))
+ query_embedding = [0.1, 0.35, 0.5] |> normalize
+ positions, distances = find_closest(test_embeddings, query_embedding, top_k = 2)
+ # The query vector should be closer to the first embedding
+ @test positions == [1, 2]
+ @test isapprox(distances, [0.9975694083904584
+ 0.9939123761133188], atol = 1e-3)
+
+ # Test when top_k is more than available embeddings
+ positions, _ = find_closest(test_embeddings, query_embedding, top_k = 5)
+ @test length(positions) == size(test_embeddings, 2)
+
+ # Test with minimum_similarity
+ positions, _ = find_closest(test_embeddings, query_embedding, top_k = 5,
+ minimum_similarity = 0.995)
+ @test length(positions) == 1
+
+ # Test behavior with edge values (top_k == 0)
+ @test find_closest(test_embeddings, query_embedding, top_k = 0) == ([], [])
+end
+
+@testset "find_tags" begin
+ test_embeddings = [1.0 2.0; 3.0 4.0; 5.0 6.0] |>
+ x -> mapreduce(normalize, hcat, eachcol(x))
+ query_embedding = [0.1, 0.35, 0.5] |> normalize
+ test_tags_vocab = ["julia", "python", "jr"]
+ test_tags_matrix = sparse([1, 2], [1, 3], [true, true], 2, 3)
+ index = ChunkIndex(;
+ sources = [".", "."],
+ chunks = ["julia", "jr"],
+ embeddings = test_embeddings,
+ tags = test_tags_matrix,
+ tags_vocab = test_tags_vocab)
+
+ # Test for finding the correct positions of a specific tag
+ @test find_tags(index, "julia").positions == [1]
+ @test find_tags(index, "julia").distances == [1.0]
+
+ # Test for no tag found // not in vocab
+ @test find_tags(index, "python").positions |> isempty
+ @test find_tags(index, "java").positions |> isempty
+
+ # Test with regex matching
+ @test find_tags(index, r"^j").positions == [1, 2]
+
+ # Test with multiple tags in vocab
+ @test find_tags(index, ["python", "jr", "x"]).positions == [2]
+end
+
+@testset "rerank" begin
+ # Mock data for testing
+ index = "mock_index"
+ question = "mock_question"
+ candidate_chunks = ["chunk1", "chunk2", "chunk3"]
+
+ # Passthrough Strategy
+ strategy = Passthrough()
+ @test rerank(strategy, index, question, candidate_chunks) === candidate_chunks
+end
\ No newline at end of file
diff --git a/test/Experimental/RAGTools/runtests.jl b/test/Experimental/RAGTools/runtests.jl
new file mode 100644
index 000000000..605ce5df5
--- /dev/null
+++ b/test/Experimental/RAGTools/runtests.jl
@@ -0,0 +1,13 @@
+using Test
+using SparseArrays, LinearAlgebra
+using PromptingTools.Experimental.RAGTools
+using JSON3, HTTP
+
+@testset "RAGTools" begin
+ include("utils.jl")
+ include("types.jl")
+ include("preparation.jl")
+ include("retrieval.jl")
+ include("generation.jl")
+ include("evaluation.jl")
+end
diff --git a/test/Experimental/RAGTools/types.jl b/test/Experimental/RAGTools/types.jl
new file mode 100644
index 000000000..bfb915919
--- /dev/null
+++ b/test/Experimental/RAGTools/types.jl
@@ -0,0 +1,154 @@
+using PromptingTools.Experimental.RAGTools: ChunkIndex, MultiIndex, CandidateChunks
+using PromptingTools.Experimental.RAGTools: embeddings, chunks, tags, tags_vocab, sources
+
+@testset "ChunkIndex" begin
+ # Test constructors and basic accessors
+ chunks_test = ["chunk1", "chunk2"]
+ emb_test = ones(2, 2)
+ tags_test = sparse([1, 2], [1, 2], [true, true], 2, 2)
+ tags_vocab_test = ["vocab1", "vocab2"]
+ sources_test = ["source1", "source2"]
+ ci = ChunkIndex(chunks = chunks_test,
+ embeddings = emb_test,
+ tags = tags_test,
+ tags_vocab = tags_vocab_test,
+ sources = sources_test)
+
+ @test chunks(ci) == chunks_test
+ @test (embeddings(ci)) == emb_test
+ @test tags(ci) == tags_test
+ @test tags_vocab(ci) == tags_vocab_test
+ @test sources(ci) == sources_test
+
+ # Test identity/equality
+ ci1 = ChunkIndex(chunks = ["chunk1", "chunk2"], sources = ["source1", "source2"])
+ ci2 = ChunkIndex(chunks = ["chunk1", "chunk2"], sources = ["source1", "source2"])
+ @test ci1 == ci2
+
+ # Test equality with different chunks and sources
+ ci2 = ChunkIndex(chunks = ["chunk3", "chunk4"], sources = ["source3", "source4"])
+ @test ci1 != ci2
+
+ # Test hcat with ChunkIndex
+ # Setup two different ChunkIndex with different tags and then hcat them
+ chunks1 = ["chunk1", "chunk2"]
+ tags1 = sparse([1, 2], [1, 2], [true, true], 2, 3)
+ tags_vocab1 = ["vocab1", "vocab2", "vocab3"]
+ sources1 = ["source1", "source1"]
+ ci1 = ChunkIndex(chunks = chunks1,
+ tags = tags1,
+ tags_vocab = tags_vocab1,
+ sources = sources1)
+
+ chunks2 = ["chunk3", "chunk4"]
+ tags2 = sparse([1, 2], [1, 3], [true, true], 2, 3)
+ tags_vocab2 = ["vocab1", "vocab3", "vocab4"]
+ sources2 = ["source2", "source2"]
+ ci2 = ChunkIndex(chunks = chunks2,
+ tags = tags2,
+ tags_vocab = tags_vocab2,
+ sources = sources2)
+
+ combined_ci = vcat(ci1, ci2)
+ @test size(tags(combined_ci), 1) == 4
+ @test size(tags(combined_ci), 2) == 4
+ @test length(unique(vcat(tags_vocab(ci1), tags_vocab(ci2)))) ==
+ length(tags_vocab(combined_ci))
+ @test sources(combined_ci) == vcat(sources(ci1), (sources(ci2)))
+
+ # Test base var"==" with ChunkIndex
+ ci1 = ChunkIndex(chunks = ["chunk1"],
+ tags = trues(3, 1),
+ tags_vocab = ["vocab1"],
+ sources = ["source1"])
+ ci2 = ChunkIndex(chunks = ["chunk1"],
+ tags = trues(3, 1),
+ tags_vocab = ["vocab1"],
+ sources = ["source1"])
+ @test ci1 == ci2
+end
+
+@testset "MultiIndex" begin
+ # Test constructors/accessors
+ # MultiIndex behaves as a container for ChunkIndexes
+ cin1 = ChunkIndex(chunks = ["chunk1"], sources = ["source1"])
+ cin2 = ChunkIndex(chunks = ["chunk2"], sources = ["source2"])
+ multi_index = MultiIndex(indexes = [cin1, cin2])
+ @test length(multi_index.indexes) == 2
+ @test cin1 in multi_index.indexes
+ @test cin2 in multi_index.indexes
+
+ # Test base var"==" with MultiIndex
+ # Case where MultiIndexes are equal
+ cin1 = ChunkIndex(chunks = ["chunk1"], sources = ["source1"])
+ cin2 = ChunkIndex(chunks = ["chunk2"], sources = ["source2"])
+ mi1 = MultiIndex(indexes = [cin1, cin2])
+ mi2 = MultiIndex(indexes = [cin1, cin2])
+ @test mi1 == mi2
+
+ # Test equality with different ChunkIndexes inside
+ cin1 = ChunkIndex(chunks = ["chunk1"], sources = ["source1"])
+ cin2 = ChunkIndex(chunks = ["chunk2"], sources = ["source2"])
+ mi1 = MultiIndex(indexes = [cin1])
+ mi2 = MultiIndex(indexes = [cin2])
+ @test mi1 != mi2
+end
+
+@testset "getindex with CandidateChunks" begin
+ # Initialize a ChunkIndex with test data
+ chunks_data = ["First chunk", "Second chunk", "Third chunk"]
+ embeddings_data = rand(3, 3) # Random matrix with 3 embeddings
+ tags_data = sparse(Bool[1 1; 0 1; 1 0]) # Some arbitrary sparse matrix representation
+ tags_vocab_data = ["tag1", "tag2"]
+ chunk_sym = Symbol("TestChunkIndex")
+ test_chunk_index = ChunkIndex(chunks = chunks_data,
+ embeddings = embeddings_data,
+ tags = tags_data,
+ tags_vocab = tags_vocab_data,
+ sources = repeat(["test_source"], 3),
+ id = chunk_sym)
+
+ # Test to get chunks based on valid CandidateChunks
+ candidate_chunks = CandidateChunks(index_id = chunk_sym,
+ positions = [1, 3],
+ distances = [0.1, 0.2])
+ @test collect(test_chunk_index[candidate_chunks]) == ["First chunk", "Third chunk"]
+
+ # Test with empty positions, which should result in an empty array
+ candidate_chunks_empty = CandidateChunks(index_id = chunk_sym,
+ positions = Int[],
+ distances = Float32[])
+ @test isempty(test_chunk_index[candidate_chunks_empty])
+
+ # Test with positions out of bounds, should handle gracefully without errors
+ candidate_chunks_oob = CandidateChunks(index_id = chunk_sym,
+ positions = [10, -1],
+ distances = [0.5, 0.6])
+ @test_throws AssertionError test_chunk_index[candidate_chunks_oob]
+
+ # Test with an incorrect index_id, which should also result in an empty array
+ wrong_sym = Symbol("InvalidIndex")
+ candidate_chunks_wrong_id = CandidateChunks(index_id = wrong_sym,
+ positions = [1, 2],
+ distances = [0.3, 0.4])
+ @test isempty(test_chunk_index[candidate_chunks_wrong_id])
+
+ # Test when chunks are requested from a MultiIndex, only chunks from the corresponding ChunkIndex should be returned
+ another_chuck_index = ChunkIndex(chunks = chunks_data,
+ embeddings = nothing,
+ tags = nothing,
+ tags_vocab = nothing,
+ sources = repeat(["another_source"], 3),
+ id = Symbol("AnotherChunkIndex"))
+ test_multi_index = MultiIndex(indexes = [
+ test_chunk_index,
+ another_chuck_index,
+ ])
+ @test collect(test_multi_index[candidate_chunks]) == ["First chunk", "Third chunk"]
+
+ # Test when wrong index_id is used with MultiIndex, resulting in an empty array
+ @test isempty(test_multi_index[candidate_chunks_wrong_id])
+
+ # Test error case when trying to use a non-chunks field, should assert error as only :chunks field is supported
+ @test_throws AssertionError test_chunk_index[candidate_chunks, :nonexistent_field]
+end
\ No newline at end of file
diff --git a/test/Experimental/RAGTools/utils.jl b/test/Experimental/RAGTools/utils.jl
new file mode 100644
index 000000000..cc93c31f9
--- /dev/null
+++ b/test/Experimental/RAGTools/utils.jl
@@ -0,0 +1,46 @@
+using PromptingTools.Experimental.RAGTools: _check_aiextract_capability,
+ merge_labeled_matrices
+
+@testset "_check_aiextract_capability" begin
+ @test _check_aiextract_capability("gpt-3.5-turbo") == nothing
+ @test_throws AssertionError _check_aiextract_capability("llama2")
+end
+
+@testset "merge_labeled_matrices" begin
+ # Test with dense matrices and overlapping vocabulary
+ mat1 = [1 2; 3 4]
+ vocab1 = ["word1", "word2"]
+ mat2 = [5 6; 7 8]
+ vocab2 = ["word2", "word3"]
+
+ merged_mat, combined_vocab = merge_labeled_matrices(mat1, vocab1, mat2, vocab2)
+
+ @test size(merged_mat) == (4, 3)
+ @test combined_vocab == ["word1", "word2", "word3"]
+ @test merged_mat == [1 2 0; 3 4 0; 0 5 6; 0 7 8]
+
+ # Test with sparse matrices and disjoint vocabulary
+ mat1 = sparse([1 0; 0 2])
+ vocab1 = ["word1", "word2"]
+ mat2 = sparse([3 0; 0 4])
+ vocab2 = ["word3", "word4"]
+
+ merged_mat, combined_vocab = merge_labeled_matrices(mat1, vocab1, mat2, vocab2)
+
+ @test size(merged_mat) == (4, 4)
+ @test combined_vocab == ["word1", "word2", "word3", "word4"]
+ @test merged_mat == sparse([1 0 0 0; 0 2 0 0; 0 0 3 0; 0 0 0 4])
+
+ # Test with different data types
+ mat1 = [1.0 2.0; 3.0 4.0]
+ vocab1 = ["word1", "word2"]
+ mat2 = [5 6; 7 8]
+ vocab2 = ["word2", "word3"]
+
+ merged_mat, combined_vocab = merge_labeled_matrices(mat1, vocab1, mat2, vocab2)
+
+ @test eltype(merged_mat) == Float64
+ @test size(merged_mat) == (4, 3)
+ @test combined_vocab == ["word1", "word2", "word3"]
+ @test merged_mat ≈ [1.0 2.0 0.0; 3.0 4.0 0.0; 0.0 5.0 6.0; 0.0 7.0 8.0]
+end
\ No newline at end of file
diff --git a/test/llm_openai.jl b/test/llm_openai.jl
index 95364b847..cc45494d0 100644
--- a/test/llm_openai.jl
+++ b/test/llm_openai.jl
@@ -180,7 +180,7 @@ end
@testset "OpenAI.create_chat" begin
# Test CustomOpenAISchema() with a mock server
PORT = rand(1000:2000)
- echo_server = HTTP.serve!(PORT) do req
+ echo_server = HTTP.serve!(PORT, verbose = -1) do req
content = JSON3.read(req.body)
user_msg = last(content[:messages])
response = Dict(:choices => [Dict(:message => user_msg)],
@@ -206,7 +206,7 @@ end
@testset "OpenAI.create_embeddings" begin
# Test CustomOpenAISchema() with a mock server
PORT = rand(1000:2000)
- echo_server = HTTP.serve!(PORT) do req
+ echo_server = HTTP.serve!(PORT, verbose = -1) do req
content = JSON3.read(req.body)
response = Dict(:data => [Dict(:embedding => ones(128))],
:usage => Dict(:total_tokens => length(content[:input]),
diff --git a/test/runtests.jl b/test/runtests.jl
index c4cc26288..d8bc24dad 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,6 @@
using PromptingTools
using OpenAI, HTTP, JSON3
+using SparseArrays, LinearAlgebra
using Test
using Aqua
const PT = PromptingTools
@@ -33,4 +34,9 @@ let cb = AICode(; code = """
@test !isnothing(cb.expression) # parsed
@test occursin("Test Failed", cb.stdout) # capture details of the test failure
@test isnothing(cb.output) # because it failed
-end
\ No newline at end of file
+end
+
+## Run experimental
+@testset "Experimental" begin
+ include("Experimental/RAGTools/runtests.jl")
+end
diff --git a/test/utils.jl b/test/utils.jl
index 1b726924a..ceabdd5e7 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -1,6 +1,7 @@
using PromptingTools: split_by_length, replace_words
-using PromptingTools: _extract_handlebar_variables, _report_stats
+using PromptingTools: _extract_handlebar_variables, call_cost, _report_stats
using PromptingTools: _string_to_vector, _encode_local_image
+using PromptingTools: DataMessage, AIMessage
@testset "replace_words" begin
words = ["Disney", "Snow White", "Mickey Mouse"]
@@ -32,7 +33,7 @@ end
# Test with empty text
chunks = split_by_length("")
- @test isempty(chunks)
+ @test chunks == [""]
# Test custom separator
text = "Hello,World,"^50
@@ -43,6 +44,34 @@ end
@test length(chunks) == 34
@test maximum(length.(chunks)) <= 20
@test join(chunks, "") == text
+
+ ### Multiple separators
+ # Single separator
+ text = "First sentence. Second sentence. Third sentence."
+ chunks = split_by_length(text, ["."], max_length = 15)
+ @test length(chunks) == 3
+ @test chunks == ["First sentence.", " Second sentence.", " Third sentence."]
+
+ # Multiple separators
+ text = "Paragraph 1\n\nParagraph 2. Sentence 1. Sentence 2.\nParagraph 3"
+ separators = ["\n\n", ". ", "\n"]
+ chunks = split_by_length(text, separators, max_length = 20)
+ @test length(chunks) == 5
+ @test chunks[1] == "Paragraph 1\n\n"
+ @test chunks[2] == "Paragraph 2. "
+ @test chunks[3] == "Sentence 1. "
+ @test chunks[4] == "Sentence 2.\n"
+ @test chunks[5] == "Paragraph 3"
+
+ # empty separators
+ text = "Some text without separators."
+ @test_throws AssertionError split_by_length(text, String[], max_length = 10)
+ # edge cases
+ text = "Short text"
+ separators = ["\n\n", ". ", "\n"]
+ chunks = split_by_length(text, separators, max_length = 50)
+ @test length(chunks) == 1
+ @test chunks[1] == text
end
@testset "extract_handlebar_variables" begin
@@ -68,20 +97,34 @@ end
@test actual_output == expected_output
end
+@testset "call_cost" begin
+ msg = AIMessage(; content = "", tokens = (1000, 2000))
+ cost = call_cost(msg, "unknown_model")
+ @test cost == 0.0
+ @test call_cost(msg, "gpt-3.5-turbo") ≈ 1000 * 1.5e-6 + 2e-6 * 2000
+
+ msg = DataMessage(; content = nothing, tokens = (1000, 1000))
+ cost = call_cost(msg, "unknown_model")
+ @test cost == 0.0
+ @test call_cost(msg, "gpt-3.5-turbo") ≈ 1000 * 1.5e-6 + 2e-6 * 1000
+
+ @test call_cost(msg,
+ "gpt-3.5-turbo";
+ cost_of_token_prompt = 1,
+ cost_of_token_generation = 1) ≈ 1000 + 1000
+end
+
@testset "report_stats" begin
# Returns a string with the total number of tokens and elapsed time when given a message and model
msg = AIMessage(; content = "", tokens = (1, 5), elapsed = 5.0)
- model = "model"
+ model = "unknown_model"
expected_output = "Tokens: 6 in 5.0 seconds"
@test _report_stats(msg, model) == expected_output
# Returns a string with a cost
- expected_output = "Tokens: 6 @ Cost: \$0.007 in 5.0 seconds"
- @test _report_stats(msg, model, 2e-3, 1e-3) == expected_output
-
- # Returns a string without cost when it's zero
- expected_output = "Tokens: 6 in 5.0 seconds"
- @test _report_stats(msg, model, 0, 0) == expected_output
+ msg = AIMessage(; content = "", tokens = (1000, 5000), elapsed = 5.0)
+ expected_output = "Tokens: 6000 @ Cost: \$0.0115 in 5.0 seconds"
+ @test _report_stats(msg, "gpt-3.5-turbo") == expected_output
end
@testset "_string_to_vector" begin