Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a deep thinking reasoner model (o1-preview/mini) #68

Merged
merged 18 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ developer = "goose.toolkit.developer:Developer"
github = "goose.toolkit.github:Github"
jira = "goose.toolkit.jira:Jira"
screen = "goose.toolkit.screen:Screen"
reasoner = "goose.toolkit.reasoner:Reasoner"
repo_context = "goose.toolkit.repo_context.repo_context:RepoContext"

[project.entry-points."goose.profile"]
Expand Down
6 changes: 6 additions & 0 deletions src/goose/toolkit/prompts/reasoner.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
To generate or write code you will have collected instructions and requirements,
and perhaps setup things in place already.

When it is time to write code, you can call your tools such as

To accomplish this, you will call your tools such as generate_code, deep_debug and deep_understand.
158 changes: 158 additions & 0 deletions src/goose/toolkit/reasoner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
from exchange import Exchange, Message, Text
from exchange.content import Content
from exchange.providers import Provider, Usage
from exchange.tool import Tool
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import raise_for_status, openai_single_message_context_length_exceeded
from exchange.providers.utils import openai_response_to_message, messages_to_openai_spec
from goose.toolkit.base import Toolkit, tool
from goose.utils.ask import ask_an_ai


class Reasoner(Toolkit):
"""This is a toolkit to add deeper and slower reasoning around code and questions and debugging"""


def message_content(self, content: Content) -> Text:
if isinstance(content, Text):
return content
else:
return Text(str(content))


@tool
def deep_debug(self, problem:str) -> str:
"""
Thius tool can assist with debugging when there are errors or problems when trying things
michaelneale marked this conversation as resolved.
Show resolved Hide resolved
and other approaches haven't solved it.
It will take a minute to think about it and consider solutions.

Args:
problem (str): description of problem or errors seen.

Returns:
response (str): A solution, which may include a suggestion or code snippet.
"""
# Create an instance of Exchange with the inlined OpenAI provider
self.notifier.status("thinking...")
provider = self.OpenAiProvider.from_env()

# Create messages list
existing_messages_copy = [
Message(role=msg.role, content=[self.message_content(content) for content in msg.content])
for msg in self.exchange_view.processor.messages]
exchange = Exchange(provider=provider, model="o1-preview", messages=existing_messages_copy, system=None)

response = ask_an_ai(input="Can you help debug this problem: " + problem, exchange=exchange)
return response.content[0].text


@tool
def generate_code(self, instructions:str) -> str:
"""
try to use this when enhancing existing code or generating new code unless it is simple or it is required quickly.
this can generate high quality code to be considered and used.

Args:
instructions (str): instructions of what code to write or how to modify it.

Returns:
response (str): generated code to be tested or applied as needed.
"""
# Create an instance of Exchange with the inlined OpenAI provider
self.notifier.status("generating code...")
provider = self.OpenAiProvider.from_env()

# clone messages, converting to text for context
existing_messages_copy = [
Message(role=msg.role, content=[self.message_content(content) for content in msg.content])
for msg in self.exchange_view.processor.messages]
exchange = Exchange(provider=provider,
model="o1-preview",
messages=existing_messages_copy, system=None)

response = ask_an_ai(input=instructions,
exchange=exchange)
return response.content[0].text

def system(self) -> str:
"""Retrieve instructions on how to use this reasoning and code generation tool"""
return Message.load("prompts/reasoner.jinja").text





class OpenAiProvider(Provider):
"""Inlined here as o1 model only in preview and supports very little still."""

def __init__(self, client: httpx.Client) -> None:
super().__init__()
self.client = client

@classmethod
def from_env(cls: Type["Reasoner.OpenAiProvider"]) -> "Reasoner.OpenAiProvider":
url = os.environ.get("OPENAI_HOST", "https://api.openai.com/")
try:
key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
client = httpx.Client(
base_url=url,
auth=("Bearer", key),
timeout=httpx.Timeout(60 * 10),
)
return cls(client)

@staticmethod
def get_usage(data: dict) -> Usage:
usage = data.pop("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")

if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens

return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)

