Skip to content

Commit

Permalink
Merge pull request #8 from FormulaMonks/refactor/stream-type-args
Browse files Browse the repository at this point in the history
refactor: Simplify type params of KurtStream (and refactor surrounding to match)
  • Loading branch information
jemc authored May 6, 2024
2 parents 4f4e0bb + a5fcdf7 commit 11b9ff0
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 194 deletions.
16 changes: 10 additions & 6 deletions src/Kurt.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import type { KurtStream } from "./KurtStream"
import type { KurtSchema, KurtSchemaInner } from "./KurtSchema"
import type {
KurtSchema,
KurtSchemaInner,
KurtSchemaResult,
} from "./KurtSchema"

export interface Kurt {
generateNaturalLanguage(
options: KurtGenerateNaturalLanguageOptions
): KurtStream

generateStructuredData<T extends KurtSchemaInner>(
options: KurtGenerateStructuredDataOptions<T>
): KurtStream<T>
generateStructuredData<I extends KurtSchemaInner>(
options: KurtGenerateStructuredDataOptions<I>
): KurtStream<KurtSchemaResult<I>>
}

export interface KurtMessage {
Expand All @@ -26,7 +30,7 @@ export interface KurtGenerateNaturalLanguageOptions {
extraMessages?: KurtMessage[]
}

export type KurtGenerateStructuredDataOptions<T extends KurtSchemaInner> =
export type KurtGenerateStructuredDataOptions<I extends KurtSchemaInner> =
KurtGenerateNaturalLanguageOptions & {
schema: KurtSchema<T>
schema: KurtSchema<I>
}
152 changes: 74 additions & 78 deletions src/KurtOpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ import type {
} from "./Kurt"
import { KurtStream, type KurtStreamEvent } from "./KurtStream"
import type {
KurtSchema,
KurtSchemaInner,
KurtSchemaInnerMaybe,
KurtSchemaMaybe,
KurtSchemaResult,
KurtSchemaResultMaybe,
} from "./KurtSchema"
import type { OpenAI, OpenAIMessage, OpenAIResponse } from "./OpenAI.types"

Expand Down Expand Up @@ -38,94 +41,49 @@ export class KurtOpenAI implements Kurt {
generateNaturalLanguage(
options: KurtGenerateNaturalLanguageOptions
): KurtStream {
return this.handleStream(
undefined,
this.options.openAI.chat.completions.create({
stream: true,
model: this.options.model,
messages: this.toOpenAIMessages(options),
})
return new KurtStream(
transformStream(
undefined,
this.options.openAI.chat.completions.create({
stream: true,
model: this.options.model,
messages: this.toOpenAIMessages(options),
})
)
)
}

generateStructuredData<T extends KurtSchemaInner>(
options: KurtGenerateStructuredDataOptions<T>
): KurtStream<T> {
generateStructuredData<I extends KurtSchemaInner>(
options: KurtGenerateStructuredDataOptions<I>
): KurtStream<KurtSchemaResult<I>> {
const schema = options.schema

return this.handleStream(
schema as KurtSchemaMaybe<T>,
this.options.openAI.chat.completions.create({
stream: true,
model: this.options.model,
messages: this.toOpenAIMessages(options),
tool_choice: {
type: "function",
function: { name: "structured_data" },
},
tools: [
{
return new KurtStream(
transformStream(
schema,
this.options.openAI.chat.completions.create({
stream: true,
model: this.options.model,
messages: this.toOpenAIMessages(options),
tool_choice: {
type: "function",
function: {
name: "structured_data",
description: schema.description,
parameters: zodToJsonSchema(schema),
},
function: { name: "structured_data" },
},
],
})
tools: [
{
type: "function",
function: {
name: "structured_data",
description: schema.description,
parameters: zodToJsonSchema(schema),
},
},
],
})
)
)
}

private handleStream<T extends KurtSchemaInnerMaybe>(
schema: KurtSchemaMaybe<T>,
response: OpenAIResponse
): KurtStream<T> {
async function* generator<T extends KurtSchemaInnerMaybe>() {
const stream = await response
const chunks: string[] = []

for await (const streamChunk of stream) {
const choice = streamChunk.choices[0]
if (!choice) continue

const textChunk = choice.delta.content
if (textChunk) {
yield { chunk: textChunk } as KurtStreamEvent<T>
chunks.push(textChunk)
}

const dataChunk = choice.delta.tool_calls?.at(0)?.function?.arguments
if (dataChunk) {
yield { chunk: dataChunk } as KurtStreamEvent<T>
chunks.push(dataChunk)
}

const isFinal = choice.finish_reason !== null

if (isFinal) {
const text = chunks.join("")
if (schema) {
const data = schema?.parse(JSON.parse(chunks.join("")))
yield {
finished: true,
text,
data,
} as KurtStreamEvent<T>
} else {
yield {
finished: true,
text,
data: undefined,
} as KurtStreamEvent<T>
}
}
}
}

return new KurtStream<T>(generator())
}

private toOpenAIMessages = ({
prompt,
systemPrompt = this.options.systemPrompt,
Expand Down Expand Up @@ -153,3 +111,41 @@ const openAIRoleMapping = {
system: "system",
user: "user",
} as const satisfies Record<KurtMessage["role"], OpenAIMessage["role"]>

async function* transformStream<
I extends KurtSchemaInnerMaybe,
S extends KurtSchemaMaybe<I>,
D extends KurtSchemaResultMaybe<I>,
>(schema: S, response: OpenAIResponse): AsyncGenerator<KurtStreamEvent<D>> {
const stream = await response
const chunks: string[] = []

for await (const streamChunk of stream) {
const choice = streamChunk.choices[0]
if (!choice) continue

const textChunk = choice.delta.content
if (textChunk) {
chunks.push(textChunk)
yield { chunk: textChunk }
}

const dataChunk = choice.delta.tool_calls?.at(0)?.function?.arguments
if (dataChunk) {
chunks.push(dataChunk)
yield { chunk: dataChunk }
}

const isFinal = choice.finish_reason !== null

if (isFinal) {
const text = chunks.join("")
if (schema) {
const data = schema.parse(JSON.parse(text)) as D
yield { finished: true, text, data }
} else {
yield { finished: true, text, data: undefined } as KurtStreamEvent<D>
}
}
}
}
12 changes: 6 additions & 6 deletions src/KurtSchema.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import type { ZodObject, ZodRawShape, infer as zodInfer } from "zod"

export type KurtSchemaInner = ZodRawShape
export type KurtSchema<T extends KurtSchemaInner> = ZodObject<T>
export type KurtSchemaResult<T extends KurtSchemaInner> = zodInfer<ZodObject<T>>
export type KurtSchema<I extends KurtSchemaInner> = ZodObject<I>
export type KurtSchemaResult<I extends KurtSchemaInner> = zodInfer<ZodObject<I>>

export type KurtSchemaInnerMaybe = KurtSchemaInner | undefined
export type KurtSchemaMaybe<T extends KurtSchemaInnerMaybe> =
T extends KurtSchemaInner ? KurtSchema<T> : undefined
export type KurtSchemaResultMaybe<T extends KurtSchemaInnerMaybe> =
T extends KurtSchemaInner ? KurtSchemaResult<T> : undefined
export type KurtSchemaMaybe<I extends KurtSchemaInnerMaybe> =
I extends KurtSchemaInner ? KurtSchema<I> : undefined
export type KurtSchemaResultMaybe<I extends KurtSchemaInnerMaybe> =
I extends KurtSchemaInner ? KurtSchemaResult<I> : undefined
37 changes: 18 additions & 19 deletions src/KurtStream.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import type { Promisable } from "type-fest"
import type { KurtSchemaInnerMaybe, KurtSchemaResultMaybe } from "./KurtSchema"

export type KurtStreamEventChunk = { chunk: string }
export type KurtResult<T extends KurtSchemaInnerMaybe = undefined> = {
export type KurtResult<D = undefined> = {
finished: true
text: string
data: KurtSchemaResultMaybe<T>
data: D
}
export type KurtStreamEvent<T extends KurtSchemaInnerMaybe = undefined> =
export type KurtStreamEvent<D = undefined> =
| KurtStreamEventChunk
| KurtResult<T>
| KurtResult<D>

type _AdditionalListener<T extends KurtSchemaInnerMaybe = undefined> = (
event: KurtStreamEvent<T> | { uncaughtError: unknown }
type _AdditionalListener<D = undefined> = (
event: KurtStreamEvent<D> | { uncaughtError: unknown }
) => void

// This class represents the result of a call to an LLM.
Expand All @@ -27,20 +26,20 @@ type _AdditionalListener<T extends KurtSchemaInnerMaybe = undefined> = (
//
// It also exposes a `result` convenience getter for callers who are only
// interested in the final result event.
export class KurtStream<T extends KurtSchemaInnerMaybe = undefined>
implements AsyncIterable<KurtStreamEvent<T>>
export class KurtStream<D = undefined>
implements AsyncIterable<KurtStreamEvent<D>>
{
private started = false
private finished = false
private seenEvents: KurtStreamEvent<T>[] = []
private seenEvents: KurtStreamEvent<D>[] = []
private finalError?: { uncaughtError: unknown }
private additionalListeners = new Set<_AdditionalListener<T>>()
private additionalListeners = new Set<_AdditionalListener<D>>()

// Create a new result stream, from the given underlying stream generator.
constructor(private gen: AsyncGenerator<KurtStreamEvent<T>>) {}
constructor(private gen: AsyncGenerator<KurtStreamEvent<D>>) {}

// Get the final event from the end of the result stream, when it is ready.
get result(): Promise<KurtResult<T>> {
get result(): Promise<KurtResult<D>> {
return toFinal(this)
}

Expand Down Expand Up @@ -107,10 +106,10 @@ export class KurtStream<T extends KurtSchemaInnerMaybe = undefined>

// To make this generator work, we need to set up a replaceable promise
// that will receive the next event (or error) via the listener callback.
let nextEventResolve: (value: Promisable<KurtStreamEvent<T>>) => void
let nextEventResolve: (value: Promisable<KurtStreamEvent<D>>) => void
let nextEventReject: (reason?: unknown) => void
const createNextEventPromise = () => {
return new Promise<KurtStreamEvent<T>>((resolve, reject) => {
return new Promise<KurtStreamEvent<D>>((resolve, reject) => {
nextEventResolve = resolve
nextEventReject = reject
})
Expand All @@ -123,7 +122,7 @@ export class KurtStream<T extends KurtSchemaInnerMaybe = undefined>
// Each time we receive an event we're going to resolve (or reject)
// the current promise, then we will replace the promise (and its
// associated resolve/reject functions), using closures.
const listener: _AdditionalListener<T> = (event) => {
const listener: _AdditionalListener<D> = (event) => {
if ("uncaughtError" in event) nextEventReject(event.uncaughtError)
else nextEventResolve(event)
nextEvent = createNextEventPromise()
Expand All @@ -145,9 +144,9 @@ export class KurtStream<T extends KurtSchemaInnerMaybe = undefined>
}
}

async function toFinal<T extends KurtSchemaInnerMaybe = undefined>(
stream: KurtStream<T>
): Promise<KurtResult<T>> {
async function toFinal<D = undefined>(
stream: KurtStream<D>
): Promise<KurtResult<D>> {
for await (const event of stream) {
if ("finished" in event) {
return event
Expand Down
Loading

0 comments on commit 11b9ff0

Please sign in to comment.