diff --git a/packages/kurt-cache/spec/KurtCache.spec.ts b/packages/kurt-cache/spec/KurtCache.spec.ts index 1a288c0..7b79510 100644 --- a/packages/kurt-cache/spec/KurtCache.spec.ts +++ b/packages/kurt-cache/spec/KurtCache.spec.ts @@ -33,11 +33,12 @@ const cacheDirRetain = `${cacheDir}-retain` async function gen( kurt: Kurt, prompt: string, - schema?: KurtSchema + schema?: KurtSchema, + sampling?: KurtSamplingOptions ) { const stream = schema - ? kurt.generateStructuredData({ prompt, schema }) - : kurt.generateNaturalLanguage({ prompt }) + ? kurt.generateStructuredData({ prompt, schema, sampling }) + : kurt.generateNaturalLanguage({ prompt, sampling }) const result = await stream.result return schema ? result.data : result.text } @@ -59,6 +60,7 @@ describe("KurtCache", () => { return new StubAdapter([ ["World ", random], ["bar ", random], + ["bar2 ", random], ["never ", random], ]) }) @@ -71,13 +73,17 @@ describe("KurtCache", () => { expect(await gen(kurt, `foo ${random}`)).toEqual(`bar ${random}`) expect(await gen(kurt, `Hello ${random}`)).toEqual(`World ${random}`) expect(await gen(kurt, `foo ${random}`)).toEqual(`bar ${random}`) - expect(await gen(kurt, `foo ${random}`)).toEqual(`bar ${random}`) + expect( + await gen(kurt, `foo ${random}`, undefined, { + forceSchemaConstrainedTokens: true, + }) + ).toEqual(`bar2 ${random}`) // Expect that the adapter setup function was called just once. expect(adapterFnCallCount).toEqual(1) - // Expect that the cache dir contains exactly two files. - expect(readdirSync(cacheDir)).toHaveLength(2) + // Expect that the cache dir contains exactly three files. + expect(readdirSync(cacheDir)).toHaveLength(3) }) test("when cache hits, works without running the adapter fn", async () => { diff --git a/packages/kurt-cache/src/KurtCache.ts b/packages/kurt-cache/src/KurtCache.ts index 00d4116..1b637eb 100644 --- a/packages/kurt-cache/src/KurtCache.ts +++ b/packages/kurt-cache/src/KurtCache.ts @@ -305,6 +305,11 @@ function hashSamplingOptions(digest: Hash, options: KurtSamplingOptions): Hash { mayHash(digest, "maxOutputTokens", options.maxOutputTokens) mayHash(digest, "temperature", options.temperature) mayHash(digest, "topP", options.topP) + mayHash( + digest, + "forceSchemaConstrainedTokens", + options.forceSchemaConstrainedTokens + ) return digest } @@ -357,10 +362,11 @@ function hashSchema(digest: Hash, schema: KurtSchema) { function mayHash( digest: Hash, key: string, - value: string | number | undefined + value: string | number | boolean | undefined ) { - if (value === undefined) return + if (value === undefined || value === false) return digest.update(key) + if (value === true) return if (typeof value === "string") digest.update(value) else digest.update(value.toString()) }