def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
payload = dict(
messages=[
*messages_to_openai_spec(messages),
],
model=model,
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(payload)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()

message = openai_response_to_message(data)
usage = self.get_usage(data)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post("v1/chat/completions", json=payload)
137 changes: 137 additions & 0 deletions tests/curves/p256_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// tests/curves/p256_tests.rs
michaelneale marked this conversation as resolved.
Show resolved Hide resolved

use super::super::curves::p256::P256Curve;
use p256::ecdsa::Signature;
use p256::EncodedPoint;
use hex::decode;
use std::str::FromStr;

#[test]
fn test_key_generation() {
let curve = P256Curve::new();
let public_key = curve.get_public_key();

// Check that the public key is a valid encoded point
assert!(EncodedPoint::from_bytes(public_key.as_bytes()).is_ok());

// Ensure the public key is in compressed form (if using compressed encoding)
assert_eq!(public_key.len(), 33);
}

#[test]
fn test_sign_and_verify() {
let curve = P256Curve::new();
let message = b"Test message for P256 signing";

let signature: Signature = curve.sign(message);
assert!(curve.verify(message, &signature));
}

#[test]
fn test_invalid_signature() {
let curve = P256Curve::new();
let message = b"Original message";
let tampered_message = b"Tampered message";

let signature: Signature = curve.sign(message);
assert!(!curve.verify(tampered_message, &signature));
}

#[test]
fn test_signature_verification_with_different_key() {
let curve1 = P256Curve::new();
let curve2 = P256Curve::new();
let message = b"Test message for P256 signing";

let signature: Signature = curve1.sign(message);
assert!(!curve2.verify(message, &signature));
}

#[test]
fn test_deterministic_signing() {
// P256 ECDSA can use deterministic signatures as per RFC 6979
// Ensure that signing the same message with the same key produces the same signature
let mut rng = p256::ecdsa::SigningKey::random(&mut rand_core::OsRng);
let message = b"Deterministic signing test message";

let signature1 = rng.sign(message);
let signature2 = rng.sign(message);

assert_eq!(signature1.as_ref(), signature2.as_ref());
}

#[test]
fn test_public_key_serialization() {
let curve = P256Curve::new();
let public_key = curve.get_public_key();

// Serialize to DER format
let der = public_key.to_der();
assert!(der.is_ok());

// Deserialize back
let decoded = EncodedPoint::from_der(der.unwrap().as_bytes());
assert!(decoded.is_ok());
let decoded = decoded.unwrap();

// Check that the decoded key matches the original
assert_eq!(decoded, public_key);
}

#[test]
fn test_invalid_public_key() {
// Attempt to create a public key from invalid bytes
let invalid_bytes = [0u8; 33];
let result = EncodedPoint::from_bytes(&invalid_bytes);
assert!(result.is_err());
}

#[test]
fn test_known_signatures() {
// Test against known signatures and keys
// Example using test vectors (Replace with real test vectors)

let signing_key = p256::ecdsa::SigningKey::from_bytes(&decode("c9af...").unwrap()).unwrap();
let verifying_key = signing_key.verifying_key();
let message = b"Known message";

let signature = signing_key.sign(message);
let expected_signature = Signature::from_der(&decode("3045...").unwrap()).unwrap();

assert_eq!(signature.as_ref(), expected_signature.as_ref());
assert!(verifying_key.verify(message, &signature).is_ok());
}

#[test]
fn test_large_message() {
let curve = P256Curve::new();
// Generate a large message (e.g., 1MB)
let message = vec![0u8; 1_000_000];
let signature = curve.sign(&message);
assert!(curve.verify(&message, &signature));
}

#[test]
fn test_empty_message() {
let curve = P256Curve::new();
let message = b"";
let signature = curve.sign(message);
assert!(curve.verify(message, &signature));
}

#[test]
fn test_multiple_signatures() {
let curve = P256Curve::new();
let messages = vec![
b"First message",
b"Second message",
b"Third message",
b"Fourth message",
b"Fifth message",
];

for message in messages {
let signature = curve.sign(message);
assert!(curve.verify(message, &signature));
}
}
Loading