diff --git a/packages/kurt-open-ai/src/KurtOpenAI.ts b/packages/kurt-open-ai/src/KurtOpenAI.ts index e0675e2..5a54373 100644 --- a/packages/kurt-open-ai/src/KurtOpenAI.ts +++ b/packages/kurt-open-ai/src/KurtOpenAI.ts @@ -302,9 +302,9 @@ async function* transformStreamWithOptionalTools< } as D }) - // biome-ignore lint/style/noNonNullAssertion: we already validated above that length > 0 - const data = allData[0]! - const additionalData = allData.slice(1) + if (!isNonEmptyArray(allData)) + throw new Error("Empty here is impossible but TS doesn't know it") + const [data, ...additionalData] = allData if (additionalData.length > 0) { yield { finished: true, text, data, additionalData } @@ -317,3 +317,11 @@ async function* transformStreamWithOptionalTools< } } } + +/** + * Return true if this array has at least one element, also refining the + * Typescript type to indicate that the first element won't be undefined. + */ +function isNonEmptyArray(array: T[]): array is [T, ...T[]] { + return array.length > 0 +} diff --git a/packages/kurt-vertex-ai/spec/generateWithOptionalTools.spec.ts b/packages/kurt-vertex-ai/spec/generateWithOptionalTools.spec.ts index a5334de..66d52a3 100644 --- a/packages/kurt-vertex-ai/spec/generateWithOptionalTools.spec.ts +++ b/packages/kurt-vertex-ai/spec/generateWithOptionalTools.spec.ts @@ -2,26 +2,28 @@ import { describe, test, expect } from "@jest/globals" import { z } from "zod" import { snapshotAndMock } from "./snapshots" +const calculatorTools = { + subtract: z + .object({ + minuend: z.number().describe("The number to subtract from"), + subtrahend: z.number().describe("The number to subtract by"), + }) + .describe("Calculate a subtraction"), + divide: z + .object({ + dividend: z.number().describe("The number to be divided"), + divisor: z.number().describe("The number to divide by"), + }) + .describe("Calculate a division"), +} + describe("KurtVertexAI generateWithOptionalTools", () => { test("calculator (with tool call)", async () => { const result = await snapshotAndMock((kurt) => kurt.generateWithOptionalTools({ prompt: "What's 9876356 divided by 30487, rounded to the nearest integer?", - tools: { - subtract: z - .object({ - minuend: z.number().describe("The number to subtract from"), - subtrahend: z.number().describe("The number to subtract by"), - }) - .describe("Calculate a subtraction"), - divide: z - .object({ - dividend: z.number().describe("The number to be divided"), - divisor: z.number().describe("The number to divide by"), - }) - .describe("Calculate a division"), - }, + tools: calculatorTools, }) ) expect(result.data).toEqual({ @@ -35,20 +37,7 @@ describe("KurtVertexAI generateWithOptionalTools", () => { kurt.generateWithOptionalTools({ prompt: "What's 9876356 divided by 30487, rounded to the nearest integer?", - tools: { - subtract: z - .object({ - minuend: z.number().describe("The number to subtract from"), - subtrahend: z.number().describe("The number to subtract by"), - }) - .describe("Calculate a subtraction"), - divide: z - .object({ - dividend: z.number().describe("The number to be divided"), - divisor: z.number().describe("The number to divide by"), - }) - .describe("Calculate a division"), - }, + tools: calculatorTools, extraMessages: [ { role: "model" as const, @@ -63,4 +52,80 @@ describe("KurtVertexAI generateWithOptionalTools", () => { ) expect(result.text).toEqual("That's about 324.") }) + + test("calculator (with parallel tool calls)", async () => { + const result = await snapshotAndMock((kurt) => + kurt.generateWithOptionalTools({ + prompt: [ + "Calculate each of the following:", + "1. 8026256882 divided by 3402398", + "2. 1185835515 divided by 348263", + "3. 90135094495 minus 89944954350", + ].join("\n"), + tools: calculatorTools, + }) + ) + expect(result.data).toEqual({ + name: "divide", + args: { dividend: 8026256882, divisor: 3402398 }, + }) + expect(result.additionalData).toEqual([ + { + name: "divide", + args: { dividend: 1185835515, divisor: 348263 }, + }, + { + name: "subtract", + args: { minuend: 90135094495, subtrahend: 89944954350 }, + }, + ]) + }) + + test("calculator (after parallel tool calls)", async () => { + const result = await snapshotAndMock((kurt) => + kurt.generateWithOptionalTools({ + prompt: [ + "Calculate each of the following:", + "1. 8026256882 divided by 3402398", + "2. 1185835515 divided by 348263", + "3. 90135094495 minus 89944954350", + ].join("\n"), + tools: calculatorTools, + extraMessages: [ + { + role: "model", + toolCall: { + name: "divide", + args: { dividend: 8026256882, divisor: 3402398 }, + result: { quotient: 2359 }, + }, + }, + { + role: "model", + toolCall: { + name: "divide", + args: { dividend: 1185835515, divisor: 348263 }, + result: { quotient: 3405 }, + }, + }, + { + role: "model", + toolCall: { + name: "subtract", + args: { minuend: 90135094495, subtrahend: 89944954350 }, + result: { quotient: 190140145 }, + }, + }, + ], + }) + ) + expect(result.text).toEqual( + [ + "1. 8026256882 divided by 3402398 is 2359.", + "2. 1185835515 divided by 348263 is 3405.", + "3. 90135094495 minus 89944954350 is 190140145.", + "", + ].join("\n") + ) + }) }) diff --git a/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateWithOptionalTools_calculator_(after_parallel_tool_calls).yaml b/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateWithOptionalTools_calculator_(after_parallel_tool_calls).yaml new file mode 100644 index 0000000..9540298 --- /dev/null +++ b/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateWithOptionalTools_calculator_(after_parallel_tool_calls).yaml @@ -0,0 +1,141 @@ +step1Request: + generationConfig: + maxOutputTokens: 4096 + temperature: 0.5 + topP: 0.95 + contents: + - role: user + parts: + - text: |- + Calculate each of the following: + 1. 8026256882 divided by 3402398 + 2. 1185835515 divided by 348263 + 3. 90135094495 minus 89944954350 + - role: model + parts: + - functionCall: + name: divide + args: + dividend: 8026256882 + divisor: 3402398 + - role: model + parts: + - functionResponse: + name: divide + response: + quotient: 2359 + - role: model + parts: + - functionCall: + name: divide + args: + dividend: 1185835515 + divisor: 348263 + - role: model + parts: + - functionResponse: + name: divide + response: + quotient: 3405 + - role: model + parts: + - functionCall: + name: subtract + args: + minuend: 90135094495 + subtrahend: 89944954350 + - role: model + parts: + - functionResponse: + name: subtract + response: + quotient: 190140145 + tools: + - functionDeclarations: + - name: subtract + description: Calculate a subtraction + parameters: + type: object + properties: + minuend: + type: number + description: The number to subtract from + subtrahend: + type: number + description: The number to subtract by + required: + - minuend + - subtrahend + - name: divide + description: Calculate a division + parameters: + type: object + properties: + dividend: + type: number + description: The number to be divided + divisor: + type: number + description: The number to divide by + required: + - dividend + - divisor +step2RawChunks: + - content: + role: model + parts: + - text: "1" + index: 0 + - content: + role: model + parts: + - text: . 8026256882 divided by 3 + index: 0 + - content: + role: model + parts: + - text: |- + 402398 is 2359. + 2. + index: 0 + - content: + role: model + parts: + - text: |2- + 1185835515 divided by 348263 is 3405. + 3. 9 + index: 0 + - content: + role: model + parts: + - text: 0135094495 minus 89944954350 is 1901401 + index: 0 + - content: + role: model + parts: + - text: | + 45. + index: 0 + - content: + role: model + parts: + - text: "" + finishReason: STOP + index: 0 +step3KurtEvents: + - chunk: "1" + - chunk: . 8026256882 divided by 3 + - chunk: |- + 402398 is 2359. + 2. + - chunk: |2- + 1185835515 divided by 348263 is 3405. + 3. 9 + - chunk: 0135094495 minus 89944954350 is 1901401 + - chunk: | + 45. + - finished: true + text: | + 1. 8026256882 divided by 3402398 is 2359. + 2. 1185835515 divided by 348263 is 3405. + 3. 90135094495 minus 89944954350 is 190140145. diff --git a/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateWithOptionalTools_calculator_(with_parallel_tool_calls).yaml b/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateWithOptionalTools_calculator_(with_parallel_tool_calls).yaml new file mode 100644 index 0000000..64443e6 --- /dev/null +++ b/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateWithOptionalTools_calculator_(with_parallel_tool_calls).yaml @@ -0,0 +1,88 @@ +step1Request: + generationConfig: + maxOutputTokens: 4096 + temperature: 0.5 + topP: 0.95 + contents: + - role: user + parts: + - text: |- + Calculate each of the following: + 1. 8026256882 divided by 3402398 + 2. 1185835515 divided by 348263 + 3. 90135094495 minus 89944954350 + tools: + - functionDeclarations: + - name: subtract + description: Calculate a subtraction + parameters: + type: object + properties: + minuend: + type: number + description: The number to subtract from + subtrahend: + type: number + description: The number to subtract by + required: + - minuend + - subtrahend + - name: divide + description: Calculate a division + parameters: + type: object + properties: + dividend: + type: number + description: The number to be divided + divisor: + type: number + description: The number to divide by + required: + - dividend + - divisor +step2RawChunks: + - content: + role: model + parts: + - functionCall: + name: divide + args: + dividend: 8026256882 + divisor: 3402398 + - functionCall: + name: divide + args: + dividend: 1185835515 + divisor: 348263 + - functionCall: + name: subtract + args: + minuend: 90135094495 + subtrahend: 89944954350 + index: 0 +step3KurtEvents: + - chunk: '{"dividend":8026256882,"divisor":3402398}' + - chunk: "\n" + - chunk: '{"dividend":1185835515,"divisor":348263}' + - chunk: "\n" + - chunk: '{"minuend":90135094495,"subtrahend":89944954350}' + - finished: true + text: |- + {"dividend":8026256882,"divisor":3402398} + {"dividend":1185835515,"divisor":348263} + {"minuend":90135094495,"subtrahend":89944954350} + data: + name: divide + args: + dividend: 8026256882 + divisor: 3402398 + additionalData: + - name: divide + args: + dividend: 1185835515 + divisor: 348263 + - name: subtract + args: + minuend: 90135094495 + subtrahend: 89944954350 diff --git a/packages/kurt-vertex-ai/src/KurtVertexAI.ts b/packages/kurt-vertex-ai/src/KurtVertexAI.ts index 0a6fb4c..feebc79 100644 --- a/packages/kurt-vertex-ai/src/KurtVertexAI.ts +++ b/packages/kurt-vertex-ai/src/KurtVertexAI.ts @@ -265,31 +265,60 @@ async function* transformStreamWithOptionalTools< for (const [partIndex, part] of parts.entries()) { const chunk = part.text const isFinal = - (isContentFinal && partIndex === parts.length - 1) || part.functionCall + (isContentFinal || part.functionCall) && partIndex === parts.length - 1 if (chunk) { chunks.push(chunk) yield { chunk } } if (isFinal) { - const { functionCall } = part - if (functionCall) { - const { name } = functionCall - const schema = tools[name] - if (!schema) { - throw new Error( - `Vertex AI tried to call tool ${name} which isn't in the tool set ${JSON.stringify( - Object.keys(tools) - )}}` - ) + if (part.functionCall) { + const allData = parts.map((part) => { + if (!part.functionCall) { + throw new Error( + `Vertex AI mixed function calls with non-function calls in the same raw stream event: ${JSON.stringify( + rawEvent + )}` + ) + } + + const { name } = part.functionCall + + const schema = tools[name] + if (!schema) { + throw new Error( + `Vertex AI tried to call tool ${name} which isn't in the tool set ${JSON.stringify( + Object.keys(tools) + )}}` + ) + } + return { + name, + args: applySchemaToFuzzyStructure(schema, part.functionCall), + } as D + }) + + // Emit a text chunk for each tool call (with line breaks in between). + for (const [dataIndex, data] of allData.entries()) { + if (dataIndex > 0) { + chunks.push("\n") + yield { chunk: "\n" } + } + const text = JSON.stringify(data.args) + chunks.push(text) + yield { chunk: text } + } + + if (!isNonEmptyArray(allData)) + throw new Error("Empty here is impossible but TS doesn't know it") + const [data, ...additionalData] = allData + const text = chunks.join("") + + if (additionalData.length > 0) { + yield { finished: true, text, data: data as D, additionalData } + } else { + yield { finished: true, text, data } } - const data = { - name, - args: applySchemaToFuzzyStructure(schema, functionCall), - } as D - const text = JSON.stringify(data.args) - yield { chunk: text } - yield { finished: true, text, data } } else { const text = chunks.join("") yield { finished: true, text, data: undefined } @@ -328,3 +357,11 @@ function applySchemaToFuzzyStructure( throw firstParseError } } + +/** + * Return true if this array has at least one element, also refining the + * Typescript type to indicate that the first element won't be undefined. + */ +function isNonEmptyArray(array: T[]): array is [T, ...T[]] { + return array.length > 0 +}