Skip to content

Commit

Permalink
feat: add support for schema-constrained tokens in KurtOpenAI
Browse files Browse the repository at this point in the history
This commit also updates the set of supported models to include the newer model snapshots. We also
update the tests to allow for specifying different models per test, and to mostly use the newer
models.
  • Loading branch information
jemc committed Dec 9, 2024
1 parent 9ba8bdb commit a98877b
Show file tree
Hide file tree
Showing 17 changed files with 1,019 additions and 228 deletions.
2 changes: 1 addition & 1 deletion examples/basic/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"@formula-monks/kurt-open-ai": "workspace:*",
"@formula-monks/kurt-vertex-ai": "workspace:*",
"@google-cloud/vertexai": "1.1.0",
"openai": "4.66.1",
"openai": "^4.76.0",
"zod": "^3.23.8"
}
}
2 changes: 1 addition & 1 deletion packages/kurt-open-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
"dependencies": {
"@formula-monks/kurt": "^1.4.0",
"openai": "4.66.1",
"openai": "4.76.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.23.3"
},
Expand Down
7 changes: 4 additions & 3 deletions packages/kurt-open-ai/spec/generateNaturalLanguage.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { KurtResultLimitError } from "@formula-monks/kurt"

describe("KurtOpenAI generateNaturalLanguage", () => {
test("says hello", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Say hello!",
})
Expand All @@ -13,7 +13,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {
})

