Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support parallel tool calls in KurtVertexAI with generateWithOptionalTools #37

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions packages/kurt-open-ai/src/KurtOpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ async function* transformStreamWithOptionalTools<
} as D
})

// biome-ignore lint/style/noNonNullAssertion: we already validated above that length > 0
const data = allData[0]!
const additionalData = allData.slice(1)
if (!isNonEmptyArray(allData))
throw new Error("Empty here is impossible but TS doesn't know it")
const [data, ...additionalData] = allData

if (additionalData.length > 0) {
yield { finished: true, text, data, additionalData }
Expand All @@ -317,3 +317,11 @@ async function* transformStreamWithOptionalTools<
}
}
}

/**
* Return true if this array has at least one element, also refining the
* Typescript type to indicate that the first element won't be undefined.
*/
function isNonEmptyArray<T>(array: T[]): array is [T, ...T[]] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make sense to export this from a utils package at some point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's something I'd love to avoid if we can 🙃

return array.length > 0
}
121 changes: 93 additions & 28 deletions packages/kurt-vertex-ai/spec/generateWithOptionalTools.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@ import { describe, test, expect } from "@jest/globals"
import { z } from "zod"
import { snapshotAndMock } from "./snapshots"

const calculatorTools = {
subtract: z
.object({
minuend: z.number().describe("The number to subtract from"),
subtrahend: z.number().describe("The number to subtract by"),
})
.describe("Calculate a subtraction"),
divide: z
.object({
dividend: z.number().describe("The number to be divided"),
divisor: z.number().describe("The number to divide by"),
})
.describe("Calculate a division"),
}

describe("KurtVertexAI generateWithOptionalTools", () => {
test("calculator (with tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
tools: {
subtract: z
.object({
minuend: z.number().describe("The number to subtract from"),
subtrahend: z.number().describe("The number to subtract by"),
})
.describe("Calculate a subtraction"),
divide: z
.object({
dividend: z.number().describe("The number to be divided"),
divisor: z.number().describe("The number to divide by"),
})
.describe("Calculate a division"),
},
tools: calculatorTools,
})
)
expect(result.data).toEqual({
Expand All @@ -35,20 +37,7 @@ describe("KurtVertexAI generateWithOptionalTools", () => {
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
tools: {
subtract: z
.object({
minuend: z.number().describe("The number to subtract from"),
subtrahend: z.number().describe("The number to subtract by"),
})
.describe("Calculate a subtraction"),
divide: z
.object({
dividend: z.number().describe("The number to be divided"),
divisor: z.number().describe("The number to divide by"),
})
.describe("Calculate a division"),
},
tools: calculatorTools,
extraMessages: [
{
role: "model" as const,
Expand All @@ -63,4 +52,80 @@ describe("KurtVertexAI generateWithOptionalTools", () => {
)
expect(result.text).toEqual("That's about 324.")
})

test("calculator (with parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
"1. 8026256882 divided by 3402398",
"2. 1185835515 divided by 348263",
"3. 90135094495 minus 89944954350",
].join("\n"),
tools: calculatorTools,
})
)
expect(result.data).toEqual({
name: "divide",
args: { dividend: 8026256882, divisor: 3402398 },
})
expect(result.additionalData).toEqual([
{
name: "divide",
args: { dividend: 1185835515, divisor: 348263 },
},
{
name: "subtract",
args: { minuend: 90135094495, subtrahend: 89944954350 },
},
])
})

test("calculator (after parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
"1. 8026256882 divided by 3402398",
"2. 1185835515 divided by 348263",
"3. 90135094495 minus 89944954350",
].join("\n"),
tools: calculatorTools,
extraMessages: [
{
role: "model",
toolCall: {
name: "divide",
args: { dividend: 8026256882, divisor: 3402398 },
result: { quotient: 2359 },
},
},
{
role: "model",
toolCall: {
name: "divide",
args: { dividend: 1185835515, divisor: 348263 },
result: { quotient: 3405 },
},
},
{
role: "model",
toolCall: {
name: "subtract",
args: { minuend: 90135094495, subtrahend: 89944954350 },
result: { quotient: 190140145 },
},
},
],
})
)
expect(result.text).toEqual(
[
"1. 8026256882 divided by 3402398 is 2359.",
"2. 1185835515 divided by 348263 is 3405.",
"3. 90135094495 minus 89944954350 is 190140145.",
"",
].join("\n")
)
})
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
step1Request:
generationConfig:
maxOutputTokens: 4096
temperature: 0.5
topP: 0.95
contents:
- role: user
parts:
- text: |-
Calculate each of the following:
1. 8026256882 divided by 3402398
2. 1185835515 divided by 348263
3. 90135094495 minus 89944954350
- role: model
parts:
- functionCall:
name: divide
args:
dividend: 8026256882
divisor: 3402398
- role: model
parts:
- functionResponse:
name: divide
response:
quotient: 2359
- role: model
parts:
- functionCall:
name: divide
args:
dividend: 1185835515
divisor: 348263
- role: model
parts:
- functionResponse:
name: divide
response:
quotient: 3405
- role: model
parts:
- functionCall:
name: subtract
args:
minuend: 90135094495
subtrahend: 89944954350
- role: model
parts:
- functionResponse:
name: subtract
response:
quotient: 190140145
tools:
- functionDeclarations:
- name: subtract
description: Calculate a subtraction
parameters:
type: object
properties:
minuend:
type: number
description: The number to subtract from
subtrahend:
type: number
description: The number to subtract by
required:
- minuend
- subtrahend
- name: divide
description: Calculate a division
parameters:
type: object
properties:
dividend:
type: number
description: The number to be divided
divisor:
type: number
description: The number to divide by
required:
- dividend
- divisor
step2RawChunks:
- content:
role: model
parts:
- text: "1"
index: 0
- content:
role: model
parts:
- text: . 8026256882 divided by 3
index: 0
- content:
role: model
parts:
- text: |-
402398 is 2359.
2.
index: 0
- content:
role: model
parts:
- text: |2-
1185835515 divided by 348263 is 3405.
3. 9
index: 0
- content:
role: model
parts:
- text: 0135094495 minus 89944954350 is 1901401
index: 0
- content:
role: model
parts:
- text: |
45.
index: 0
- content:
role: model
parts:
- text: ""
finishReason: STOP
index: 0
step3KurtEvents:
- chunk: "1"
- chunk: . 8026256882 divided by 3
- chunk: |-
402398 is 2359.
2.
- chunk: |2-
1185835515 divided by 348263 is 3405.
3. 9
- chunk: 0135094495 minus 89944954350 is 1901401
- chunk: |
45.
- finished: true
text: |
1. 8026256882 divided by 3402398 is 2359.
2. 1185835515 divided by 348263 is 3405.
3. 90135094495 minus 89944954350 is 190140145.
Loading
Loading