Skip to content

Commit

Permalink
Add Cerebras API + node validation for airetry!
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Oct 9, 2024
1 parent ea8d51a commit a66da99
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 38 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.58.0]

### Added
- Added support for [Cerebras](https://cloud.cerebras.ai) hosted models (set your ENV `CEREBRAS_API_KEY`). Available model aliases: `cl3` (Llama3.1 8bn), `cl70` (Llama3.1 70bn).
- Added a kwarg to `aiclassify` to provide a custom token ID mapping (`token_ids_map`) to work with custom tokenizers.

### Updated
- Improved the implementation of `airetry!` to concatenate feedback from all ancestor nodes ONLY IF `feedback_inplace=true` (because otherwise LLM can see it in the message history).

### Fixed
- Fixed a potential bug in `airetry!` where the `aicall` object was not properly validated to ensure it has been `run!` first.

## [0.57.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.57.0"
version = "0.58.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
15 changes: 12 additions & 3 deletions src/Experimental/AgentTools/retry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
verbose::Bool = true, throw::Bool = false, evaluate_all::Bool = true, feedback_expensive::Bool = false,
max_retries::Union{Nothing, Int} = nothing, retry_delay::Union{Nothing, Int} = nothing)
Evaluates the condition `f_cond` on the `aicall` object.
Evaluates the condition `f_cond` on the `aicall` object.
If the condition is not met, it will return the best sample to retry from and provide `feedback` (string or function) to `aicall`. That's why it's mutating.
It will retry maximum `max_retries` times, with `throw=true`, an error will be thrown if the condition is not met after `max_retries` retries.
Note: `aicall` must be run first via `run!(aicall)` before calling `airetry!`.
Function signatures
- `f_cond(aicall::AICallBlock) -> Bool`, ie, it must accept the aicall object and return a boolean value.
- `feedback` can be a string or `feedback(aicall::AICallBlock) -> String`, ie, it must accept the aicall object and return a string.
Expand Down Expand Up @@ -286,6 +288,9 @@ function airetry!(f_cond::Function, aicall::AICallBlock,
(; config) = aicall
(; max_calls, feedback_inplace, feedback_template) = aicall.config

## Validate that the aicall has been run first
@assert aicall.success isa Bool "Provided `aicall` has not been run yet. Use `run!(aicall)` first, before calling `airetry!` to check the condition."

max_retries = max_retries isa Nothing ? config.max_retries : max_retries
retry_delay = retry_delay isa Nothing ? config.retry_delay : retry_delay
verbose = min(verbose, get(aicall.kwargs, :verbose, 99))
Expand Down Expand Up @@ -505,8 +510,12 @@ conversation[end].content ==
function add_feedback!(conversation::AbstractVector{<:PT.AbstractMessage},
sample::SampleNode; feedback_inplace::Bool = false,
feedback_template::Symbol = :FeedbackFromEvaluator)
##
all_feedback = collect_all_feedback(sample)
## If you use in-place feedback, collect all feedback from ancestors (because you won't see the history otherwise)
all_feedback = if feedback_inplace
collect_all_feedback(sample)
else
sample.feedback
end
## short circuit if no feedback
if strip(all_feedback) == ""
return conversation
Expand Down
14 changes: 14 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ Requires one environment variable to be set:
"""
struct OpenRouterOpenAISchema <: AbstractOpenAISchema end

"""
CerebrasOpenAISchema
Schema to call the [Cerebras](https://cerebras.ai/) API.
Links:
- [Get your API key](https://cloud.cerebras.ai)
- [API Reference](https://inference-docs.cerebras.ai/api-reference/chat-completions)
Requires one environment variable to be set:
- `CEREBRAS_API_KEY`: Your API key
"""
struct CerebrasOpenAISchema <: AbstractOpenAISchema end

abstract type AbstractOllamaSchema <: AbstractPromptSchema end

"""
Expand Down
93 changes: 63 additions & 30 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@ function OpenAI.create_chat(schema::OpenRouterOpenAISchema,
api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::CerebrasOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.cerebras.ai/v1",
kwargs...)
api_key = isempty(CEREBRAS_API_KEY) ? api_key : CEREBRAS_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::DatabricksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
Expand Down Expand Up @@ -272,19 +281,20 @@ function OpenAI.create_chat(schema::DatabricksOpenAISchema,
end
end
function OpenAI.create_chat(schema::AzureOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)
api_key::AbstractString,
model::AbstractString,
conversation;
api_version::String = "2023-03-15-preview",
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) *
"/openai/deployments/$model",
api_version = api_version
)
# Override standard OpenAI request endpoint
Expand All @@ -297,7 +307,7 @@ function OpenAI.create_chat(schema::AzureOpenAISchema,
query = Dict("api-version" => provider.api_version),
streamcallback = streamcallback,
kwargs...
)
)
end

# Extend OpenAI create_embeddings to allow for testing
Expand Down Expand Up @@ -396,17 +406,18 @@ function OpenAI.create_embeddings(schema::FireworksOpenAISchema,
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::AzureOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)
api_key::AbstractString,
docs,
model::AbstractString;
api_version::String = "2023-03-15-preview",
url::String = "https://<resource-name>.openai.azure.com",
kwargs...)

# Build the corresponding provider object
provider = OpenAI.AzureProvider(;
api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY,
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model",
base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) *
"/openai/deployments/$model",
api_version = api_version)
# Override standard OpenAI request endpoint
OpenAI.openai_request(
Expand Down Expand Up @@ -851,11 +862,15 @@ const OPENAI_TOKEN_IDS_GPT4O = Dict(
"38" => 3150,
"39" => 3255,
"40" => 1723)
## Note: You can provide your own token IDs map to `encode_choices` to use a custom mapping via kwarg: token_ids_map

function pick_tokenizer(model::AbstractString)
function pick_tokenizer(model::AbstractString;
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing)
global OPENAI_TOKEN_IDS_GPT35_GPT4, OPENAI_TOKEN_IDS_GPT4O
OPENAI_TOKEN_IDS = if model == "gpt-4" || startswith(model, "gpt-3.5") ||
startswith(model, "gpt-4-")
OPENAI_TOKEN_IDS = if !isnothing(token_ids_map)
token_ids_map
elseif (model == "gpt-4" || startswith(model, "gpt-3.5") ||
startswith(model, "gpt-4-"))
OPENAI_TOKEN_IDS_GPT35_GPT4
elseif startswith(model, "gpt-4o")
OPENAI_TOKEN_IDS_GPT4O
Expand All @@ -866,10 +881,15 @@ function pick_tokenizer(model::AbstractString)
end

"""
encode_choices(schema::OpenAISchema, choices::AbstractVector{<:AbstractString}; kwargs...)
encode_choices(schema::OpenAISchema, choices::AbstractVector{<:AbstractString};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
encode_choices(schema::OpenAISchema, choices::AbstractVector{T};
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
Encode the choices into an enumerated list that can be interpolated into the prompt and creates the corresponding logit biases (to choose only from the selected tokens).
Expand All @@ -880,6 +900,8 @@ There can be at most 40 choices provided.
# Arguments
- `schema::OpenAISchema`: The OpenAISchema object.
- `choices::AbstractVector{<:Union{AbstractString,Tuple{<:AbstractString, <:AbstractString}}}`: The choices to be encoded, represented as a vector of the choices directly, or tuples where each tuple contains a choice and its description.
- `model::AbstractString`: The model to use for encoding. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
- `token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing`: A dictionary mapping custom token IDs to their corresponding integer values. If `nothing`, it will use the default token IDs for the given model.
- `kwargs...`: Additional keyword arguments.
# Returns
Expand Down Expand Up @@ -908,8 +930,9 @@ logit_bias # Output: Dict(16 => 100, 17 => 100, 18 => 100)
function encode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
## if all choices are in the dictionary, use the dictionary
if all(Base.Fix1(haskey, OPENAI_TOKEN_IDS), choices)
choices_prompt = ["$c for \"$c\"" for c in choices]
Expand All @@ -927,8 +950,9 @@ end
function encode_choices(schema::OpenAISchema,
choices::AbstractVector{T};
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Tuple{<:AbstractString, <:AbstractString}}
OPENAI_TOKEN_IDS = pick_tokenizer(model)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
## if all choices are in the dictionary, use the dictionary
if all(Base.Fix1(haskey, OPENAI_TOKEN_IDS), first.(choices))
choices_prompt = ["$c for \"$desc\"" for (c, desc) in choices]
Expand Down Expand Up @@ -958,6 +982,7 @@ end

function decode_choices(schema::OpenAISchema, choices, conv::AbstractVector;
model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
if length(conv) > 0 && last(conv) isa AIMessage && hasproperty(last(conv), :run_id)
## if it is a multi-sample response,
Expand All @@ -966,7 +991,7 @@ function decode_choices(schema::OpenAISchema, choices, conv::AbstractVector;
for i in eachindex(conv)
msg = conv[i]
if isaimessage(msg) && msg.run_id == run_id
conv[i] = decode_choices(schema, choices, msg; model)
conv[i] = decode_choices(schema, choices, msg; model, token_ids_map)
end
end
end
Expand All @@ -976,16 +1001,20 @@ end
"""
decode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString},
msg::AIMessage; model::AbstractString, kwargs...)
msg::AIMessage; model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
Decodes the underlying AIMessage against the original choices to lookup what the category name was.
If it fails, it will return `msg.content == nothing`
"""
function decode_choices(schema::OpenAISchema,
choices::AbstractVector{<:AbstractString},
msg::AIMessage; model::AbstractString, kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model)
msg::AIMessage; model::AbstractString,
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...)
OPENAI_TOKEN_IDS = pick_tokenizer(model; token_ids_map)
parsed_digit = tryparse(Int, strip(msg.content))
if !isnothing(parsed_digit) && haskey(OPENAI_TOKEN_IDS, strip(msg.content))
## It's encoded
Expand All @@ -1006,6 +1035,7 @@ end
choices::AbstractVector{T} = ["true", "false", "unknown"],
model::AbstractString = MODEL_CHAT,
api_kwargs::NamedTuple = NamedTuple(),
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <: Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
Classifies the given prompt/statement into an arbitrary list of `choices`, which must be only the choices (vector of strings) or choices and descriptions are provided (vector of tuples, ie, `("choice","description")`).
Expand All @@ -1025,6 +1055,7 @@ It uses Logit bias trick and limits the output to 1 token to force the model to
- `choices::AbstractVector{T}`: The choices to be classified into. It can be a vector of strings or a vector of tuples, where the first element is the choice and the second is the description.
- `model::AbstractString = MODEL_CHAT`: The model to use for classification. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
- `api_kwargs::NamedTuple = NamedTuple()`: Additional keyword arguments for the API call.
- `token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing`: A dictionary mapping custom token IDs to their corresponding integer values. If `nothing`, it will use the default token IDs for the given model.
- `kwargs`: Additional keyword arguments for the prompt template.
# Example
Expand Down Expand Up @@ -1085,12 +1116,13 @@ function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_
choices::AbstractVector{T} = ["true", "false", "unknown"],
model::AbstractString = MODEL_CHAT,
api_kwargs::NamedTuple = NamedTuple(),
token_ids_map::Union{Nothing, Dict{<:AbstractString, <:Integer}} = nothing,
kwargs...) where {T <:
Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
## Encode the choices and the corresponding prompt
model_id = get(MODEL_ALIASES, model, model)
choices_prompt, logit_bias, decode_ids = encode_choices(
prompt_schema, choices; model = model_id)
prompt_schema, choices; model = model_id, token_ids_map)
## We want only 1 token
api_kwargs = merge(api_kwargs, (; logit_bias, max_tokens = 1, temperature = 0))
msg_or_conv = aigenerate(prompt_schema,
Expand All @@ -1099,7 +1131,8 @@ function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_
model = model_id,
api_kwargs,
kwargs...)
return decode_choices(prompt_schema, decode_ids, msg_or_conv; model = model_id)
return decode_choices(
prompt_schema, decode_ids, msg_or_conv; model = model_id, token_ids_map)
end

function response_to_message(schema::AbstractOpenAISchema,
Expand Down
24 changes: 23 additions & 1 deletion src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Check your preferences by calling `get_preferences(key::String)`.
- `GROQ_API_KEY`: The API key for the Groq API. Free in beta! Get yours from [here](https://console.groq.com/keys).
- `DEEPSEEK_API_KEY`: The API key for the DeepSeek API. Get \$5 credit when you join. Get yours from [here](https://platform.deepseek.com/api_keys).
- `OPENROUTER_API_KEY`: The API key for the OpenRouter API. Get yours from [here](https://openrouter.ai/keys).
- `CEREBRAS_API_KEY`: The API key for the Cerebras API. Get yours from [here](https://cloud.cerebras.ai/).
- `MODEL_CHAT`: The default model to use for aigenerate and most ai* calls. See `MODEL_REGISTRY` for a list of available models or define your own.
- `MODEL_EMBEDDING`: The default model to use for aiembed (embedding documents). See `MODEL_REGISTRY` for a list of available models or define your own.
- `PROMPT_SCHEMA`: The default prompt schema to use for aigenerate and most ai* calls (if not specified in `MODEL_REGISTRY`). Set as a string, eg, `"OpenAISchema"`.
Expand Down Expand Up @@ -55,6 +56,7 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava
- `GROQ_API_KEY`: The API key for the Groq API. Free in beta! Get yours from [here](https://console.groq.com/keys).
- `DEEPSEEK_API_KEY`: The API key for the DeepSeek API. Get \$5 credit when you join. Get yours from [here](https://platform.deepseek.com/api_keys).
- `OPENROUTER_API_KEY`: The API key for the OpenRouter API. Get yours from [here](https://openrouter.ai/keys).
- `CEREBRAS_API_KEY`: The API key for the Cerebras API.
- `LOG_DIR`: The directory to save the logs to, eg, when using `SaverSchema <: AbstractTracerSchema`. Defaults to `joinpath(pwd(), "log")`. Refer to `?SaverSchema` for more information on how it works and examples.
Preferences.jl takes priority over ENV variables, so if you set a preference, it will take precedence over the ENV variable.
Expand All @@ -78,6 +80,7 @@ const ALLOWED_PREFERENCES = ["MISTRALAI_API_KEY",
"GROQ_API_KEY",
"DEEPSEEK_API_KEY",
"OPENROUTER_API_KEY", # Added OPENROUTER_API_KEY
"CEREBRAS_API_KEY",
"MODEL_CHAT",
"MODEL_EMBEDDING",
"MODEL_ALIASES",
Expand Down Expand Up @@ -159,6 +162,7 @@ global VOYAGE_API_KEY::String = ""
global GROQ_API_KEY::String = ""
global DEEPSEEK_API_KEY::String = ""
global OPENROUTER_API_KEY::String = ""
global CEREBRAS_API_KEY::String = ""
global LOCAL_SERVER::String = ""
global LOG_DIR::String = ""

Expand Down Expand Up @@ -216,6 +220,9 @@ function load_api_keys!()
global OPENROUTER_API_KEY # Added OPENROUTER_API_KEY
OPENROUTER_API_KEY = @load_preference("OPENROUTER_API_KEY",
default=get(ENV, "OPENROUTER_API_KEY", ""))
global CEREBRAS_API_KEY
CEREBRAS_API_KEY = @load_preference("CEREBRAS_API_KEY",
default=get(ENV, "CEREBRAS_API_KEY", ""))
global LOCAL_SERVER
LOCAL_SERVER = @load_preference("LOCAL_SERVER",
default=get(ENV, "LOCAL_SERVER", ""))
Expand Down Expand Up @@ -410,6 +417,11 @@ aliases = merge(
"gll" => "llama-3.1-405b-reasoning", #l for large
"gmixtral" => "mixtral-8x7b-32768",
"ggemma9" => "gemma2-9b-it",
## Cerebras
"cl3" => "llama3.1-8b",
"cllama3" => "llama3.1-8b",
"cl70" => "llama3.1-70b",
"cllama70" => "llama3.1-70b",
## DeepSeek
"dschat" => "deepseek-chat",
"dscode" => "deepseek-coder",
Expand Down Expand Up @@ -885,7 +897,17 @@ registry = Dict{String, ModelSpec}(
OpenRouterOpenAISchema(),
2e-6,
2e-6,
"Meta's Llama3.1 405b, hosted by OpenRouter. This is a BASE model!! Max output 32K tokens, 131K context. See details [here](https://openrouter.ai/models/meta-llama/llama-3.1-405b)")
"Meta's Llama3.1 405b, hosted by OpenRouter. This is a BASE model!! Max output 32K tokens, 131K context. See details [here](https://openrouter.ai/models/meta-llama/llama-3.1-405b)"),
"llama3.1-8b" => ModelSpec("llama3.1-8b",
CerebrasOpenAISchema(),
1e-7,
1e-7,
"Meta's Llama3.1 8b, hosted by Cerebras.ai. Max 8K context."),
"llama3.1-70b" => ModelSpec("llama3.1-70b",
CerebrasOpenAISchema(),
6e-7,
6e-7,
"Meta's Llama3.1 70b, hosted by Cerebras.ai. Max 8K context.")
)

"""
Expand Down
Loading

0 comments on commit a66da99

Please sign in to comment.