test("writes a haiku with high temperature", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Compose a haiku about a mountain stream at night.",
sampling: {
Expand All @@ -34,6 +34,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {

test("throws a limit error", async () => {
await snapshotAndMockWithError(
"gpt-4o-2024-05-13",
(kurt) =>
kurt.generateNaturalLanguage({
prompt: "Compose a haiku about content length limitations.",
Expand All @@ -50,7 +51,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {
})

test("describes a base64-encoded image", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Describe this emoji, in two words.",
extraMessages: [
Expand Down
62 changes: 60 additions & 2 deletions packages/kurt-open-ai/spec/generateStructuredData.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { describe, test, expect } from "@jest/globals"
import { z } from "zod"
import { snapshotAndMock, snapshotAndMockWithError } from "./snapshots"
import { KurtResultValidateError } from "@formula-monks/kurt"
import {
KurtCapabilityError,
KurtResultValidateError,
} from "@formula-monks/kurt"

describe("KurtOpenAI generateStructuredData", () => {
test("says hello", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
Expand All @@ -18,8 +21,63 @@ describe("KurtOpenAI generateStructuredData", () => {
expect(result.data).toEqual({ say: "hello" })
})

test("says hello with system prompt", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
systemPrompt: "Be nice.",
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
})
)
expect(result.data).toEqual({ say: "hello" })
})

test("says hello with schema constrained tokens", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
sampling: { forceSchemaConstrainedTokens: true },
})
)
expect(result.data).toEqual({ say: "hello" })
})

test("throws a capability error for schema constrained tokens in an older model", async () => {
await snapshotAndMockWithError(
"gpt-4o-2024-05-13",
(kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
sampling: { forceSchemaConstrainedTokens: true },
}),
(errorAny) => {
expect(errorAny).toBeInstanceOf(KurtCapabilityError)
const error = errorAny as KurtCapabilityError
expect(error.missingCapability).toEqual(
"forceSchemaConstrainedTokens is not available for older models, including gpt-4o-2024-05-13"
)
expect(error.message).toContain(error.missingCapability)
}
)
})

test("throws a validate error from an impossible schema", async () => {
await snapshotAndMockWithError(
"gpt-4o-mini-2024-07-18",
(kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
Expand Down
26 changes: 21 additions & 5 deletions packages/kurt-open-ai/spec/generateWithOptionalTools.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const calculatorTools = {

describe("KurtOpenAI generateWithOptionalTools", () => {
test("calculator (with tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
Expand All @@ -34,8 +34,24 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
expect(result.additionalData).toBeUndefined() // no parallel tool calls
})

test("calculator (with strict tool call)", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
tools: calculatorTools,
sampling: { forceSchemaConstrainedTokens: true },
})
)
expect(result.data).toEqual({
name: "divide",
args: { dividend: 9876356, divisor: 30487 },
})
expect(result.additionalData).toBeUndefined() // no parallel tool calls
})

test("calculator (after tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
Expand All @@ -58,7 +74,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
})

test("calculator (with parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
Expand Down Expand Up @@ -86,7 +102,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
})

test("calculator (after parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
Expand Down Expand Up @@ -125,7 +141,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
)
expect(result.text).toEqual(
[
"Here are the results of the calculations:",
"Here are the results:",
"",
"1. 8026256882 divided by 3402398 is 2359.",
"2. 1185835515 divided by 348263 is 3405.",
Expand Down
8 changes: 5 additions & 3 deletions packages/kurt-open-ai/spec/snapshots.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import type {
OpenAIResponse,
OpenAIResponseChunk,
} from "../src/OpenAI.types"
import { KurtOpenAI } from "../src/KurtOpenAI"
import { KurtOpenAI, type KurtOpenAISupportedModel } from "../src/KurtOpenAI"

function snapshotFilenameFor(testName: string | undefined) {
return `${__dirname}/snapshots/${testName?.replace(/ /g, "_")}.yaml`
Expand All @@ -29,6 +29,7 @@ function dumpYaml(filename: string, data: object) {
}

export async function snapshotAndMock<T>(
model: KurtOpenAISupportedModel,
testCaseFn: (kurt: Kurt) => KurtStream<T>
) {
// Here's the data structure we will use to snapshot a request/response cycle.
Expand Down Expand Up @@ -91,7 +92,7 @@ export async function snapshotAndMock<T>(
} as unknown as OpenAI

// Run the test case function with a new instance of Kurt.
const kurt = new Kurt(new KurtOpenAI({ openAI, model: "gpt-4o-2024-05-13" }))
const kurt = new Kurt(new KurtOpenAI({ openAI, model }))
const stream = testCaseFn(kurt)

// Save the final stream of Kurt events.
Expand All @@ -114,11 +115,12 @@ export async function snapshotAndMock<T>(
}

export async function snapshotAndMockWithError<T>(
model: KurtOpenAISupportedModel,
testCaseFn: (kurt: Kurt) => KurtStream<T>,
errorCheckFn: (error: Error) => void
) {
try {
await snapshotAndMock(testCaseFn)
await snapshotAndMock(model, testCaseFn)
expectedErrorToBeThrownBeforeThisPoint()
} catch (error: unknown) {
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ step1Request:
stream: true
stream_options:
include_usage: true
model: gpt-4o-2024-05-13
model: gpt-4o-mini-2024-07-18
max_tokens: 4096
temperature: 0.5
top_p: 0.95
messages:
- role: system
content:
- type: text
text: Respond with JSON.
- role: user
content:
- type: text
Expand All @@ -31,6 +35,8 @@ step1Request:
type: function
function:
name: structured_data
response_format:
type: json_object
step2RawChunks:
- choices:
- index: 0
Expand All @@ -39,14 +45,15 @@ step2RawChunks:
content: null
tool_calls:
- index: 0
id: call_oZj1FnPJSZCVbFYtpNYPAm7P
id: call_9x7qX8eO6DgWYP8h1xc5kHsl
type: function
function:
name: structured_data
arguments: ""
refusal: null
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -57,7 +64,7 @@ step2RawChunks:
arguments: '{"'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -68,7 +75,7 @@ step2RawChunks:
arguments: say
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -79,7 +86,7 @@ step2RawChunks:
arguments: '":"'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -90,7 +97,7 @@ step2RawChunks:
arguments: hello
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -101,21 +108,29 @@ step2RawChunks:
arguments: '"}'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
delta: {}
logprobs: null
finish_reason: stop
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices: []
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage:
prompt_tokens: 66
prompt_tokens: 70
completion_tokens: 5
total_tokens: 71
total_tokens: 75
prompt_tokens_details:
cached_tokens: 0
audio_tokens: 0
completion_tokens_details:
reasoning_tokens: 0
audio_tokens: 0
accepted_prediction_tokens: 0
rejected_prediction_tokens: 0
step3KurtEvents:
- chunk: '{"'
- chunk: say
Expand All @@ -127,6 +142,6 @@ step3KurtEvents:
data:
say: hello
metadata:
totalInputTokens: 66
totalInputTokens: 70
totalOutputTokens: 5
systemFingerprint: fp_5bf7397cd3
systemFingerprint: fp_bba3c8e70b
Loading

0 comments on commit a98877b

Please sign in to comment.