Skip to content

Commit

Permalink
ChainOfThought + Groq
Browse files Browse the repository at this point in the history
  • Loading branch information
thmsmlr committed Sep 25, 2024
1 parent 0db54d0 commit 1abd847
Show file tree
Hide file tree
Showing 10 changed files with 501 additions and 22 deletions.
4 changes: 3 additions & 1 deletion .formatter.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Used by "mix format"
[
inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"]
import_deps: [:ecto, :phoenix, :phoenix_live_view],
plugins: [Phoenix.LiveView.HTMLFormatter],
inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}", "pages/cookbook/**/*.{ex,exs}"]
]
19 changes: 2 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,12 @@ end

## TODO

- [x] Top-level array support
- [x] Gemini
- [x] tools mode
- [ ] json mode
- [x] json_schema mode
- [x] Figure out a solution for OpenAI's json_schema mode not supporting arbitrary maps.
- [ ] Partial Schemaless doesn't work since fields are set to required in Ecto.


- [ ] llamacpp adapter broken, needs to support openai input/output API
- [ ] GBNF should enforce required properties on objects, currently they're optional.
- [ ] GBNF limit the number of digits in number tokens -- small models can sometimes run off to infinit digits
- [ ] Add instructor tests against llamacpp interface using mocks, there's non-trivial logic in there
- [ ] Groq adapter
- [ ] ChainOfThought doesn't work with max_retries
- [ ] Logging for Distillation / Finetuning
- [ ] Add a Bumblebee adapter
- [ ] Add llamacpp_ex adapter
- [ ] Support naked ecto types by auto-wrapping, not just maps of ecto types, do not wrap if we don't need to... Current codepaths are muddled
- [x] Support Streaming
- [ ] Verify schemaless support `{:array, %{name: :string}}`
- [ ] Support typespec style support for array streaming `[MySchema]`
- [ ] Optional/Maybe types
- [ ] Add Livebook Tutorials, include in Hexdocs
- [x] Text Classification
Expand All @@ -160,7 +146,6 @@ end
- [ ] Multi-File Code Generation
- [ ] PII Data Sanitizatiommersed
- [x] Update hexdocs homepage to include example for tutorial
- [ ] Setup Github CI for testing, add badge to README

## Blog Posts

Expand Down
35 changes: 35 additions & 0 deletions lib/instructor/adapters/groq.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
defmodule Instructor.Adapters.Groq do
@moduledoc """
Adapter for Groq Cloud API. Using the OpenAI API compatible endpoint.
"""
alias Instructor.Adapters

@behaviour Instructor.Adapter
@supported_modes [:tools]

@impl true
def chat_completion(params, user_config \\ nil) do
config = config(user_config)
mode = params[:mode]

if mode not in @supported_modes do
raise "Unsupported mode #{mode} for Groq"
end

Adapters.OpenAI.chat_completion(params, config)
end

@impl true
defdelegate reask_messages(raw_response, params, config), to: Adapters.OpenAI

defp config(nil), do: config(Application.get_env(:instructor, :groq, []))

defp config(base_config) do
default_config = [
api_url: "https://api.groq.com/openai",
http_options: [receive_timeout: 60_000]
]

Keyword.merge(default_config, base_config)
end
end
11 changes: 8 additions & 3 deletions lib/instructor/adapters/openai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,15 @@ defmodule Instructor.Adapters.OpenAI do

defp parse_stream_chunk_for_mode(:tools, %{
"choices" => [
%{"delta" => %{"content" => chunk}}
%{"delta" => delta}
]
}),
do: chunk
}) do
case delta do
nil -> ""
%{} -> ""
%{"content" => chunk} -> chunk
end
end

defp parse_stream_chunk_for_mode(_, %{"choices" => [%{"finish_reason" => "stop"}]}), do: ""

Expand Down
118 changes: 118 additions & 0 deletions lib/instructor/extras/chain_of_thought.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
defmodule Instructor.Extras.ChainOfThought do
defmodule ReasoningStep do
use Ecto.Schema

