diff --git a/src/Kurt.ts b/src/Kurt.ts index 98d677e..79013ad 100644 --- a/src/Kurt.ts +++ b/src/Kurt.ts @@ -1,14 +1,18 @@ import type { KurtStream } from "./KurtStream" -import type { KurtSchema, KurtSchemaInner } from "./KurtSchema" +import type { + KurtSchema, + KurtSchemaInner, + KurtSchemaResult, +} from "./KurtSchema" export interface Kurt { generateNaturalLanguage( options: KurtGenerateNaturalLanguageOptions ): KurtStream - generateStructuredData( - options: KurtGenerateStructuredDataOptions - ): KurtStream + generateStructuredData( + options: KurtGenerateStructuredDataOptions + ): KurtStream> } export interface KurtMessage { @@ -26,7 +30,7 @@ export interface KurtGenerateNaturalLanguageOptions { extraMessages?: KurtMessage[] } -export type KurtGenerateStructuredDataOptions = +export type KurtGenerateStructuredDataOptions = KurtGenerateNaturalLanguageOptions & { - schema: KurtSchema + schema: KurtSchema } diff --git a/src/KurtOpenAI.ts b/src/KurtOpenAI.ts index b48c88d..f1500b8 100644 --- a/src/KurtOpenAI.ts +++ b/src/KurtOpenAI.ts @@ -8,9 +8,12 @@ import type { } from "./Kurt" import { KurtStream, type KurtStreamEvent } from "./KurtStream" import type { + KurtSchema, KurtSchemaInner, KurtSchemaInnerMaybe, KurtSchemaMaybe, + KurtSchemaResult, + KurtSchemaResultMaybe, } from "./KurtSchema" import type { OpenAI, OpenAIMessage, OpenAIResponse } from "./OpenAI.types" @@ -38,94 +41,49 @@ export class KurtOpenAI implements Kurt { generateNaturalLanguage( options: KurtGenerateNaturalLanguageOptions ): KurtStream { - return this.handleStream( - undefined, - this.options.openAI.chat.completions.create({ - stream: true, - model: this.options.model, - messages: this.toOpenAIMessages(options), - }) + return new KurtStream( + transformStream( + undefined, + this.options.openAI.chat.completions.create({ + stream: true, + model: this.options.model, + messages: this.toOpenAIMessages(options), + }) + ) ) } - generateStructuredData( - options: KurtGenerateStructuredDataOptions - ): KurtStream { + generateStructuredData( + options: KurtGenerateStructuredDataOptions + ): KurtStream> { const schema = options.schema - return this.handleStream( - schema as KurtSchemaMaybe, - this.options.openAI.chat.completions.create({ - stream: true, - model: this.options.model, - messages: this.toOpenAIMessages(options), - tool_choice: { - type: "function", - function: { name: "structured_data" }, - }, - tools: [ - { + return new KurtStream( + transformStream( + schema, + this.options.openAI.chat.completions.create({ + stream: true, + model: this.options.model, + messages: this.toOpenAIMessages(options), + tool_choice: { type: "function", - function: { - name: "structured_data", - description: schema.description, - parameters: zodToJsonSchema(schema), - }, + function: { name: "structured_data" }, }, - ], - }) + tools: [ + { + type: "function", + function: { + name: "structured_data", + description: schema.description, + parameters: zodToJsonSchema(schema), + }, + }, + ], + }) + ) ) } - private handleStream( - schema: KurtSchemaMaybe, - response: OpenAIResponse - ): KurtStream { - async function* generator() { - const stream = await response - const chunks: string[] = [] - - for await (const streamChunk of stream) { - const choice = streamChunk.choices[0] - if (!choice) continue - - const textChunk = choice.delta.content - if (textChunk) { - yield { chunk: textChunk } as KurtStreamEvent - chunks.push(textChunk) - } - - const dataChunk = choice.delta.tool_calls?.at(0)?.function?.arguments - if (dataChunk) { - yield { chunk: dataChunk } as KurtStreamEvent - chunks.push(dataChunk) - } - - const isFinal = choice.finish_reason !== null - - if (isFinal) { - const text = chunks.join("") - if (schema) { - const data = schema?.parse(JSON.parse(chunks.join(""))) - yield { - finished: true, - text, - data, - } as KurtStreamEvent - } else { - yield { - finished: true, - text, - data: undefined, - } as KurtStreamEvent - } - } - } - } - - return new KurtStream(generator()) - } - private toOpenAIMessages = ({ prompt, systemPrompt = this.options.systemPrompt, @@ -153,3 +111,41 @@ const openAIRoleMapping = { system: "system", user: "user", } as const satisfies Record + +async function* transformStream< + I extends KurtSchemaInnerMaybe, + S extends KurtSchemaMaybe, + D extends KurtSchemaResultMaybe, +>(schema: S, response: OpenAIResponse): AsyncGenerator> { + const stream = await response + const chunks: string[] = [] + + for await (const streamChunk of stream) { + const choice = streamChunk.choices[0] + if (!choice) continue + + const textChunk = choice.delta.content + if (textChunk) { + chunks.push(textChunk) + yield { chunk: textChunk } + } + + const dataChunk = choice.delta.tool_calls?.at(0)?.function?.arguments + if (dataChunk) { + chunks.push(dataChunk) + yield { chunk: dataChunk } + } + + const isFinal = choice.finish_reason !== null + + if (isFinal) { + const text = chunks.join("") + if (schema) { + const data = schema.parse(JSON.parse(text)) as D + yield { finished: true, text, data } + } else { + yield { finished: true, text, data: undefined } as KurtStreamEvent + } + } + } +} diff --git a/src/KurtSchema.ts b/src/KurtSchema.ts index 28b8b04..555dc82 100644 --- a/src/KurtSchema.ts +++ b/src/KurtSchema.ts @@ -1,11 +1,11 @@ import type { ZodObject, ZodRawShape, infer as zodInfer } from "zod" export type KurtSchemaInner = ZodRawShape -export type KurtSchema = ZodObject -export type KurtSchemaResult = zodInfer> +export type KurtSchema = ZodObject +export type KurtSchemaResult = zodInfer> export type KurtSchemaInnerMaybe = KurtSchemaInner | undefined -export type KurtSchemaMaybe = - T extends KurtSchemaInner ? KurtSchema : undefined -export type KurtSchemaResultMaybe = - T extends KurtSchemaInner ? KurtSchemaResult : undefined +export type KurtSchemaMaybe = + I extends KurtSchemaInner ? KurtSchema : undefined +export type KurtSchemaResultMaybe = + I extends KurtSchemaInner ? KurtSchemaResult : undefined diff --git a/src/KurtStream.ts b/src/KurtStream.ts index 07517ce..cb65e19 100644 --- a/src/KurtStream.ts +++ b/src/KurtStream.ts @@ -1,18 +1,17 @@ import type { Promisable } from "type-fest" -import type { KurtSchemaInnerMaybe, KurtSchemaResultMaybe } from "./KurtSchema" export type KurtStreamEventChunk = { chunk: string } -export type KurtResult = { +export type KurtResult = { finished: true text: string - data: KurtSchemaResultMaybe + data: D } -export type KurtStreamEvent = +export type KurtStreamEvent = | KurtStreamEventChunk - | KurtResult + | KurtResult -type _AdditionalListener = ( - event: KurtStreamEvent | { uncaughtError: unknown } +type _AdditionalListener = ( + event: KurtStreamEvent | { uncaughtError: unknown } ) => void // This class represents the result of a call to an LLM. @@ -27,20 +26,20 @@ type _AdditionalListener = ( // // It also exposes a `result` convenience getter for callers who are only // interested in the final result event. -export class KurtStream - implements AsyncIterable> +export class KurtStream + implements AsyncIterable> { private started = false private finished = false - private seenEvents: KurtStreamEvent[] = [] + private seenEvents: KurtStreamEvent[] = [] private finalError?: { uncaughtError: unknown } - private additionalListeners = new Set<_AdditionalListener>() + private additionalListeners = new Set<_AdditionalListener>() // Create a new result stream, from the given underlying stream generator. - constructor(private gen: AsyncGenerator>) {} + constructor(private gen: AsyncGenerator>) {} // Get the final event from the end of the result stream, when it is ready. - get result(): Promise> { + get result(): Promise> { return toFinal(this) } @@ -107,10 +106,10 @@ export class KurtStream // To make this generator work, we need to set up a replaceable promise // that will receive the next event (or error) via the listener callback. - let nextEventResolve: (value: Promisable>) => void + let nextEventResolve: (value: Promisable>) => void let nextEventReject: (reason?: unknown) => void const createNextEventPromise = () => { - return new Promise>((resolve, reject) => { + return new Promise>((resolve, reject) => { nextEventResolve = resolve nextEventReject = reject }) @@ -123,7 +122,7 @@ export class KurtStream // Each time we receive an event we're going to resolve (or reject) // the current promise, then we will replace the promise (and its // associated resolve/reject functions), using closures. - const listener: _AdditionalListener = (event) => { + const listener: _AdditionalListener = (event) => { if ("uncaughtError" in event) nextEventReject(event.uncaughtError) else nextEventResolve(event) nextEvent = createNextEventPromise() @@ -145,9 +144,9 @@ export class KurtStream } } -async function toFinal( - stream: KurtStream -): Promise> { +async function toFinal( + stream: KurtStream +): Promise> { for await (const event of stream) { if ("finished" in event) { return event diff --git a/src/KurtVertexAI.ts b/src/KurtVertexAI.ts index 1eee83b..c11509c 100644 --- a/src/KurtVertexAI.ts +++ b/src/KurtVertexAI.ts @@ -8,12 +8,14 @@ import type { KurtGenerateStructuredDataOptions, KurtMessage, } from "./Kurt" -import { KurtStream, type KurtStreamEvent } from "./KurtStream" +import { KurtResult, KurtStream, type KurtStreamEvent } from "./KurtStream" import type { KurtSchema, KurtSchemaInner, KurtSchemaInnerMaybe, KurtSchemaMaybe, + KurtSchemaResult, + KurtSchemaResultMaybe, } from "./KurtSchema" import type { VertexAI, @@ -43,95 +45,47 @@ export class KurtVertexAI implements Kurt { model: this.options.model, }) as VertexAIGenerativeModel - return this.handleStream( - undefined, - llm.generateContentStreamPATCHED({ - contents: this.toVertexAIMessages(options), - }) + return new KurtStream( + transformStream( + undefined, + llm.generateContentStreamPATCHED({ + contents: this.toVertexAIMessages(options), + }) + ) ) } - generateStructuredData( - options: KurtGenerateStructuredDataOptions - ): KurtStream { + generateStructuredData( + options: KurtGenerateStructuredDataOptions + ): KurtStream> { const schema = options.schema const llm = this.options.vertexAI.getGenerativeModel({ model: this.options.model, }) as VertexAIGenerativeModel - return this.handleStream( - schema as KurtSchemaMaybe, - llm.generateContentStreamPATCHED({ - contents: this.toVertexAIMessages(options), - tool_config: { function_calling_config: { mode: "ANY" } }, - tools: [ - { - functionDeclarations: [ - { - name: "structured_data", - description: schema.description, - parameters: jsonSchemaForVertexAI(schema), - }, - ], - }, - ], - }) + return new KurtStream( + transformStream( + schema, + llm.generateContentStreamPATCHED({ + contents: this.toVertexAIMessages(options), + tool_config: { function_calling_config: { mode: "ANY" } }, + tools: [ + { + functionDeclarations: [ + { + name: "structured_data", + description: schema.description, + parameters: jsonSchemaForVertexAI(schema), + }, + ], + }, + ], + }) + ) ) } - private handleStream( - schema: KurtSchemaMaybe, - response: VertexAIResponse - ): KurtStream { - async function* generator() { - const { stream } = await response - const chunks: string[] = [] - - for await (const streamChunk of stream) { - const choice = streamChunk.candidates?.at(0) - if (!choice) continue - - const isContentFinal = choice.finishReason !== undefined - const { parts } = choice.content - - for (const [partIndex, part] of parts.entries()) { - const chunk = part.text - const isFinal = isContentFinal && partIndex === parts.length - 1 - const data = isFinal - ? applySchemaToFuzzyStructure(schema, part.functionCall) - : undefined - - if (chunk) { - yield { chunk } - chunks.push(chunk) - } - if (isFinal) { - if (data) { - const text = JSON.stringify(data) - yield { chunk: text } - yield { - finished: true, - text, - data, - } as KurtStreamEvent - } else { - const text = chunks.join("") - const data = undefined - yield { - finished: true, - text, - data, - } as KurtStreamEvent - } - } - } - } - } - - return new KurtStream(generator()) - } - private toVertexAIMessages = ({ prompt, systemPrompt = this.options.systemPrompt, @@ -190,19 +144,63 @@ function jsonSchemaForVertexAI( return schema as VertexAISchema } +async function* transformStream< + I extends KurtSchemaInnerMaybe, + S extends KurtSchemaMaybe, + D extends KurtSchemaResultMaybe, +>(schema: S, response: VertexAIResponse): AsyncGenerator> { + const { stream } = await response + const chunks: string[] = [] + + for await (const streamChunk of stream) { + const choice = streamChunk.candidates?.at(0) + if (!choice) continue + + const isContentFinal = choice.finishReason !== undefined + const { parts } = choice.content + + for (const [partIndex, part] of parts.entries()) { + const chunk = part.text + const isFinal = isContentFinal && partIndex === parts.length - 1 + + if (chunk) { + chunks.push(chunk) + yield { chunk } + } + if (isFinal) { + if (schema) { + const { functionCall } = part + if (!functionCall) { + throw new Error( + `Expected function call in final chunk, but got ${JSON.stringify( + part + )}` + ) + } + const data = applySchemaToFuzzyStructure(schema, functionCall) as D + const text = JSON.stringify(data) + yield { chunk: text } + yield { finished: true, text, data } + } else { + const text = chunks.join("") + yield { finished: true, text, data: undefined } as KurtStreamEvent + } + } + } + } +} + // Vertex AI sometimes gives wonky results that are nested weirdly. // This function tries to account for the different scenarios we've seen. // // If a new scenario is seen, we can add a test for it in KurtVertexAI.spec.ts // and then add new logic here as needed to handle the new scenario. -function applySchemaToFuzzyStructure( - schema: KurtSchemaMaybe, - input: { name: string; args: object } | undefined - // biome-ignore lint/suspicious/noExplicitAny: TODO: no any -): any { - if (schema === undefined || input === undefined) return undefined - +function applySchemaToFuzzyStructure( + schema: KurtSchema, + input: { name: string; args: object } +): KurtSchemaResult { const { name, args } = input + try { // First, try the most obvious case. return schema.parse(args)