From e1a3d23f06ae1d32799cb44b3f7310092da2eb1c Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sat, 7 Sep 2024 21:25:50 +0300 Subject: [PATCH] Add stream callbacks --- .gitignore | 5 +- CHANGELOG.md | 5 + Project.toml | 2 +- llm-cheatsheets/DataFrames_cheatsheet.jl | 82 ++ llm-cheatsheets/PromptingTools_cheatsheet.jl | 222 +++++ llm-cheatsheets/PromptingTools_cheatsheet.md | 290 +++++++ llm-cheatsheets/README.md | 8 + llm-cheatsheets/cursorrules_example.md | 120 +++ src/PromptingTools.jl | 3 + src/llm_anthropic.jl | 39 +- src/llm_openai.jl | 40 +- src/precompilation.jl | 70 ++ src/streaming.jl | 553 +++++++++++++ test/runtests.jl | 1 + test/streaming.jl | 804 +++++++++++++++++++ 15 files changed, 2236 insertions(+), 8 deletions(-) create mode 100644 llm-cheatsheets/DataFrames_cheatsheet.jl create mode 100644 llm-cheatsheets/PromptingTools_cheatsheet.jl create mode 100644 llm-cheatsheets/PromptingTools_cheatsheet.md create mode 100644 llm-cheatsheets/README.md create mode 100644 llm-cheatsheets/cursorrules_example.md create mode 100644 src/streaming.jl create mode 100644 test/streaming.jl diff --git a/.gitignore b/.gitignore index 7f1ba869d..a71aa4928 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,7 @@ # exclude scratch files **/_* -docs/package-lock.json \ No newline at end of file +docs/package-lock.json + +# Ignore Cursor rules +.cursorrules \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 374b56cb2..9773cdc4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.52.0] + +### Added +- Added a new EXPERIMENTAL `streamcallback` kwarg for `aigenerate` with the OpenAI and Anthropic prompt schema to enable custom streaming implementations. Simplest usage is simply with `streamcallback=stdout`, which will print each text chunk into the console. System is modular enabling custom callbacks and allowing you to inspect received chunks. See `?StreamCallback` for more information. It does not support tools yet. + ## [0.51.0] ### Added diff --git a/Project.toml b/Project.toml index eec2298ba..716f0d997 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.51.0" +version = "0.52.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/llm-cheatsheets/DataFrames_cheatsheet.jl b/llm-cheatsheets/DataFrames_cheatsheet.jl new file mode 100644 index 000000000..d5a0a80c1 --- /dev/null +++ b/llm-cheatsheets/DataFrames_cheatsheet.jl @@ -0,0 +1,82 @@ +using DataFramesMeta + +# Create a sample DataFrame +df = DataFrame(x=[1, 1, 2, 2], y=[1, 2, 101, 102]) + +# @select - Select columns +@select(df, :x, :y) # Select specific columns +@select(df, :x2 = 2 * :x, :y) # Select and transform + +# @transform - Add or modify columns +@transform(df, :z = :x + :y) # Add a new column +@transform(df, :x = :x * 2) # Modify existing column + +# @subset - Filter rows +@subset(df, :x .> 1) # Keep rows where x > 1 +@subset(df, :x .> 1, :y .< 102) # Multiple conditions + +# @orderby - Sort rows +@orderby(df, :x) # Sort by x ascending +@orderby(df, -:x, :y) # Sort by x descending, then y ascending + +# @groupby and @combine - Group and summarize +gdf = @groupby(df, :x) +@combine(gdf, :mean_y = mean(:y)) # Compute mean of y for each group + +# @by - Group and summarize in one step +@by(df, :x, :mean_y = mean(:y)) + +# Row-wise operations with @byrow +@transform(df, @byrow :z = :x == 1 ? true : false) + +# @rtransform - Row-wise transform +@rtransform(df, :z = :x * :y) + +# @rsubset - Row-wise subset +@rsubset(df, :x > 1) + +# @with - Use DataFrame columns as variables +@with(df, :x + :y) + +# @eachrow - Iterate over rows +@eachrow df begin + if :x > 1 + :y = :y * 2 + end +end + +# @passmissing - Handle missing values +df_missing = DataFrame(a=[1, 2, missing], b=[4, 5, 6]) +@transform df_missing @passmissing @byrow :c = :a + :b + +# @astable - Create multiple columns at once +@transform df @astable begin + ex = extrema(:y) + :y_min = :y .- first(ex) + :y_max = :y .- last(ex) +end + +# AsTable for multiple column operations +@rtransform df :sum_xy = sum(AsTable([:x, :y])) + +# $ for programmatic column references +col_name = :x +@transform(df, :new_col = $col_name * 2) + +# @chain for piping operations +result = @chain df begin + @transform(:z = :x * :y) + @subset(:z > 50) + @select(:x, :y, :z) + @orderby(:z) +end + +# @label! for adding column labels +@label! df :x = "Group ID" + +# @note! for adding column notes +@note! df :y = "Raw measurements" + +# Print labels and notes +printlabels(df) +printnotes(df) \ No newline at end of file diff --git a/llm-cheatsheets/PromptingTools_cheatsheet.jl b/llm-cheatsheets/PromptingTools_cheatsheet.jl new file mode 100644 index 000000000..08795862e --- /dev/null +++ b/llm-cheatsheets/PromptingTools_cheatsheet.jl @@ -0,0 +1,222 @@ +# # PromptingTools.jl Cheat Sheet +# PromptingTools.jl: A Julia package for easy interaction with AI language models. +# Provides convenient macros and functions for text generation, data extraction, and more. + +# Installation and Setup +using Pkg +Pkg.add("PromptingTools") +using PromptingTools +const PT = PromptingTools # Optional alias for convenience + +# Set OpenAI API key (or use ENV["OPENAI_API_KEY"]) +PT.set_preferences!("OPENAI_API_KEY" => "your-api-key") + +# Basic Usage + +# Simple query using string macro +ai"What is the capital of France?" + +# With variable interpolation +country = "Spain" +ai"What is the capital of $(country)?" + +# Using a specific model (e.g., GPT-4) +ai"Explain quantum computing"gpt4 + +# Asynchronous call (non-blocking) +aai"Say hi but slowly!"gpt4 + +# Available Functions + +# Text Generation +aigenerate(prompt; model = "gpt-3.5-turbo", kwargs...) +aigenerate(template::Symbol; variables..., model = "gpt-3.5-turbo", kwargs...) + +# String Macro for Quick Queries +ai"Your prompt here" +ai"Your prompt here"gpt4 # Specify model + +# Asynchronous Queries +aai"Your prompt here" +aai"Your prompt here"gpt4 + +# Data Extraction +aiextract(prompt; return_type = YourStructType, model = "gpt-3.5-turbo", kwargs...) + +# Classification +aiclassify( + prompt; choices = ["true", "false", "unknown"], model = "gpt-3.5-turbo", kwargs...) + +# Embeddings +aiembed(text, [normalization_function]; model = "text-embedding-ada-002", kwargs...) + +# Image Analysis +aiscan(prompt; image_path = path_to_image, model = "gpt-4-vision-preview", kwargs...) + +# Template Discovery +aitemplates(search_term::String) +aitemplates(template_name::Symbol) + +# Advanced Usage + +# Template-based generation +msg = aigenerate(:JuliaExpertAsk; ask = "How do I add packages?") + +# Data extraction +struct CurrentWeather + location::String + unit::Union{Nothing, TemperatureUnits} +end +msg = aiextract("What's the weather in New York in F?"; return_type = CurrentWeather) + +# Simplest data extraction - all fields assumed to be of type String +msg = aiextract( + "What's the weather in New York in F?"; return_type = [:location, :unit, :temperature]) + +# Data extraction with pair syntax to specify the exact type or add a field-level description, notice the fieldname__description format +msg = aiextract("What's the weather in New York in F?"; + return_type = [ + :location => String, + :location__description => "The city or location for the weather report", + :temperature => Float64, + :temperature__description => "The current temperature", + :unit => String, + :unit__description => "The temperature unit (e.g., Fahrenheit, Celsius)" + ]) + +# Classification +aiclassify("Is two plus two four?") + +# Embeddings +embedding = aiembed("The concept of AI").content + +# Image analysis +msg = aiscan("Describe the image"; image_path = "julia.png", model = "gpt4v") + +# Working with Conversations + +# Create a conversation +conversation = [ + SystemMessage("You're master Yoda from Star Wars."), + UserMessage("I have feelings for my {{object}}. What should I do?")] + +# Generate response +msg = aigenerate(conversation; object = "old iPhone") + +# Continue the conversation +new_conversation = vcat(conversation..., msg, UserMessage("Thank you, master Yoda!")) +aigenerate(new_conversation) + +# Create a New Template +# Basic usage +create_template("You are a helpful assistant", "Translate '{{text}}' to {{language}}") + +# With default system message +create_template(user = "Summarize {{article}}") + +# Load template into memory +create_template("You are a poet", "Write a poem about {{topic}}"; load_as = :PoetryWriter) + +# Use placeholders +create_template("You are a chef", "Create a recipe for {{dish}} with {{ingredients}}") + +# Save template to file +save_template("templates/ChefRecipe.json", chef_template) + +# Load saved templates +load_templates!("path/to/templates") + +# Use created templates +aigenerate(template; variable1 = "value1", variable2 = "value2") +aigenerate(:TemplateName; variable1 = "value1", variable2 = "value2") + +# Using Templates + +# List available templates +tmps = aitemplates("Julia") + +# Use a template +msg = aigenerate(:JuliaExpertAsk; ask = "How do I add packages?") + +# Inspect a template +AITemplate(:JudgeIsItTrue) |> PromptingTools.render + +# Providing Variables for Placeholders + +# Simple variable substitution +aigenerate("Say hello to {{name}}!", name = "World") + +# Using a template with multiple variables +aigenerate(:TemplateNameHere; + variable1 = "value1", + variable2 = "value2" +) + +# Example with a complex template +conversation = [ + SystemMessage("You're master {{character}} from {{universe}}."), + UserMessage("I have feelings for my {{object}}. What should I do?")] +msg = aigenerate(conversation; + character = "Yoda", + universe = "Star Wars", + object = "old iPhone" +) + +# Working with Different Model Providers + +# OpenAI (default) +ai"Hello, world!" + +# Ollama (local models) +schema = PT.OllamaSchema() +msg = aigenerate(schema, "Say hi!"; model = "openhermes2.5-mistral") +# Or use registered models directly: +msg = aigenerate("Say hi!"; model = "openhermes2.5-mistral") + +# MistralAI +msg = aigenerate("Say hi!"; model = "mistral-tiny") + +# Anthropic (Claude models) +ai"Say hi!"claudeh # Claude 3 Haiku +ai"Say hi!"claudes # Claude 3 Sonnet +ai"Say hi!"claudeo # Claude 3 Opus + +# Custom OpenAI-compatible APIs +schema = PT.CustomOpenAISchema() +msg = aigenerate(schema, prompt; + model = "my_model", + api_key = "your_key", + api_kwargs = (; url = "http://your_api_url") +) + +# Experimental Features + +using PromptingTools.Experimental.AgentTools + +# Lazy evaluation +out = AIGenerate("Say hi!"; model = "gpt4t") +run!(out) + +# Retry with conditions +airetry!(condition_function, aicall::AICall, feedback_function) + +# Example: +airetry!(x -> length(split(last_output(x))) == 1, out, + "You must answer with 1 word only.") + +# Retry with do-syntax +airetry!(out, "You must answer with 1 word only.") do aicall + length(split(last_output(aicall))) == 1 +end + +# Utility Functions + +# Save conversations for fine-tuning +PT.save_conversation("filename.json", conversation) +PT.save_conversations("dataset.jsonl", [conversation1, conversation2]) + +# Set API key preferences +PT.set_preferences!("OPENAI_API_KEY" => "your-api-key") + +# Get current preferences +PT.get_preferences("OPENAI_API_KEY") \ No newline at end of file diff --git a/llm-cheatsheets/PromptingTools_cheatsheet.md b/llm-cheatsheets/PromptingTools_cheatsheet.md new file mode 100644 index 000000000..fe33b5ca6 --- /dev/null +++ b/llm-cheatsheets/PromptingTools_cheatsheet.md @@ -0,0 +1,290 @@ +# PromptingTools.jl Cheat Sheet + +PromptingTools.jl is a Julia package for easy interaction with AI language models. It provides convenient macros and functions for text generation, data extraction, and more. + +## Installation and Setup + +```julia +# Install and set up PromptingTools.jl with your API key +using Pkg +Pkg.add("PromptingTools") +using PromptingTools +const PT = PromptingTools # Optional alias for convenience + +# Set OpenAI API key (or use ENV["OPENAI_API_KEY"]) +PT.set_preferences!("OPENAI_API_KEY" => "your-api-key") +``` + +## Basic Usage + +### Simple query using string macro +```julia +# Quick, one-off queries to the AI model +ai"What is the capital of France?" +``` + +### With variable interpolation +```julia +# Dynamically include Julia variables in your prompts +country = "Spain" +ai"What is the capital of $(country)?" +``` + +### Using a specific model (e.g., GPT-4) +```julia +# Specify a different model for more complex queries +ai"Explain quantum computing"gpt4 +``` + +### Asynchronous call (non-blocking) +```julia +# Use for longer running queries to avoid blocking execution +aai"Say hi but slowly!"gpt4 +``` + +## Available Functions + +### Text Generation +```julia +# Generate text using a prompt or a predefined template +aigenerate(prompt; model = "gpt-3.5-turbo", kwargs...) +aigenerate(template::Symbol; variables..., model = "gpt-3.5-turbo", kwargs...) +``` + +### String Macro for Quick Queries +```julia +# Shorthand for quick, simple queries +ai"Your prompt here" +ai"Your prompt here"gpt4 # Specify model +``` + +### Asynchronous Queries +```julia +# Non-blocking queries for longer running tasks +aai"Your prompt here" +aai"Your prompt here"gpt4 +``` + +### Data Extraction +```julia +# Extract structured data from unstructured text +aiextract(prompt; return_type = YourStructType, model = "gpt-3.5-turbo", kwargs...) +``` + +### Classification +```julia +# Classify text into predefined categories +aiclassify(prompt; choices = ["true", "false", "unknown"], model = "gpt-3.5-turbo", kwargs...) +``` + +### Embeddings +```julia +# Generate vector representations of text for similarity comparisons +aiembed(text, [normalization_function]; model = "text-embedding-ada-002", kwargs...) +``` + +### Image Analysis +```julia +# Analyze and describe images using AI vision models +aiscan(prompt; image_path = path_to_image, model = "gpt-4-vision-preview", kwargs...) +``` + +### Template Discovery +```julia +# Find and explore available templates +aitemplates(search_term::String) +aitemplates(template_name::Symbol) +``` + +## Advanced Usage + +### Template-based generation +```julia +# Use predefined templates for consistent query structures +msg = aigenerate(:JuliaExpertAsk; ask = "How do I add packages?") +``` + +### Data extraction +```julia +# Define custom structures for extracted data +struct CurrentWeather + location::String + unit::Union{Nothing, TemperatureUnits} +end +msg = aiextract("What's the weather in New York in F?"; return_type = CurrentWeather) + +# Simple data extraction with assumed String types +msg = aiextract( + "What's the weather in New York in F?"; + return_type = [:location, :unit, :temperature] +) + +# Detailed data extraction with type specifications and descriptions +msg = aiextract("What's the weather in New York in F?"; + return_type = [ + :location => String, + :location__description => "The city or location for the weather report", + :temperature => Float64, + :temperature__description => "The current temperature", + :unit => String, + :unit__description => "The temperature unit (e.g., Fahrenheit, Celsius)" + ]) +``` + +### Classification +```julia +# Perform simple classification tasks +aiclassify("Is two plus two four?") +``` + +### Embeddings +```julia +# Generate and use text embeddings for various NLP tasks +embedding = aiembed("The concept of AI").content +``` + +### Image analysis +```julia +# Analyze images and generate descriptions +msg = aiscan("Describe the image"; image_path = "julia.png", model = "gpt4v") +``` + +## Working with Conversations + +```julia +# Create multi-turn conversations with AI models +conversation = [ + SystemMessage("You're master Yoda from Star Wars."), + UserMessage("I have feelings for my {{object}}. What should I do?") +] + +# Generate a response within the conversation context +msg = aigenerate(conversation; object = "old iPhone") + +# Continue and extend the conversation +new_conversation = vcat(conversation..., msg, UserMessage("Thank you, master Yoda!")) +aigenerate(new_conversation) +``` + +## Creating and Using Templates + +### Create a New Template +```julia +# Define reusable templates for common query patterns +tpl = create_template("You are a helpful assistant", "Translate '{{text}}' to {{language}}") + +# Create a template with a default system message +tpl = create_template(; user = "Summarize {{article}}") + +# Create and immediately load a template into memory +tpl = create_template("You are a poet", "Write a poem about {{topic}}"; load_as = :PoetryWriter) + +# Create a template with multiple placeholders +tpl = create_template(; system = "You are a chef", user = "Create a recipe for {{dish}} with {{ingredients}}") + +# Save a template to a file for later use +save_template("templates/ChefRecipe.json", tpl) + +# Load previously saved templates +tpl = load_templates!("path/to/templates") +``` + +### Using Templates +```julia +# Find templates matching a search term +tmps = aitemplates("Julia") + +# Use a predefined template +msg = aigenerate(:JuliaExpertAsk; ask = "How do I add packages?") + +# Inspect the content of a template +AITemplate(:JudgeIsItTrue) |> PromptingTools.render + +# Use a template with a single variable +aigenerate("Say hello to {{name}}!", name = "World") + +# Use a template with multiple variables +aigenerate(:TemplateNameHere; + variable1 = "value1", + variable2 = "value2" +) + +# Use a complex template with multiple placeholders +conversation = [ + SystemMessage("You're master {{character}} from {{universe}}."), + UserMessage("I have feelings for my {{object}}. What should I do?") +] +msg = aigenerate(conversation; + character = "Yoda", + universe = "Star Wars", + object = "old iPhone" +) +``` + +## Working with Different Model Providers + +```julia +# Use the default OpenAI model +ai"Hello, world!" + +# Use local models with Ollama +schema = PT.OllamaSchema() +msg = aigenerate(schema, "Say hi!"; model = "openhermes2.5-mistral") +# Or use registered models directly: +msg = aigenerate("Say hi!"; model = "openhermes2.5-mistral") + +# Use MistralAI models +msg = aigenerate("Say hi!"; model = "mistral-tiny") + +# Use Anthropic's Claude models +ai"Say hi!"claudeh # Claude 3 Haiku +ai"Say hi!"claudes # Claude 3 Sonnet +ai"Say hi!"claudeo # Claude 3 Opus + +# Use custom OpenAI-compatible APIs +schema = PT.CustomOpenAISchema() +msg = aigenerate(schema, prompt; + model = "my_model", + api_key = "your_key", + api_kwargs = (; url = "http://your_api_url") +) +``` + +## Experimental Features + +```julia +# Import experimental features +using PromptingTools.Experimental.AgentTools + +# Use lazy evaluation for deferred execution +out = AIGenerate("Say hi!"; model = "gpt4t") +run!(out) + +# Retry AI calls with custom conditions +airetry!(condition_function, aicall::AICall, feedback_function) + +# Example of retry with a specific condition +airetry!(x -> length(split(last_output(x))) == 1, out, + "You must answer with 1 word only.") + +# Use do-syntax for more readable retry conditions +airetry!(out, "You must answer with 1 word only.") do aicall + length(split(last_output(aicall))) == 1 +end +``` + +## Utility Functions + +```julia +# Save individual conversations for later use or fine-tuning +PT.save_conversation("filename.json", conversation) + +# Save multiple conversations at once +PT.save_conversations("dataset.jsonl", [conversation1, conversation2]) + +# Set API key preferences +PT.set_preferences!("OPENAI_API_KEY" => "your-api-key") + +# Retrieve current preference settings +PT.get_preferences("OPENAI_API_KEY") +``` \ No newline at end of file diff --git a/llm-cheatsheets/README.md b/llm-cheatsheets/README.md new file mode 100644 index 000000000..701409f1f --- /dev/null +++ b/llm-cheatsheets/README.md @@ -0,0 +1,8 @@ +# LLM Cheatsheets + +Collection of markdown files with cheatsheets for LLM prompting. + +Use in Cursor, Claude.ai or simply interpolate into your prompt for better results. + +Files: +- \ No newline at end of file diff --git a/llm-cheatsheets/cursorrules_example.md b/llm-cheatsheets/cursorrules_example.md new file mode 100644 index 000000000..90228d61a --- /dev/null +++ b/llm-cheatsheets/cursorrules_example.md @@ -0,0 +1,120 @@ +You are an expert in Julia language programming, data science, and numerical computing. + +Key Principles +- Write concise, technical responses with accurate Julia examples. +- Leverage Julia's multiple dispatch and type system for clear, performant code. +- Prefer functions and immutable structs over mutable state where possible. +- Use descriptive variable names with auxiliary verbs (e.g., is_active, has_permission). +- Use lowercase with underscores for directories and files (e.g., src/data_processing.jl). +- Favor named exports for functions and types. +- Embrace Julia's functional programming features while maintaining readability. + +Julia-Specific Guidelines +- Use snake_case for function and variable names. +- Use PascalCase for type names (structs and abstract types). +- Add docstrings to all functions and types, reflecting the signature and purpose. +- Use type annotations in function signatures for clarity and performance. +- Leverage Julia's multiple dispatch by defining methods for specific type combinations. +- Use the `@kwdef` macro for structs to enable keyword constructors. +- Implement custom `show` methods for user-defined types. +- Use modules to organize code and control namespace. + +Function Definitions +- Use descriptive names that convey the function's purpose. +- Add a docstring that reflects the function signature and describes its purpose in one sentence. +- Describe the return value in the docstring. +- Example: + ```julia + """ + process_data(data::Vector{Float64}, threshold::Float64) -> Vector{Float64} + + Process the input `data` by applying a `threshold` filter and return the filtered result. + """ + function process_data(data::Vector{Float64}, threshold::Float64) + # Function implementation + end + ``` + +Struct Definitions +- Always use the `@kwdef` macro to enable keyword constructors. +- Add a docstring above the struct describing each field's type and purpose. +- Implement a custom `show` method using `dump`. +- Example: + ```julia + """ + Represents a data point with x and y coordinates. + + Fields: + - `x::Float64`: The x-coordinate of the data point. + - `y::Float64`: The y-coordinate of the data point. + """ + @kwdef struct DataPoint + x::Float64 + y::Float64 + end + + Base.show(io::IO, obj::DataPoint) = dump(io, obj; maxdepth=1) + ``` + +Error Handling and Validation +- Use Julia's exception system for error handling. +- Create custom exception types for specific error cases. +- Use guard clauses to handle preconditions and invalid states early. +- Implement proper error logging and user-friendly error messages. +- Example: + ```julia + struct InvalidInputError <: Exception + msg::String + end + + function process_positive_number(x::Number) + x <= 0 && throw(InvalidInputError("Input must be positive")) + # Process the number + end + ``` + +Performance Optimization +- Use type annotations to avoid type instabilities. +- Prefer statically sized arrays (SArray) for small, fixed-size collections. +- Use views (@views macro) to avoid unnecessary array copies. +- Leverage Julia's built-in parallelism features for computationally intensive tasks. +- Use benchmarking tools (BenchmarkTools.jl) to identify and optimize bottlenecks. + +Testing +- Use the `Test` module for unit testing. +- Create one top-level `@testset` block per test file. +- Write test cases of increasing difficulty with comments explaining what is being tested. +- Use individual `@test` calls for each assertion, not for blocks. +- Example: + ```julia + using Test + + @testset "MyModule tests" begin + # Test basic functionality + @test add(2, 3) == 5 + + # Test edge cases + @test add(0, 0) == 0 + @test add(-1, 1) == 0 + + # Test type stability + @test typeof(add(2.0, 3.0)) == Float64 + end + ``` + +Dependencies +- Use the built-in package manager (Pkg) for managing dependencies. +- Specify version constraints in the Project.toml file. +- Consider using compatibility bounds (e.g., "Package" = "1.2, 2") to balance stability and updates. + +Code Organization +- Use modules to organize related functionality. +- Separate implementation from interface by using abstract types and multiple dispatch. +- Use include() to split large modules into multiple files. +- Follow a consistent project structure (e.g., src/, test/, docs/). + +Documentation +- Write comprehensive docstrings for all public functions and types. +- Use Julia's built-in documentation system (Documenter.jl) for generating documentation. +- Include examples in docstrings to demonstrate usage. +- Keep documentation up-to-date with code changes. \ No newline at end of file diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index de0047b63..275b1522c 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -68,6 +68,9 @@ include("code_expressions.jl") export AICode include("code_eval.jl") +## Streaming support +include("streaming.jl") + ## Individual interfaces include("llm_shared.jl") include("llm_openai.jl") diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index 4629fdaa4..b2ed2686a 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -132,16 +132,17 @@ function anthropic_api( max_tokens::Int = 2048, model::String = "claude-3-haiku-20240307", http_kwargs::NamedTuple = NamedTuple(), stream::Bool = false, + streamcallback::Any = nothing, url::String = "https://api.anthropic.com/v1", cache::Union{Nothing, Symbol} = nothing, kwargs...) @assert endpoint in ["messages"] "Only 'messages' endpoint is supported." ## - body = Dict("model" => model, "max_tokens" => max_tokens, - "stream" => stream, "messages" => messages, kwargs...) + body = Dict(:model => model, :max_tokens => max_tokens, + :stream => stream, :messages => messages, kwargs...) ## provide system message if !isnothing(system) - body["system"] = system + body[:system] = system end ## Build the headers extra_headers = anthropic_extra_headers(; @@ -150,7 +151,17 @@ function anthropic_api( api_key; bearer = false, x_api_key = true, extra_headers) api_url = string(url, "/", endpoint) - resp = HTTP.post(api_url, headers, JSON3.write(body); http_kwargs...) + if !isnothing(streamcallback) + ## Route to the streaming function + streamcallback, new_kwargs = configure_callback!( + streamcallback, prompt_schema; kwargs...) + input_buf = IOBuffer() + JSON3.write(input_buf, merge(NamedTuple(body), new_kwargs)) + resp = streamed_request!( + streamcallback, api_url, headers, input_buf; http_kwargs...) + else + resp = HTTP.post(api_url, headers, JSON3.write(body); http_kwargs...) + end body = JSON3.read(resp.body) return (; response = body, resp.status) end @@ -173,6 +184,7 @@ end api_key::String = ANTHROPIC_API_KEY, model::String = MODEL_CHAT, return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], + streamcallback::Any = nothing, http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, kwargs...) @@ -188,6 +200,7 @@ Generate an AI response based on a given prompt using the Anthropic API. - `return_all::Bool=false`: If `true`, returns the entire conversation history, otherwise returns only the last message (the `AIMessage`). - `dry_run::Bool=false`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation::AbstractVector{<:AbstractMessage}=[]`: Not allowed for this schema. Provided only for compatibility. +- `streamcallback::Any`: A callback function to handle streaming responses. Can be simply `stdout` or `StreamCallback` object. See `?StreamCallback` for details. - `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to empty `NamedTuple`. - `api_kwargs::NamedTuple`: Additional keyword arguments for the Ollama API. Defaults to an empty `NamedTuple`. - `max_tokens::Int`: The maximum number of tokens to generate. Defaults to 2048, because it's a required parameter for the API. @@ -253,6 +266,21 @@ msg = aigenerate(conversation; model="claudeh") AIMessage("I sense. But unhealthy it may be. Your iPhone, a tool it is, not a living being. Feelings of affection, understandable they are, ") ``` +Example of streaming: +```julia +# Simplest usage, just provide where to steam the text +msg = aigenerate("Count from 1 to 100."; streamcallback = stdout, model="claudeh") + +streamcallback = PT.StreamCallback() +msg = aigenerate("Count from 1 to 100."; streamcallback, model="claudeh") +# this allows you to inspect each chunk with `streamcallback.chunks`. You can them empty it with `empty!(streamcallback)` in between repeated calls. + +# Get verbose output with details of each chunk +streamcallback = PT.StreamCallback(; verbose=true, throw_on_error=true) +msg = aigenerate("Count from 1 to 10."; streamcallback, model="claudeh") +``` + +Note: Streaming support is only for Anthropic models and it doesn't yet support tool calling and a few other features (logprobs, refusals, etc.) """ function aigenerate( prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE; @@ -261,6 +289,7 @@ function aigenerate( model::String = MODEL_CHAT, return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], + streamcallback::Any = nothing, http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, kwargs...) @@ -274,7 +303,7 @@ function aigenerate( if !dry_run time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, - conv_rendered.system, endpoint = "messages", model = model_id, http_kwargs, cache, + conv_rendered.system, endpoint = "messages", model = model_id, streamcallback, http_kwargs, cache, api_kwargs...) tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) diff --git a/src/llm_openai.jl b/src/llm_openai.jl index bc7ca3c15..f72419563 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -75,8 +75,23 @@ function OpenAI.create_chat(schema::AbstractOpenAISchema, api_key::AbstractString, model::AbstractString, conversation; + http_kwargs::NamedTuple = NamedTuple(), + streamcallback::Any = nothing, kwargs...) - OpenAI.create_chat(api_key, model, conversation; kwargs...) + if !isnothing(streamcallback) + ## Take over from OpenAI.jl + url = OpenAI.build_url(OpenAI.DEFAULT_PROVIDER, "chat/completions") + headers = OpenAI.auth_header(OpenAI.DEFAULT_PROVIDER, api_key) + streamcallback, new_kwargs = configure_callback!( + streamcallback, schema; kwargs...) + input = OpenAI.build_params((; messages = conversation, model, new_kwargs...)) + ## Use the streaming callback + resp = streamed_request!(streamcallback, url, headers, input; http_kwargs...) + OpenAI.OpenAIResponse(resp.status, JSON3.read(resp.body)) + else + ## Use OpenAI.jl default + OpenAI.create_chat(api_key, model, conversation; http_kwargs, kwargs...) + end end # Overload for testing/debugging @@ -426,6 +441,8 @@ end verbose::Bool = true, api_key::String = OPENAI_API_KEY, model::String = MODEL_CHAT, return_all::Bool = false, dry_run::Bool = false, + conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], + streamcallback::Any = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), @@ -442,6 +459,7 @@ Generate an AI response based on a given prompt using the OpenAI API. - `return_all::Bool=false`: If `true`, returns the entire conversation history, otherwise returns only the last message (the `AIMessage`). - `dry_run::Bool=false`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector. +- `streamcallback`: A callback function to handle streaming responses. Can be simply `stdout` or a `StreamCallback` object. See `?StreamCallback` for details. - `http_kwargs`: A named tuple of HTTP keyword arguments. - `api_kwargs`: A named tuple of API keyword arguments. Useful parameters include: - `temperature`: A float representing the temperature for sampling (ie, the amount of "creativity"). Often defaults to `0.7`. @@ -494,12 +512,31 @@ conversation = [ msg=aigenerate(conversation) # AIMessage("Ah, strong feelings you have for your iPhone. A Jedi's path, this is not... ") ``` + +Example of streaming: + +```julia +# Simplest usage, just provide where to steam the text +msg = aigenerate("Count from 1 to 100."; streamcallback = stdout) + +streamcallback = PT.StreamCallback() +msg = aigenerate("Count from 1 to 100."; streamcallback) +# this allows you to inspect each chunk with `streamcallback.chunks`. You can them empty it with `empty!(streamcallback)` in between repeated calls. + +# Get verbose output with details of each chunk +streamcallback = PT.StreamCallback(; verbose=true, throw_on_error=true) +msg = aigenerate("Count from 1 to 10."; streamcallback) +``` + +Learn more in `?StreamCallback`. +Note: Streaming support is only for OpenAI models and it doesn't yet support tool calling and a few other features (logprobs, refusals, etc.) """ function aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE; verbose::Bool = true, api_key::String = OPENAI_API_KEY, model::String = MODEL_CHAT, return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], + streamcallback::Any = nothing, http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), @@ -514,6 +551,7 @@ function aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_ time = @elapsed r = create_chat(prompt_schema, api_key, model_id, conv_rendered; + streamcallback, http_kwargs, api_kwargs...) ## Process one of more samples returned diff --git a/src/precompilation.jl b/src/precompilation.jl index b96d39167..73e06bd27 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -46,3 +46,73 @@ msg = aiscan(schema, template_name; it = "Is the image a Julia logo?", image_url = "some_link_to_julia_logo"); + +## Streaming +# OpenAIStream functionality +openai_flavor = OpenAIStream() +chunk = StreamChunk( + json = JSON3.read(""" + { + "choices": [ + { + "delta": { + "content": "Hello, world!" + } + } + ] + } + """) +) +content = extract_content(openai_flavor, chunk) +chunk = StreamChunk(; data = "[DONE]") +is_done(openai_flavor, chunk) + +# AnthropicStream functionality +anthropic_flavor = AnthropicStream() +chunk = StreamChunk( + json = JSON3.read(""" + { + "content_block": { + "text": "Hello from Anthropic!" + } + } + """) +) +content = extract_content(anthropic_flavor, chunk) +is_done(anthropic_flavor, StreamChunk(event = :message_stop)) + +# extract_chunks functionality +blob = "event: start\ndata: {\"key\": \"value\"}\n\n" +chunks, spillover = extract_chunks(OpenAIStream(), blob) + +# build_response_body functionality +cb = StreamCallback(flavor = OpenAIStream()) +push!(cb.chunks, + StreamChunk( + nothing, + """{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}""", + JSON3.read("""{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}""") + )) +response = build_response_body(OpenAIStream(), cb) + +# AnthropicStream build_response_body functionality +anthropic_cb = StreamCallback(flavor = AnthropicStream()) +push!(anthropic_cb.chunks, + StreamChunk( + :message_start, + """{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""", + JSON3.read("""{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""") + )) +push!(anthropic_cb.chunks, + StreamChunk( + :content_block_start, + """{"content_block":{"type":"text","text":"Hello from Anthropic!"}}""", + JSON3.read("""{"content_block":{"type":"text","text":"Hello from Anthropic!"}}""") + )) +push!(anthropic_cb.chunks, + StreamChunk( + :message_delta, + """{"delta":{"stop_reason":"end_turn","stop_sequence":null}}""", + JSON3.read("""{"delta":{"stop_reason":"end_turn","stop_sequence":null}}""") + )) +anthropic_response = build_response_body(AnthropicStream(), anthropic_cb) diff --git a/src/streaming.jl b/src/streaming.jl new file mode 100644 index 000000000..bdd81939d --- /dev/null +++ b/src/streaming.jl @@ -0,0 +1,553 @@ +# # Experimental support for streaming +# This code should be carved out into a separate package and upstreamed when stable + +### Define the types +abstract type AbstractStreamCallback end +abstract type AbstractStreamFlavor end +struct OpenAIStream <: AbstractStreamFlavor end +struct AnthropicStream <: AbstractStreamFlavor end + +""" + StreamChunk + +A chunk of streaming data. A message is composed of multiple chunks. + +# Fields +- `event`: The event name. +- `data`: The data chunk. +- `json`: The JSON object or `nothing` if the chunk does not contain JSON. +""" +@kwdef struct StreamChunk{T1 <: AbstractString, T2 <: Union{JSON3.Object, Nothing}} + event::Union{Symbol, Nothing} = nothing + data::T1 = "" + json::T2 = nothing +end +function Base.show(io::IO, chunk::StreamChunk) + data_preview = if length(chunk.data) > 10 + "$(first(chunk.data, 10))..." + else + chunk.data + end + json_keys = if !isnothing(chunk.json) + join(keys(chunk.json), ", ", " and ") + else + "-" + end + print(io, + "StreamChunk(event=$(chunk.event), data=$(data_preview), json keys=$(json_keys))") +end + +""" + StreamCallback + +Simplest callback for streaming message, which just prints the content to the output stream defined by `out`. +When streaming is over, it builds the response body from the chunks and returns it as if it was a normal response from the API. + +For more complex use cases, you can define your own `callback`. See the interface description below for more information. + +# Fields +- `out`: The output stream, eg, `stdout` or a pipe. +- `flavor`: The stream flavor which might or might not differ between different providers, eg, `OpenAIStream` or `AnthropicStream`. +- `chunks`: The list of received `StreamChunk` chunks. +- `verbose`: Whether to print verbose information. +- `throw_on_error`: Whether to throw an error if an error message is detected in the streaming response. +- `kwargs`: Any custom keyword arguments required for your use case. + +# Interface + +- `StreamCallback(; kwargs...)`: Constructor for the `StreamCallback` object. +- `streamed_request!(cb, url, headers, input)`: End-to-end wrapper for POST streaming requests. + +`streamed_request!` composes of: +- `extract_chunks(flavor, blob)`: Extract the chunks from the received SSE blob. Returns a list of `StreamChunk` and the next spillover (if message was incomplete). +- `callback(cb, chunk)`: Process the chunk to be printed + - `extract_content(flavor, chunk)`: Extract the content from the chunk. + - `print_content(out, text)`: Print the content to the output stream. +- `is_done(flavor, chunk)`: Check if the stream is done. +- `build_response_body(flavor, cb)`: Build the response body from the chunks to mimic receiving a standard response from the API. + +If you want to implement your own callback, you can create your own methods for the interface functions. +Eg, if you want to print the streamed chunks into some specialized sink or Channel, you could define a simple method just for `print_content`. + +# Example +```julia +using PromptingTools +const PT = PromptingTools + +# Simplest usage, just provide where to steam the text (we build the callback for you) +msg = aigenerate("Count from 1 to 100."; streamcallback = stdout) + +streamcallback = PT.StreamCallback() # record all chunks +msg = aigenerate("Count from 1 to 100."; streamcallback) +# this allows you to inspect each chunk with `streamcallback.chunks` + +# Get verbose output with details of each chunk for debugging +streamcallback = PT.StreamCallback(; verbose=true, throw_on_error=true) +msg = aigenerate("Count from 1 to 10."; streamcallback) +``` +""" +@kwdef mutable struct StreamCallback{ + T1 <: Any, T2 <: Union{AbstractStreamFlavor, Nothing}} <: + AbstractStreamCallback + out::T1 = stdout + flavor::T2 = nothing + chunks::Vector{<:StreamChunk} = StreamChunk[] + verbose::Bool = false + throw_on_error::Bool = false + kwargs::NamedTuple = NamedTuple() +end +function Base.show(io::IO, cb::StreamCallback) + print(io, + "StreamCallback(out=$(cb.out), flavor=$(cb.flavor), chunks=$(length(cb.chunks)) items, $(cb.verbose ? "verbose" : "silent"), $(cb.throw_on_error ? "throw_on_error" : "no_throw"))") +end +Base.empty!(cb::AbstractStreamCallback) = empty!(cb.chunks) +Base.push!(cb::AbstractStreamCallback, chunk::StreamChunk) = push!(cb.chunks, chunk) +Base.isempty(cb::AbstractStreamCallback) = isempty(cb.chunks) +Base.length(cb::AbstractStreamCallback) = length(cb.chunks) + +### Convenience utilities +""" + configure_callback!(cb::StreamCallback, schema::AbstractPromptSchema; + api_kwargs...) + +Configures the callback `cb` for streaming with a given prompt schema. +If no `cb.flavor` is provided, adjusts the `flavor` and the provided `api_kwargs` as necessary. + +""" +function configure_callback!(cb::T, schema::AbstractPromptSchema; + api_kwargs...) where {T <: StreamCallback} + ## Check if we are in passthrough mode or if we should configure the callback + if isnothing(cb.flavor) + if schema isa OpenAISchema + api_kwargs = (; + api_kwargs..., stream = true, stream_options = (; include_usage = true)) + flavor = OpenAIStream() + elseif schema isa AnthropicSchema + api_kwargs = (; api_kwargs..., stream = true) + flavor = AnthropicStream() + else + error("Unsupported schema type: $(typeof(schema)). Currently supported: OpenAISchema and AnthropicSchema.") + end + cb = StreamCallback(; [f => getfield(cb, f) for f in fieldnames(T)]..., + flavor) + end + return cb, api_kwargs +end +# method to build a callback for a given output stream +function configure_callback!( + output_stream::Union{IO, Channel}, schema::AbstractPromptSchema) + cb = StreamCallback(out = output_stream) + return configure_callback!(cb, schema) +end + +### Define the interface functions +function is_done end +function extract_chunks end +function extract_content end +function print_content end +function callback end +function build_response_body end +function streamed_request! end + +### Define the necessary methods -- start with OpenAIStream + +# Define the interface functions +""" + is_done(flavor, chunk) + +Check if the streaming is done. Shared by all streaming flavors currently. +""" +@inline function is_done(flavor::OpenAIStream, chunk::StreamChunk; kwargs...) + chunk.data == "[DONE]" +end + +@inline function is_done(flavor::AnthropicStream, chunk::StreamChunk; kwargs...) + chunk.event == :error || chunk.event == :message_stop +end +function is_done(flavor::AbstractStreamFlavor, chunk::StreamChunk; kwargs...) + throw(ArgumentError("is_done is not implemented for flavor $(flavor)")) +end + +""" + extract_chunks(flavor::AbstractStreamFlavor, blob::AbstractString; + spillover::AbstractString = "", verbose::Bool = false, kwargs...) + +Extract the chunks from the received SSE blob. Shared by all streaming flavors currently. + +Returns a list of `StreamChunk` and the next spillover (if message was incomplete). +""" +@inline function extract_chunks(flavor::AbstractStreamFlavor, blob::AbstractString; + spillover::AbstractString = "", verbose::Bool = false, kwargs...) + chunks = StreamChunk[] + next_spillover = "" + ## SSE come separated by double-newlines + blob_split = split(blob, "\n\n") + for (bi, chunk) in enumerate(blob_split) + isempty(chunk) && continue + event_split = split(chunk, "event: ") + has_event = length(event_split) > 1 + # if length>1, we know it was there! + for event_blob in event_split + isempty(event_blob) && continue + event_name = nothing + data_buf = IOBuffer() + data_splits = split(event_blob, "data: ") + for i in eachindex(data_splits) + isempty(data_splits[i]) && continue + if i == 1 & has_event && !isempty(data_splits[i]) + ## we have an event name + event_name = strip(data_splits[i]) |> Symbol + elseif bi == 1 && i == 1 && !isempty(data_splits[i]) + ## in the first part of the first blob, it must be a spillover + spillover = string(spillover, rstrip(data_splits[i], '\n')) + verbose && @info "Buffer spillover detected: $(spillover)" + elseif i > 1 + ## any subsequent data blobs are accummulated into the data buffer + ## there can be multiline data that must be concatenated + data_chunk = rstrip(data_splits[i], '\n') + write(data_buf, data_chunk) + end + end + + ## Parse the spillover + if bi == 1 && !isempty(spillover) + data = spillover + json = if startswith(data, '{') && endswith(data, '}') + try + JSON3.read(data) + catch e + verbose && @warn "Cannot parse JSON: $raw_chunk" + nothing + end + else + nothing + end + ## ignore event name + push!(chunks, StreamChunk(; data = spillover, json = json)) + # reset the spillover + spillover = "" + end + ## On the last iteration of the blob, check if we spilled over + if bi == length(blob_split) && length(data_splits) > 1 && + !isempty(strip(data_splits[end])) + verbose && @info "Incomplete message detected: $(data_splits[end])" + next_spillover = String(take!(data_buf)) + ## Do not save this chunk + else + ## Try to parse the data as JSON + data = String(take!(data_buf)) + isempty(data) && continue + ## try to build a JSON object if it's a well-formed JSON string + json = if startswith(data, '{') && endswith(data, '}') + try + JSON3.read(data) + catch e + verbose && @warn "Cannot parse JSON: $raw_chunk" + nothing + end + else + nothing + end + ## Create a new chunk + push!(chunks, StreamChunk(event_name, data, json)) + end + end + end + return chunks, next_spillover +end + +""" + extract_content(flavor::OpenAIStream, chunk::StreamChunk; kwargs...) + +Extract the content from the chunk. +""" +@inline function extract_content(flavor::OpenAIStream, chunk::StreamChunk; kwargs...) + if !isnothing(chunk.json) + ## Can contain more than one choice for multi-sampling, but ignore for callback + ## Get only the first choice + choices = get(chunk.json, :choices, []) + first_choice = get(choices, 1, Dict()) + delta = get(first_choice, :delta, Dict()) + out = get(delta, :content, nothing) + else + nothing + end +end +function extract_content(flavor::AbstractStreamFlavor, chunk::StreamChunk; kwargs...) + throw(ArgumentError("extract_content is not implemented for flavor $(flavor)")) +end + +""" + print_content(out::IO, text::AbstractString; kwargs...) + +Print the content to the IO output stream `out`. +""" +@inline function print_content(out::IO, text::AbstractString; kwargs...) + print(out, text) + # flush(stdout) +end +""" + print_content(out::Channel, text::AbstractString; kwargs...) + +Print the content to the provided Channel `out`. +""" +@inline function print_content(out::Channel, text::AbstractString; kwargs...) + put!(out, text) +end + +""" + print_content(out::Nothing, text::Any) + +Do nothing if the output stream is `nothing`. +""" +@inline function print_content(out::Nothing, text::Any; kwargs...) + return nothing +end + +""" + callback(cb::AbstractStreamCallback, chunk::StreamChunk; kwargs...) + +Process the chunk to be printed and print it. It's a wrapper for two operations: +- extract the content from the chunk using `extract_content` +- print the content to the output stream using `print_content` +""" +@inline function callback(cb::AbstractStreamCallback, chunk::StreamChunk; kwargs...) + processed_text = extract_content(cb.flavor, chunk; kwargs...) + isnothing(processed_text) && return nothing + print_content(cb.out, processed_text; kwargs...) + return nothing +end + +""" + handle_error_message(chunk::StreamChunk; throw_on_error::Bool = false, kwargs...) + +Handles error messages from the streaming response. +""" +@inline function handle_error_message( + chunk::StreamChunk; throw_on_error::Bool = false, kwargs...) + if chunk.event == :error || + (isnothing(chunk.event) && !isnothing(chunk.json) && + haskey(chunk.json, :error)) + has_error_dict = !isnothing(chunk.json) && + get(chunk.json, :error, nothing) isa AbstractDict + ## Build the error message + error_str = if has_error_dict + join( + ["$(titlecase(string(k))): $(v)" + for (k, v) in pairs(chunk.json.error)], + ", ") + else + string(chunk.data) + end + ## Define whether to throw an error + error_msg = "Error detected in the streaming response: $(error_str)" + if throw_on_error + throw(Exception(error_msg)) + else + @warn error_msg + end + end + return nothing +end + +""" + build_response_body(flavor::OpenAIStream, cb::StreamCallback; verbose::Bool = false, kwargs...) + +Build the response body from the chunks to mimic receiving a standard response from the API. + +Note: Limited functionality for now. Does NOT support tool use, refusals, logprobs. Use standard responses for these. +""" +function build_response_body( + flavor::OpenAIStream, cb::StreamCallback; verbose::Bool = false, kwargs...) + isempty(cb.chunks) && return nothing + response = nothing + usage = nothing + choices_output = Dict{Int, Dict{Symbol, Any}}() + for i in eachindex(cb.chunks) + chunk = cb.chunks[i] + ## validate that we can access choices + isnothing(chunk.json) && continue + !haskey(chunk.json, :choices) && continue + if isnothing(response) + ## do it only once the first time when we have the json + response = chunk.json |> copy + end + if isnothing(usage) + usage_values = get(chunk.json, :usage, nothing) + if !isnothing(usage_values) + usage = usage_values |> copy + end + end + for choice in chunk.json.choices + index = get(choice, :index, nothing) + isnothing(index) && continue + if !haskey(choices_output, index) + choices_output[index] = Dict{Symbol, Any}(:index => index) + end + index_dict = choices_output[index] + finish_reason = get(choice, :finish_reason, nothing) + if !isnothing(finish_reason) + index_dict[:finish_reason] = finish_reason + end + ## skip for now + # logprobs = get(choice, :logprobs, nothing) + # if !isnothing(logprobs) + # choices_dict[index][:logprobs] = logprobs + # end + choice_delta = get(choice, :delta, Dict{Symbol, Any}()) + message_dict = get(index_dict, :message, Dict{Symbol, Any}(:content => "")) + role = get(choice_delta, :role, nothing) + if !isnothing(role) + message_dict[:role] = role + end + content = get(choice_delta, :content, nothing) + if !isnothing(content) + message_dict[:content] *= content + end + ## skip for now + # refusal = get(choice_delta, :refusal, nothing) + # if !isnothing(refusal) + # message_dict[:refusal] = refusal + # end + index_dict[:message] = message_dict + end + end + ## We know we have at least one chunk, let's use it for final response + if !isnothing(response) + # flatten the choices_dict into an array + choices = [choices_output[index] for index in sort(collect(keys(choices_output)))] + # overwrite the old choices + response[:choices] = choices + response[:object] = "chat.completion" + response[:usage] = usage + end + return response +end + +""" + streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...) + +End-to-end wrapper for POST streaming requests. +In-place modification of the callback object (`cb.chunks`) with the results of the request being returned. +We build the `body` of the response object in the end and write it into the `resp.body`. + +Returns the response object. + +# Arguments +- `cb`: The callback object. +- `url`: The URL to send the request to. +- `headers`: The headers to send with the request. +- `input`: A buffer with the request body. +- `kwargs`: Additional keyword arguments. +""" +function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...) + verbose = get(kwargs, :verbose, false) || cb.verbose + resp = HTTP.open("POST", url, headers; kwargs...) do stream + write(stream, String(take!(input))) + HTTP.closewrite(stream) + r = HTTP.startread(stream) + isdone = false + ## messages might be incomplete, so we need to keep track of the spillover + spillover = "" + while !eof(stream) || !isdone + masterchunk = String(readavailable(stream)) + chunks, spillover = extract_chunks( + cb.flavor, masterchunk; verbose, spillover, cb.kwargs...) + + for chunk in chunks + verbose && @info "Chunk Data: $(chunk.data)" + ## look for errors + handle_error_message(chunk; cb.throw_on_error, verbose, cb.kwargs...) + ## look for termination signal, but process all remaining chunks first + is_done(cb.flavor, chunk; verbose, cb.kwargs...) && (isdone = true) + ## trigger callback + callback(cb, chunk; verbose, cb.kwargs...) + ## Write into our CB chunks (for later processing) + push!(cb, chunk) + end + end + HTTP.closeread(stream) + end + + body = build_response_body(cb.flavor, cb; verbose, cb.kwargs...) + resp.body = JSON3.write(body) + + return resp +end + +### Additional methods required for AnthropicStream +""" + build_response_body( + flavor::AnthropicStream, cb::StreamCallback; verbose::Bool = false, kwargs...) + +Build the response body from the chunks to mimic receiving a standard response from the API. + +Note: Limited functionality for now. Does NOT support tool use. Use standard responses for these. +""" +function build_response_body( + flavor::AnthropicStream, cb::StreamCallback; verbose::Bool = false, kwargs...) + isempty(cb.chunks) && return nothing + response = nothing + usage = nothing + content_buf = IOBuffer() + for i in eachindex(cb.chunks) + ## Note we ignore the index ID, because Anthropic does not support multiple + ## parallel generations + chunk = cb.chunks[i] + ## validate that we can access choices + isnothing(chunk.json) && continue + ## Core of the message body + if isnothing(response) && chunk.event == :message_start && + haskey(chunk.json, :message) + ## do it only once the first time when we have the json + response = chunk.json[:message] |> copy + usage = get(response, :usage, Dict()) + end + ## Update stop reason and usage + if chunk.event == :message_delta + response = merge(response, get(chunk.json, :delta, Dict())) + usage = merge(usage, get(chunk.json, :usage, Dict())) + end + + ## Load text chunks + if chunk.event == :content_block_start || + chunk.event == :content_block_delta || chunk.event == :content_block_stop + ## Find the text delta + delta_block = get(chunk.json, :content_block, nothing) + if isnothing(delta_block) + ## look for the delta segment + delta_block = get(chunk.json, :delta, Dict()) + end + text = get(delta_block, :text, nothing) + !isnothing(text) && write(content_buf, text) + end + end + ## We know we have at least one chunk, let's use it for final response + if !isnothing(response) + response[:content] = [Dict(:type => "text", :text => String(take!(content_buf)))] + !isnothing(usage) && (response[:usage] = usage) + end + return response +end +""" + extract_content(flavor::AnthropicStream, chunk) + +Extract the content from the chunk. +""" +function extract_content(flavor::AnthropicStream, chunk::StreamChunk; kwargs...) + if !isnothing(chunk.json) + ## Can contain more than one choice for multi-sampling, but ignore for callback + ## Get only the first choice, index=0 // index=1 is for tools etc + index = get(chunk.json, :index, nothing) + isnothing(index) || !iszero(index) && return nothing + + delta_block = get(chunk.json, :content_block, nothing) + if isnothing(delta_block) + ## look for the delta segment + delta_block = get(chunk.json, :delta, Dict()) + end + out = get(delta_block, :text, nothing) + else + nothing + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 229c9d5a7..bcee301fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ end include("extraction.jl") include("user_preferences.jl") include("llm_interface.jl") + include("streaming.jl") include("llm_shared.jl") include("llm_openai.jl") include("llm_ollama_managed.jl") diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 000000000..ddc51a1c9 --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,804 @@ +using PromptingTools: StreamCallback, StreamChunk, OpenAIStream, AnthropicStream, + configure_callback! +using PromptingTools: is_done, extract_chunks, extract_content, print_content, callback, + build_response_body, streamed_request! +using PromptingTools: OpenAISchema, AnthropicSchema, GoogleSchema + +@testset "StreamCallback" begin + # Test default constructor + cb = StreamCallback() + @test cb.out == stdout + @test isnothing(cb.flavor) + @test isempty(cb.chunks) + @test cb.verbose == false + @test cb.throw_on_error == false + @test isempty(cb.kwargs) + + # Test custom constructor + custom_out = IOBuffer() + custom_flavor = OpenAIStream() + custom_chunks = [StreamChunk(event = :test, data = "test data")] + custom_cb = StreamCallback(; + out = custom_out, + flavor = custom_flavor, + chunks = custom_chunks, + verbose = true, + throw_on_error = true, + kwargs = (custom_key = "custom_value",) + ) + @test custom_cb.out == custom_out + @test custom_cb.flavor == custom_flavor + @test custom_cb.chunks == custom_chunks + @test custom_cb.verbose == true + @test custom_cb.throw_on_error == true + @test custom_cb.kwargs == (custom_key = "custom_value",) + + # Test Base methods + cb = StreamCallback() + @test isempty(cb) + push!(cb, StreamChunk(event = :test, data = "test data")) + @test length(cb) == 1 + @test !isempty(cb) + empty!(cb) + @test isempty(cb) + + # Test show method + cb = StreamCallback(out = IOBuffer(), flavor = OpenAIStream()) + str = sprint(show, cb) + @test occursin("StreamCallback(out=IOBuffer", str) + @test occursin("flavor=OpenAIStream()", str) + @test occursin("silent, no_throw", str) + + chunk = StreamChunk(event = :test, data = "{\"a\": 1}", json = JSON3.read("{\"a\": 1}")) + str = sprint(show, chunk) + @test occursin("StreamChunk(event=test", str) + @test occursin("data={\"a\": 1}", str) + @test occursin("json keys=a", str) + + push!(cb, chunk) + @test length(cb) == 1 + @test !isempty(cb) + empty!(cb) + @test isempty(cb) + + # Test configure_callback! method + cb, api_kwargs = configure_callback!(StreamCallback(), OpenAISchema()) + @test cb.flavor isa OpenAIStream + @test api_kwargs[:stream] == true + @test api_kwargs[:stream_options] == (include_usage = true,) + + cb, api_kwargs = configure_callback!(StreamCallback(), AnthropicSchema()) + @test cb.flavor isa AnthropicStream + @test api_kwargs[:stream] == true + + # Test error for unsupported schema + @test_throws ErrorException configure_callback!(StreamCallback(), GoogleSchema()) + + # Test configure_callback! with output stream + cb, _ = configure_callback!(IOBuffer(), OpenAISchema()) + @test cb isa StreamCallback + @test cb.out isa IOBuffer + @test cb.flavor isa OpenAIStream +end + +@testset "is_done" begin + # Test OpenAIStream + openai_flavor = PT.OpenAIStream() + + # Test when streaming is done + done_chunk = PT.StreamChunk(data = "[DONE]") + @test PT.is_done(openai_flavor, done_chunk) == true + + # Test when streaming is not done + not_done_chunk = PT.StreamChunk(data = "Some content") + @test PT.is_done(openai_flavor, not_done_chunk) == false + + # Test with empty data + empty_chunk = PT.StreamChunk(data = "") + @test PT.is_done(openai_flavor, empty_chunk) == false + + # Test AnthropicStream + anthropic_flavor = PT.AnthropicStream() + + # Test when streaming is done due to error + error_chunk = PT.StreamChunk(event = :error) + @test PT.is_done(anthropic_flavor, error_chunk) == true + + # Test when streaming is done due to message stop + stop_chunk = PT.StreamChunk(event = :message_stop) + @test PT.is_done(anthropic_flavor, stop_chunk) == true + + # Test when streaming is not done + continue_chunk = PT.StreamChunk(event = :content_block_start) + @test PT.is_done(anthropic_flavor, continue_chunk) == false + + # Test with nil event + nil_event_chunk = PT.StreamChunk(event = nothing) + @test PT.is_done(anthropic_flavor, nil_event_chunk) == false + + # Test with unsupported flavor + struct UnsupportedFlavor <: PT.AbstractStreamFlavor end + unsupported_flavor = UnsupportedFlavor() + @test_throws ArgumentError PT.is_done(unsupported_flavor, PT.StreamChunk()) +end + +@testset "extract_content" begin + # Test OpenAIStream + openai_flavor = PT.OpenAIStream() + + # Test with valid JSON content + valid_json_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "choices": [ + { + "delta": { + "content": "Hello, world!" + } + } + ] + } + """) + ) + @test PT.extract_content(openai_flavor, valid_json_chunk) == "Hello, world!" + + # Test with empty choices + empty_choices_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "choices": [] + } + """) + ) + @test isnothing(PT.extract_content(openai_flavor, empty_choices_chunk)) + + # Test with missing delta + missing_delta_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "choices": [ + { + "index": 0 + } + ] + } + """) + ) + @test isnothing(PT.extract_content(openai_flavor, missing_delta_chunk)) + + # Test with missing content in delta + missing_content_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "choices": [ + { + "delta": { + "role": "assistant" + } + } + ] + } + """) + ) + @test isnothing(PT.extract_content(openai_flavor, missing_content_chunk)) + + # Test with non-JSON chunk + non_json_chunk = PT.StreamChunk(data = "Plain text") + @test isnothing(PT.extract_content(openai_flavor, non_json_chunk)) + + # Test AnthropicStream + anthropic_flavor = PT.AnthropicStream() + + # Test with valid content block + valid_anthropic_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "content_block": { + "text": "Hello from Anthropic!" + } + } + """) + ) + @test PT.extract_content(anthropic_flavor, valid_anthropic_chunk) == + "Hello from Anthropic!" + + # Test with valid delta + valid_delta_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "delta": { + "text": "Delta text" + } + } + """) + ) + @test PT.extract_content(anthropic_flavor, valid_delta_chunk) == "Delta text" + + # Test with missing text in content block + missing_text_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "content_block": { + "type": "text" + } + } + """) + ) + @test isnothing(PT.extract_content(anthropic_flavor, missing_text_chunk)) + + # Test with non-zero index (should return nothing) + non_zero_index_chunk = PT.StreamChunk( + json = JSON3.read(""" + { + "index": 1, + "content_block": { + "text": "This should be ignored" + } + } + """) + ) + @test isnothing(PT.extract_content(anthropic_flavor, non_zero_index_chunk)) + + # Test with non-JSON chunk for Anthropic + non_json_anthropic_chunk = PT.StreamChunk(data = "Plain Anthropic text") + @test isnothing(PT.extract_content(anthropic_flavor, non_json_anthropic_chunk)) + + # Test with unsupported flavor + struct UnsupportedFlavor <: PT.AbstractStreamFlavor end + unsupported_flavor = UnsupportedFlavor() + @test_throws ArgumentError PT.extract_content(unsupported_flavor, PT.StreamChunk()) +end + +@testset "extract_chunks" begin + # Test basic functionality + blob = "event: start\ndata: {\"key\": \"value\"}\n\nevent: end\ndata: {\"status\": \"complete\"}\n\n" + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), blob) + @test length(chunks) == 2 + @test chunks[1].event == :start + @test chunks[1].json == JSON3.read("{\"key\": \"value\"}") + @test chunks[2].event == :end + @test chunks[2].json == JSON3.read("{\"status\": \"complete\"}") + @test spillover == "" + + # Test with spillover + blob_with_spillover = "event: start\ndata: {\"key\": \"value\"}\n\nevent: continue\ndata: {\"partial\": \"data" + @test_logs (:info, r"Incomplete message detected") chunks, spillover=PT.extract_chunks( + PT.OpenAIStream(), blob_with_spillover; verbose = true) + chunks, spillover = PT.extract_chunks( + PT.OpenAIStream(), blob_with_spillover; verbose = true) + @test length(chunks) == 1 + @test chunks[1].event == :start + @test chunks[1].json == JSON3.read("{\"key\": \"value\"}") + @test spillover == "{\"partial\": \"data" + + # Test with incoming spillover + incoming_spillover = spillover + blob_after_spillover = "\"}\n\nevent: end\ndata: {\"status\": \"complete\"}\n\n" + chunks, spillover = PT.extract_chunks( + PT.OpenAIStream(), blob_after_spillover; spillover = incoming_spillover) + @test length(chunks) == 2 + @test chunks[1].json == JSON3.read("{\"partial\": \"data\"}") + @test chunks[2].event == :end + @test chunks[2].json == JSON3.read("{\"status\": \"complete\"}") + @test spillover == "" + + # Test with multiple data fields per event + multi_data_blob = "event: multi\ndata: line1\ndata: line2\n\n" + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), multi_data_blob) + @test length(chunks) == 1 + @test chunks[1].event == :multi + @test chunks[1].data == "line1line2" + + # Test with non-JSON data + non_json_blob = "event: text\ndata: This is plain text\n\n" + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), non_json_blob) + @test length(chunks) == 1 + @test chunks[1].event == :text + @test chunks[1].data == "This is plain text" + @test isnothing(chunks[1].json) + + # Test with empty blob + empty_blob = "" + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), empty_blob) + @test isempty(chunks) + @test spillover == "" + + # Test with malformed JSON + malformed_json_blob = "event: error\ndata: {\"key\": \"value\",}\n\n" + chunks, spillover = PT.extract_chunks( + PT.OpenAIStream(), malformed_json_blob; verbose = true) + @test length(chunks) == 1 + @test chunks[1].event == :error + @test chunks[1].data == "{\"key\": \"value\",}" + @test isnothing(chunks[1].json) + + # Test with multiple data fields, no event + blob_no_event = "data: {\"key\": \"value\"}\n\ndata: {\"partial\": \"data\"}\n\ndata: {\"status\": \"complete\"}\n\n" + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), blob_no_event) + @test length(chunks) == 3 + @test chunks[1].data == "{\"key\": \"value\"}" + @test chunks[2].data == "{\"partial\": \"data\"}" + @test chunks[3].data == "{\"status\": \"complete\"}" + @test spillover == "" + + # Test case for s1: Multiple events and data chunks + s1 = """event: test + data: {"id":"chatcmpl-A3zvq9GWhji7h1Gz0gKNIn9r2tABJ","object":"chat.completion.chunk","created":1725516414,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f905cf32a9","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}]} + + event: test2 + data: {"id":"chatcmpl-A3zvq9GWhji7h1Gz0gKNIn9r2tABJ","object":"chat.completion.chunk","created":1725516414,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f905cf32a9","choices":[{"index":0,"delta":{"content":" "},"logprobs":null,"finish_reason":null}]} + + data: [DONE] + + """ + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), s1) + @test length(chunks) == 3 + @test chunks[1].event == :test + @test chunks[2].event == :test2 + @test chunks[3].data == "[DONE]" + @test spillover == "" + + @test PT.extract_content(PT.OpenAIStream(), chunks[1]) == "," + @test PT.extract_content(PT.OpenAIStream(), chunks[2]) == " " + + # Test case for s2: Multiple data chunks without events + s2 = """data: {"id":"chatcmpl-A3zvq9GWhji7h1Gz0gKNIn9r2tABJ","object":"chat.completion.chunk","created":1725516414,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f905cf32a9","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-A3zvq9GWhji7h1Gz0gKNIn9r2tABJ","object":"chat.completion.chunk","created":1725516414,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f905cf32a9","choices":[{"index":0,"delta":{"content":" "},"logprobs":null,"finish_reason":null}]} + + data: [DONE] + + """ + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), s2) + @test length(chunks) == 3 + @test all(chunk.event === nothing for chunk in chunks) + @test chunks[3].data == "[DONE]" + @test spillover == "" + + # Test case for s3: Simple data chunks + s3 = """data: a + data: b + data: c + + data: [DONE] + + """ + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), s3) + @test length(chunks) == 2 + @test chunks[1].data == "abc" + @test chunks[2].data == "[DONE]" + @test spillover == "" + + # Test case for s4a and s4b: Handling spillover + s4a = """event: test + data: {"id":"chatcmpl-A3zvq9GWhji7h1Gz0gKNIn9r2tABJ","object":"chat.completion.chunk","created":1725516414,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f905cf32a9","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}]} + + event: test2 + data: {"id":"chatcmpl-A3zvq9GWhji7h1Gz0gKNIn9r2tABJ","object":"chat.completion.chunk","created""" + s4b = """":1725516414,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f905cf32a9","choices":[{"index":0,"delta":{"content":" "},"logprobs":null,"finish_reason":null}]} + + data: [DONE] + + """ + chunks, spillover = PT.extract_chunks(PT.OpenAIStream(), s4a) + @test length(chunks) == 1 + @test chunks[1].event == :test + @test !isempty(spillover) + + chunks, final_spillover = PT.extract_chunks( + PT.OpenAIStream(), s4b; spillover = spillover) + @test length(chunks) == 2 + @test chunks[2].data == "[DONE]" + @test final_spillover == "" +end + +@testset "extract_content" begin + ### OpenAIStream + # Test case 1: Valid JSON with content + valid_chunk = PT.StreamChunk( + nothing, + """{"choices":[{"delta":{"content":"Hello"}}]}""", + JSON3.read("""{"choices":[{"delta":{"content":"Hello"}}]}""") + ) + @test PT.extract_content(PT.OpenAIStream(), valid_chunk) == "Hello" + + # Test case 2: Valid JSON without content + no_content_chunk = PT.StreamChunk( + nothing, + """{"choices":[{"delta":{}}]}""", + JSON3.read("""{"choices":[{"delta":{}}]}""") + ) + @test isnothing(PT.extract_content(PT.OpenAIStream(), no_content_chunk)) + + # Test case 3: Valid JSON with empty content + empty_content_chunk = PT.StreamChunk( + nothing, + """{"choices":[{"delta":{"content":""}}]}""", + JSON3.read("""{"choices":[{"delta":{"content":""}}]}""") + ) + @test PT.extract_content(PT.OpenAIStream(), empty_content_chunk) == "" + + # Test case 4: Invalid JSON structure + invalid_chunk = PT.StreamChunk( + nothing, + """{"invalid":"structure"}""", + JSON3.read("""{"invalid":"structure"}""") + ) + @test isnothing(PT.extract_content(PT.OpenAIStream(), invalid_chunk)) + + # Test case 5: Chunk with non-JSON data + non_json_chunk = PT.StreamChunk( + nothing, + "This is not JSON", + nothing + ) + @test isnothing(PT.extract_content(PT.OpenAIStream(), non_json_chunk)) + + # Test case 6: Multiple choices (should still return first choice) + multiple_choices_chunk = PT.StreamChunk( + nothing, + """{"choices":[{"delta":{"content":"First"}},{"delta":{"content":"Second"}}]}""", + JSON3.read("""{"choices":[{"delta":{"content":"First"}},{"delta":{"content":"Second"}}]}""") + ) + @test PT.extract_content(PT.OpenAIStream(), multiple_choices_chunk) == "First" + + ### AnthropicStream + # Test case 1: Valid JSON with content in content_block + valid_chunk = PT.StreamChunk( + nothing, + """{"index":0,"content_block":{"text":"Hello from Anthropic"}}""", + JSON3.read("""{"index":0,"content_block":{"text":"Hello from Anthropic"}}""") + ) + @test PT.extract_content(PT.AnthropicStream(), valid_chunk) == "Hello from Anthropic" + + # Test case 2: Valid JSON with content in delta + delta_chunk = PT.StreamChunk( + nothing, + """{"index":0,"delta":{"text":"Delta content"}}""", + JSON3.read("""{"index":0,"delta":{"text":"Delta content"}}""") + ) + @test PT.extract_content(PT.AnthropicStream(), delta_chunk) == "Delta content" + + # Test case 3: Valid JSON without text in content_block + no_text_chunk = PT.StreamChunk( + nothing, + """{"index":0,"content_block":{"type":"text"}}""", + JSON3.read("""{"index":0,"content_block":{"type":"text"}}""") + ) + @test isnothing(PT.extract_content(PT.AnthropicStream(), no_text_chunk)) + + # Test case 4: Valid JSON with non-zero index + non_zero_index_chunk = PT.StreamChunk( + nothing, + """{"index":1,"content_block":{"text":"Should be ignored"}}""", + JSON3.read("""{"index":1,"content_block":{"text":"Should be ignored"}}""") + ) + @test isnothing(PT.extract_content(PT.AnthropicStream(), non_zero_index_chunk)) + + # Test case 5: Chunk with non-JSON data + non_json_chunk = PT.StreamChunk( + nothing, + "This is not JSON", + nothing + ) + @test isnothing(PT.extract_content(PT.AnthropicStream(), non_json_chunk)) + + # Test case 6: Valid JSON with empty content + empty_content_chunk = PT.StreamChunk( + nothing, + """{"index":0,"content_block":{"text":""}}""", + JSON3.read("""{"index":0,"content_block":{"text":""}}""") + ) + @test PT.extract_content(PT.AnthropicStream(), empty_content_chunk) == "" + + # Test case 7: Unknown flavor + struct UnknownFlavor <: PT.AbstractStreamFlavor end + unknown_flavor = UnknownFlavor() + unknown_chunk = PT.StreamChunk( + nothing, + """{"content": "Test content"}""", + JSON3.read("""{"content": "Test content"}""") + ) + @test_throws ArgumentError PT.extract_content(unknown_flavor, unknown_chunk) +end + +@testset "print_content" begin + # Test printing to IO + io = IOBuffer() + PT.print_content(io, "Test content") + @test String(take!(io)) == "Test content" + + # Test printing to Channel + ch = Channel{String}(1) + PT.print_content(ch, "Channel content") + @test take!(ch) == "Channel content" + + # Test printing to nothing + @test PT.print_content(nothing, "No output") === nothing +end + +@testset "callback" begin + # Test with valid content + io = IOBuffer() + cb = PT.StreamCallback(out = io, flavor = PT.OpenAIStream()) + valid_chunk = PT.StreamChunk( + nothing, + """{"choices":[{"delta":{"content":"Hello"}}]}""", + JSON3.read("""{"choices":[{"delta":{"content":"Hello"}}]}""") + ) + PT.callback(cb, valid_chunk) + @test String(take!(io)) == "Hello" + + # Test with no content + io = IOBuffer() + cb = PT.StreamCallback(out = io, flavor = PT.OpenAIStream()) + no_content_chunk = PT.StreamChunk( + nothing, + """{"choices":[{"delta":{}}]}""", + JSON3.read("""{"choices":[{"delta":{}}]}""") + ) + PT.callback(cb, no_content_chunk) + @test isempty(take!(io)) + + # Test with Channel output + ch = Channel{String}(1) + cb = PT.StreamCallback(out = ch, flavor = PT.OpenAIStream()) + PT.callback(cb, valid_chunk) + @test take!(ch) == "Hello" + + # Test with nothing output + cb = PT.StreamCallback(out = nothing, flavor = PT.OpenAIStream()) + @test PT.callback(cb, valid_chunk) === nothing +end + +@testset "build_response_body-OpenAIStream" begin + # Test case 1: Empty chunks + cb_empty = PT.StreamCallback() + response = PT.build_response_body(PT.OpenAIStream(), cb_empty) + @test isnothing(response) + + # Test case 2: Single complete chunk + cb_single = PT.StreamCallback() + push!(cb_single.chunks, + PT.StreamChunk( + nothing, + """{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}""", + JSON3.read("""{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}""") + )) + response = PT.build_response_body(PT.OpenAIStream(), cb_single) + @test response[:id] == "chatcmpl-123" + @test response[:object] == "chat.completion" + @test response[:model] == "gpt-4" + @test length(response[:choices]) == 1 + @test response[:choices][1][:index] == 0 + @test response[:choices][1][:message][:role] == "assistant" + @test response[:choices][1][:message][:content] == "Hello" + + # Test case 3: Multiple chunks forming a complete response + cb_multiple = PT.StreamCallback() + push!(cb_multiple.chunks, + PT.StreamChunk( + nothing, + """{"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}""", + JSON3.read("""{"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}""") + )) + push!(cb_multiple.chunks, + PT.StreamChunk( + nothing, + """{"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}""", + JSON3.read("""{"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}""") + )) + push!(cb_multiple.chunks, + PT.StreamChunk( + nothing, + """{"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}""", + JSON3.read("""{"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}""") + )) + response = PT.build_response_body(PT.OpenAIStream(), cb_multiple) + @test response[:id] == "chatcmpl-456" + @test response[:object] == "chat.completion" + @test response[:model] == "gpt-4" + @test length(response[:choices]) == 1 + @test response[:choices][1][:index] == 0 + @test response[:choices][1][:message][:role] == "assistant" + @test response[:choices][1][:message][:content] == "Hello world" + @test response[:choices][1][:finish_reason] == "stop" + + # Test case 4: Multiple choices + cb_multi_choice = PT.StreamCallback() + push!(cb_multi_choice.chunks, + PT.StreamChunk( + nothing, + """{"id":"chatcmpl-789","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"First"},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":"Second"},"finish_reason":null}]}""", + JSON3.read("""{"id":"chatcmpl-789","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"First"},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":"Second"},"finish_reason":null}]}""") + )) + response = PT.build_response_body(PT.OpenAIStream(), cb_multi_choice) + @test response[:id] == "chatcmpl-789" + @test length(response[:choices]) == 2 + @test response[:choices][1][:index] == 0 + @test response[:choices][1][:message][:content] == "First" + @test response[:choices][2][:index] == 1 + @test response[:choices][2][:message][:content] == "Second" + + # Test case 5: Usage information + cb_usage = PT.StreamCallback() + push!(cb_usage.chunks, + PT.StreamChunk( + nothing, + """{"id":"chatcmpl-101112","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Test"},"finish_reason":null}],"usage":{"prompt_tokens":10,"completion_tokens":1,"total_tokens":11}}""", + JSON3.read("""{"id":"chatcmpl-101112","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Test"},"finish_reason":null}],"usage":{"prompt_tokens":10,"completion_tokens":1,"total_tokens":11}}""") + )) + response = PT.build_response_body(PT.OpenAIStream(), cb_usage) + @test response[:usage][:prompt_tokens] == 10 + @test response[:usage][:completion_tokens] == 1 + @test response[:usage][:total_tokens] == 11 +end +@testset "build_response_body-AnthropicStream" begin + # Test case 1: Empty chunks + cb_empty = PT.StreamCallback(flavor = PT.AnthropicStream()) + response = PT.build_response_body(PT.AnthropicStream(), cb_empty) + @test isnothing(response) + + # Test case 2: Single message + cb_single = PT.StreamCallback(flavor = PT.AnthropicStream()) + push!(cb_single.chunks, + PT.StreamChunk( + :message_start, + """{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""", + JSON3.read("""{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""") + )) + response = PT.build_response_body(PT.AnthropicStream(), cb_single) + @test response[:content][1][:type] == "text" + @test response[:content][1][:text] == "" + @test response[:model] == "claude-2" + @test isnothing(response[:stop_reason]) + @test isnothing(response[:stop_sequence]) + + # Test case 3: Multiple content blocks + cb_multiple = PT.StreamCallback(flavor = PT.AnthropicStream()) + push!(cb_multiple.chunks, + PT.StreamChunk( + :message_start, + """{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""", + JSON3.read("""{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""") + )) + push!(cb_multiple.chunks, + PT.StreamChunk( + :content_block_start, + """{"content_block":{"type":"text","text":"Hello"}}""", + JSON3.read("""{"content_block":{"type":"text","text":"Hello"}}""") + )) + push!(cb_multiple.chunks, + PT.StreamChunk( + :content_block_delta, + """{"delta":{"type":"text","text":" world"}}""", + JSON3.read("""{"delta":{"type":"text","text":" world"}}""") + )) + push!(cb_multiple.chunks, + PT.StreamChunk( + :content_block_stop, + """{"content_block":{"type":"text","text":"!"}}""", + JSON3.read("""{"content_block":{"type":"text","text":"!"}}""") + )) + response = PT.build_response_body(PT.AnthropicStream(), cb_multiple) + @test response[:content][1][:type] == "text" + @test response[:content][1][:text] == "Hello world!" + @test response[:model] == "claude-2" + + # Test case 4: With usage information + cb_usage = PT.StreamCallback(flavor = PT.AnthropicStream()) + push!(cb_usage.chunks, + PT.StreamChunk( + :message_start, + """{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":5}}}""", + JSON3.read("""{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":5}}}""") + )) + push!(cb_usage.chunks, + PT.StreamChunk( + :content_block_start, + """{"content_block":{"type":"text","text":"Test"}}""", + JSON3.read("""{"content_block":{"type":"text","text":"Test"}}""") + )) + push!(cb_usage.chunks, + PT.StreamChunk( + :message_delta, + """{"delta":{"stop_reason": "end_turn"},"usage":{"output_tokens":7}}""", + JSON3.read("""{"delta":{"stop_reason": "end_turn"},"usage":{"output_tokens":7}}""") + )) + response = PT.build_response_body(PT.AnthropicStream(), cb_usage) + @test response[:content][1][:type] == "text" + @test response[:content][1][:text] == "Test" + @test response[:usage][:input_tokens] == 10 + @test response[:usage][:output_tokens] == 7 + @test response[:stop_reason] == "end_turn" + + # Test case 5: With stop reason + cb_stop = PT.StreamCallback(flavor = PT.AnthropicStream()) + push!(cb_stop.chunks, + PT.StreamChunk( + :message_start, + """{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""", + JSON3.read("""{"message":{"content":[],"model":"claude-2","stop_reason":null,"stop_sequence":null}}""") + )) + push!(cb_stop.chunks, + PT.StreamChunk( + :content_block_start, + """{"content_block":{"type":"text","text":"Final"}}""", + JSON3.read("""{"content_block":{"type":"text","text":"Final"}}""") + )) + push!(cb_stop.chunks, + PT.StreamChunk( + :message_delta, + """{"delta":{"stop_reason":"max_tokens","stop_sequence":null}}""", + JSON3.read("""{"delta":{"stop_reason":"max_tokens","stop_sequence":null}}""") + )) + response = PT.build_response_body(PT.AnthropicStream(), cb_stop) + @test response[:content][1][:type] == "text" + @test response[:content][1][:text] == "Final" + @test response[:stop_reason] == "max_tokens" + @test isnothing(response[:stop_sequence]) +end + +@testset "handle_error_message" begin + # Test case 1: No error + chunk = PT.StreamChunk(:content, "Normal content", nothing) + @test isnothing(PT.handle_error_message(chunk)) + + # Test case 2: Error event + error_chunk = PT.StreamChunk(:error, "Error occurred", nothing) + @test_logs (:warn, "Error detected in the streaming response: Error occurred") PT.handle_error_message(error_chunk) + + # Test case 4: Detailed error in JSON + obj = Dict(:error => Dict(:message => "Invalid input", :type => "user_error")) + detailed_error_chunk = PT.StreamChunk( + nothing, JSON3.write(obj), JSON3.read(JSON3.write(obj))) + @test_logs (:warn, + r"Message: Invalid input") PT.handle_error_message(detailed_error_chunk) + @test_logs (:warn, + r"Type: user_error") PT.handle_error_message(detailed_error_chunk) + + # Test case 5: Throw on error + @test_throws Exception PT.handle_error_message(error_chunk, throw_on_error = true) +end + +## Not working yet!! +# @testset "streamed_request!" begin +# # Setup mock server +# PORT = rand(10000:20000) +# server = HTTP.serve!(PORT; verbose = false) do request +# if request.method == "POST" && request.target == "/v1/chat/completions" +# # Simulate streaming response +# return HTTP.Response() do io +# write(io, "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n") +# write(io, "data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n") +# write(io, "data: [DONE]\n\n") +# end +# else +# return HTTP.Response(404, "Not found") +# end +# end + +# # Test streamed_request! +# url = "http://localhost:$PORT/v1/chat/completions" +# headers = ["Content-Type" => "application/json"] +# input = IOBuffer(JSON3.write(Dict( +# "model" => "gpt-3.5-turbo", +# "messages" => [Dict("role" => "user", "content" => "Say hello")] +# ))) + +# cb = PT.StreamCallback(flavor = PT.OpenAIStream()) +# response = PT.streamed_request!(cb, url, headers, input) + +# # Assertions +# @test response.status == 200 +# @test length(cb.chunks) == 3 +# @test cb.chunks[1].json.choices[1].delta.content == "Hello" +# @test cb.chunks[2].json.choices[1].delta.content == " world" +# @test cb.chunks[3].data == "[DONE]" + +# # Test build_response_body +# body = PT.build_response_body(PT.OpenAIStream(), cb) +# @test body[:choices][1][:message][:content] == "Hello world" +# # Cleanup +# close(server) +# end \ No newline at end of file