@doc """
For each step, provide a title that describes what you're doing in that step, along with the content.
Decide if you need another step or if you're ready to give the final answer.
Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys.
"""
@primary_key false
embedded_schema do
field(:title, :string)
field(:content, :string)
field(:next_action, Ecto.Enum, values: [:final_answer, :continue])
end
end

def chat_completion(params, config \\ nil) do
reasoning_steps = Keyword.pop(params, :reasoning_steps, 3)
response_model = params[:response_model]

initial_messages =
[
%{
role: "system",
content: """
You are an expert AI assistant that explains your reasoning step by step.
For each step, provide a title that describes what you're doing in that step, along with the content.
Decide if you need another step or if you're ready to give the final answer.
Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys.
USE AS MANY REASONING STEPS AS POSSIBLE.
AT LEAST 3.
# ... (rest of the system message)
"""
}
] ++
params[:messages] ++
[
%{
role: "assistant",
content: """
Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem.
"""
}
]

params = Keyword.put(params, :messages, initial_messages)
params = Keyword.put(params, :response_model, ReasoningStep)

Stream.resource(
fn -> {params, 0} end,
fn
:halt ->
{:halt, nil}

{:final_answer, params} ->
new_messages =
params[:messages] ++
[
%{
role: "user",
content: """
Please provide the final answer based solely on your reasoning above.
Only provide the text response without any titles or preambles.
Retain any formatting as instructed by the original prompt, such as exact formatting for free response or multiple choice.
"""
}
]

params = Keyword.put(params, :messages, new_messages)
params = Keyword.put(params, :response_model, response_model)
{:ok, final_answer} = Instructor.chat_completion(params, config)
{[{:final_answer, final_answer}], :halt}

{params, step_count} ->
case Instructor.chat_completion(params, config) do
{:ok, %ReasoningStep{} = step} ->
new_messages =
params[:messages] ++
[
%{
role: "assistant",
content: step |> Map.from_struct() |> Jason.encode!()
}
]

params = Keyword.put(params, :messages, new_messages)

acc =
case step.next_action do
:final_answer ->
{:final_answer, params}

:continue ->
{params, step_count + 1}
end

{[step], acc}

{:error, reason} ->
IO.inspect(reason, label: "ERROR")
{:halt, {params, step_count}}
end
end,
fn _ -> nil end
)
|> Stream.transform(0, fn
{:final_answer, final_answer}, _step_count ->
{[final_answer], :halt}

step, step_count when step_count < reasoning_steps ->
{[step], step_count + 1}

_step, _step_count ->
{[{:error, "No final answer within #{reasoning_steps} reasoning steps"}], :halt}
end)
end
end
1 change: 1 addition & 0 deletions lib/instructor/sse_stream_parser.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ defmodule Instructor.SSEStreamParser do
Jason.decode!(json_string)
end)
end)
# |> Stream.each(&IO.inspect/1)
end
end
4 changes: 3 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ defmodule Instructor.MixProject do
{:req, "~> 0.5 or ~> 1.0"},
{:jaxon, "~> 2.0"},
{:ex_doc, "~> 0.31", only: :dev, runtime: false},
{:mox, "~> 1.1.0", only: :test}
{:mox, "~> 1.1.0", only: :test},
{:phoenix, "~> 1.7", only: :test},
{:phoenix_live_view, "~> 0.20.17", only: :test}
]
end
end
9 changes: 9 additions & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
"nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
"phoenix": {:hex, :phoenix, "1.7.14", "a7d0b3f1bc95987044ddada111e77bd7f75646a08518942c72a8440278ae7825", [:mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:phoenix_pubsub, "~> 2.1", [hex: :phoenix_pubsub, repo: "hexpm", optional: false]}, {:phoenix_template, "~> 1.0", [hex: :phoenix_template, repo: "hexpm", optional: false]}, {:phoenix_view, "~> 2.0", [hex: :phoenix_view, repo: "hexpm", optional: true]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.7", [hex: :plug_cowboy, repo: "hexpm", optional: true]}, {:plug_crypto, "~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:websock_adapter, "~> 0.5.3", [hex: :websock_adapter, repo: "hexpm", optional: false]}], "hexpm", "c7859bc56cc5dfef19ecfc240775dae358cbaa530231118a9e014df392ace61a"},
"phoenix_html": {:hex, :phoenix_html, "4.1.1", "4c064fd3873d12ebb1388425a8f2a19348cef56e7289e1998e2d2fa758aa982e", [:mix], [], "hexpm", "f2f2df5a72bc9a2f510b21497fd7d2b86d932ec0598f0210fed4114adc546c6f"},
"phoenix_live_view": {:hex, :phoenix_live_view, "0.20.17", "f396bbdaf4ba227b82251eb75ac0afa6b3da5e509bc0d030206374237dfc9450", [:mix], [{:floki, "~> 0.36", [hex: :floki, repo: "hexpm", optional: true]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:phoenix, "~> 1.6.15 or ~> 1.7.0", [hex: :phoenix, repo: "hexpm", optional: false]}, {:phoenix_html, "~> 3.3 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:phoenix_template, "~> 1.0", [hex: :phoenix_template, repo: "hexpm", optional: false]}, {:phoenix_view, "~> 2.0", [hex: :phoenix_view, repo: "hexpm", optional: true]}, {:plug, "~> 1.15", [hex: :plug, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.2 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "a61d741ffb78c85fdbca0de084da6a48f8ceb5261a79165b5a0b59e5f65ce98b"},
"phoenix_pubsub": {:hex, :phoenix_pubsub, "2.1.3", "3168d78ba41835aecad272d5e8cd51aa87a7ac9eb836eabc42f6e57538e3731d", [:mix], [], "hexpm", "bba06bc1dcfd8cb086759f0edc94a8ba2bc8896d5331a1e2c2902bf8e36ee502"},
"phoenix_template": {:hex, :phoenix_template, "1.0.4", "e2092c132f3b5e5b2d49c96695342eb36d0ed514c5b252a77048d5969330d639", [:mix], [{:phoenix_html, "~> 2.14.2 or ~> 3.0 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: true]}], "hexpm", "2c0c81f0e5c6753faf5cca2f229c9709919aba34fab866d3bc05060c9c444206"},
"plug": {:hex, :plug, "1.16.1", "40c74619c12f82736d2214557dedec2e9762029b2438d6d175c5074c933edc9d", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "a13ff6b9006b03d7e33874945b2755253841b238c34071ed85b0e86057f8cddc"},
"plug_crypto": {:hex, :plug_crypto, "2.1.0", "f44309c2b06d249c27c8d3f65cfe08158ade08418cf540fd4f72d4d6863abb7b", [:mix], [], "hexpm", "131216a4b030b8f8ce0f26038bc4421ae60e4bb95c5cf5395e1421437824c4fa"},
"req": {:hex, :req, "0.5.0", "6d8a77c25cfc03e06a439fb12ffb51beade53e3fe0e2c5e362899a18b50298b3", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.17", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "dda04878c1396eebbfdec6db6f3d4ca609e5c8846b7ee88cc56eb9891406f7a3"},
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
"websock": {:hex, :websock, "0.5.3", "2f69a6ebe810328555b6fe5c831a851f485e303a7c8ce6c5f675abeb20ebdadc", [:mix], [], "hexpm", "6105453d7fac22c712ad66fab1d45abdf049868f253cf719b625151460b8b453"},
"websock_adapter": {:hex, :websock_adapter, "0.5.7", "65fa74042530064ef0570b75b43f5c49bb8b235d6515671b3d250022cb8a1f9e", [:mix], [{:bandit, ">= 0.6.0", [hex: :bandit, repo: "hexpm", optional: true]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.6", [hex: :plug_cowboy, repo: "hexpm", optional: true]}, {:websock, "~> 0.5", [hex: :websock, repo: "hexpm", optional: false]}], "hexpm", "d0f478ee64deddfec64b800673fd6e0c8888b079d9f3444dd96d2a98383bdbd1"},
}
Loading

0 comments on commit 1abd847

Please sign in to comment.