Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for schema-constrained tokens in KurtOpenAI #63

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/basic/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"@formula-monks/kurt-open-ai": "workspace:*",
"@formula-monks/kurt-vertex-ai": "workspace:*",
"@google-cloud/vertexai": "1.1.0",
"openai": "4.66.1",
"openai": "^4.76.0",
"zod": "^3.23.8"
}
}
2 changes: 1 addition & 1 deletion packages/kurt-open-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
"dependencies": {
"@formula-monks/kurt": "^1.4.0",
"openai": "4.66.1",
"openai": "4.76.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.23.3"
},
Expand Down
7 changes: 4 additions & 3 deletions packages/kurt-open-ai/spec/generateNaturalLanguage.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { KurtResultLimitError } from "@formula-monks/kurt"

describe("KurtOpenAI generateNaturalLanguage", () => {
test("says hello", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Say hello!",
})
Expand All @@ -13,7 +13,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {
})

test("writes a haiku with high temperature", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Compose a haiku about a mountain stream at night.",
sampling: {
Expand All @@ -34,6 +34,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {

test("throws a limit error", async () => {
await snapshotAndMockWithError(
"gpt-4o-2024-05-13",
(kurt) =>
kurt.generateNaturalLanguage({
prompt: "Compose a haiku about content length limitations.",
Expand All @@ -50,7 +51,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {
})

test("describes a base64-encoded image", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Describe this emoji, in two words.",
extraMessages: [
Expand Down
62 changes: 60 additions & 2 deletions packages/kurt-open-ai/spec/generateStructuredData.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { describe, test, expect } from "@jest/globals"
import { z } from "zod"
import { snapshotAndMock, snapshotAndMockWithError } from "./snapshots"
import { KurtResultValidateError } from "@formula-monks/kurt"
import {
KurtCapabilityError,
KurtResultValidateError,
} from "@formula-monks/kurt"

describe("KurtOpenAI generateStructuredData", () => {
test("says hello", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
Expand All @@ -18,8 +21,63 @@ describe("KurtOpenAI generateStructuredData", () => {
expect(result.data).toEqual({ say: "hello" })
})

test("says hello with system prompt", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
systemPrompt: "Be nice.",
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
})
)
expect(result.data).toEqual({ say: "hello" })
})

test("says hello with schema constrained tokens", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
sampling: { forceSchemaConstrainedTokens: true },
})
)
expect(result.data).toEqual({ say: "hello" })
})

test("throws a capability error for schema constrained tokens in an older model", async () => {
await snapshotAndMockWithError(
"gpt-4o-2024-05-13",
(kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
sampling: { forceSchemaConstrainedTokens: true },
}),
(errorAny) => {
expect(errorAny).toBeInstanceOf(KurtCapabilityError)
const error = errorAny as KurtCapabilityError
expect(error.missingCapability).toEqual(
"forceSchemaConstrainedTokens is not available for older models, including gpt-4o-2024-05-13"
)
expect(error.message).toContain(error.missingCapability)
}
)
})

test("throws a validate error from an impossible schema", async () => {
await snapshotAndMockWithError(
"gpt-4o-mini-2024-07-18",
(kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
Expand Down
26 changes: 21 additions & 5 deletions packages/kurt-open-ai/spec/generateWithOptionalTools.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const calculatorTools = {

describe("KurtOpenAI generateWithOptionalTools", () => {
test("calculator (with tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
Expand All @@ -34,8 +34,24 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
expect(result.additionalData).toBeUndefined() // no parallel tool calls
})

test("calculator (with strict tool call)", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
tools: calculatorTools,
sampling: { forceSchemaConstrainedTokens: true },
})
)
expect(result.data).toEqual({
name: "divide",
args: { dividend: 9876356, divisor: 30487 },
})
expect(result.additionalData).toBeUndefined() // no parallel tool calls
})

test("calculator (after tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
Expand All @@ -58,7 +74,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
})

test("calculator (with parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
Expand Down Expand Up @@ -86,7 +102,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
})

test("calculator (after parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
Expand Down Expand Up @@ -125,7 +141,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
)
expect(result.text).toEqual(
[
"Here are the results of the calculations:",
"Here are the results:",
"",
"1. 8026256882 divided by 3402398 is 2359.",
"2. 1185835515 divided by 348263 is 3405.",
Expand Down
8 changes: 5 additions & 3 deletions packages/kurt-open-ai/spec/snapshots.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import type {
OpenAIResponse,
OpenAIResponseChunk,
} from "../src/OpenAI.types"
import { KurtOpenAI } from "../src/KurtOpenAI"
import { KurtOpenAI, type KurtOpenAISupportedModel } from "../src/KurtOpenAI"

function snapshotFilenameFor(testName: string | undefined) {
return `${__dirname}/snapshots/${testName?.replace(/ /g, "_")}.yaml`
Expand All @@ -29,6 +29,7 @@ function dumpYaml(filename: string, data: object) {
}

export async function snapshotAndMock<T>(
model: KurtOpenAISupportedModel,
testCaseFn: (kurt: Kurt) => KurtStream<T>
) {
// Here's the data structure we will use to snapshot a request/response cycle.
Expand Down Expand Up @@ -91,7 +92,7 @@ export async function snapshotAndMock<T>(
} as unknown as OpenAI

// Run the test case function with a new instance of Kurt.
const kurt = new Kurt(new KurtOpenAI({ openAI, model: "gpt-4o-2024-05-13" }))
const kurt = new Kurt(new KurtOpenAI({ openAI, model }))
const stream = testCaseFn(kurt)

// Save the final stream of Kurt events.
Expand All @@ -114,11 +115,12 @@ export async function snapshotAndMock<T>(
}

export async function snapshotAndMockWithError<T>(
model: KurtOpenAISupportedModel,
testCaseFn: (kurt: Kurt) => KurtStream<T>,
errorCheckFn: (error: Error) => void
) {
try {
await snapshotAndMock(testCaseFn)
await snapshotAndMock(model, testCaseFn)
expectedErrorToBeThrownBeforeThisPoint()
} catch (error: unknown) {
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ step1Request:
stream: true
stream_options:
include_usage: true
model: gpt-4o-2024-05-13
model: gpt-4o-mini-2024-07-18
max_tokens: 4096
temperature: 0.5
top_p: 0.95
messages:
- role: system
content:
- type: text
text: Respond with JSON.
- role: user
content:
- type: text
Expand All @@ -31,6 +35,8 @@ step1Request:
type: function
function:
name: structured_data
response_format:
type: json_object
step2RawChunks:
- choices:
- index: 0
Expand All @@ -39,14 +45,15 @@ step2RawChunks:
content: null
tool_calls:
- index: 0
id: call_oZj1FnPJSZCVbFYtpNYPAm7P
id: call_9x7qX8eO6DgWYP8h1xc5kHsl
type: function
function:
name: structured_data
arguments: ""
refusal: null
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -57,7 +64,7 @@ step2RawChunks:
arguments: '{"'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -68,7 +75,7 @@ step2RawChunks:
arguments: say
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -79,7 +86,7 @@ step2RawChunks:
arguments: '":"'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -90,7 +97,7 @@ step2RawChunks:
arguments: hello
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -101,21 +108,29 @@ step2RawChunks:
arguments: '"}'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
delta: {}
logprobs: null
finish_reason: stop
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices: []
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage:
prompt_tokens: 66
prompt_tokens: 70
completion_tokens: 5
total_tokens: 71
total_tokens: 75
prompt_tokens_details:
cached_tokens: 0
audio_tokens: 0
completion_tokens_details:
reasoning_tokens: 0
audio_tokens: 0
accepted_prediction_tokens: 0
rejected_prediction_tokens: 0
step3KurtEvents:
- chunk: '{"'
- chunk: say
Expand All @@ -127,6 +142,6 @@ step3KurtEvents:
data:
say: hello
metadata:
totalInputTokens: 66
totalInputTokens: 70
totalOutputTokens: 5
systemFingerprint: fp_5bf7397cd3
systemFingerprint: fp_bba3c8e70b
Loading
Loading