From 92321427076ff8c2b8bfc90e282a1147f6bfa168 Mon Sep 17 00:00:00 2001 From: Santiago Kent Date: Sun, 5 May 2024 12:41:02 -0300 Subject: [PATCH] refactor the creation of messages for vertexAI to reduce repetition and increase readability --- src/KurtVertexAI.ts | 49 ++++++++++++++++++++------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/src/KurtVertexAI.ts b/src/KurtVertexAI.ts index 640a278..c0cd9bb 100644 --- a/src/KurtVertexAI.ts +++ b/src/KurtVertexAI.ts @@ -39,10 +39,6 @@ export class KurtVertexAI implements Kurt { generateNaturalLanguage( options: KurtGenerateNaturalLanguageOptions ): KurtResult { - const systemPrompt = options.systemPrompt ?? this.options.systemPrompt - const prompt = options.prompt - const extraMessages = options.extraMessages ?? [] - const llm = this.options.vertexAI.getGenerativeModel({ model: this.options.model, }) as VertexAIGenerativeModel @@ -50,13 +46,7 @@ export class KurtVertexAI implements Kurt { return this.handleStream( undefined, llm.generateContentStreamPATCHED({ - contents: [ - ...(systemPrompt - ? [{ role: "system", parts: [{ text: systemPrompt }] }] - : []), - { role: "user", parts: [{ text: prompt }] }, - ...toVertexAIMessages(extraMessages), - ], + contents: this.toVertexAIMessages(options), }) ) } @@ -64,10 +54,7 @@ export class KurtVertexAI implements Kurt { generateStructuredData( options: KurtGenerateStructuredDataOptions ): KurtResult { - const systemPrompt = options.systemPrompt ?? this.options.systemPrompt - const prompt = options.prompt const schema = options.schema - const extraMessages = options.extraMessages ?? [] const llm = this.options.vertexAI.getGenerativeModel({ model: this.options.model, @@ -76,13 +63,7 @@ export class KurtVertexAI implements Kurt { return this.handleStream( schema as KurtSchemaMaybe, llm.generateContentStreamPATCHED({ - contents: [ - ...(systemPrompt - ? [{ role: "system", parts: [{ text: systemPrompt }] }] - : []), - { role: "user", parts: [{ text: prompt }] }, - ...toVertexAIMessages(extraMessages), - ], + contents: this.toVertexAIMessages(options), tool_config: { function_calling_config: { mode: "ANY" } }, tools: [ { @@ -142,15 +123,29 @@ export class KurtVertexAI implements Kurt { return new KurtResult(generator()) } -} -function toVertexAIMessages(messages: KurtMessage[]): VertexAIMessage[] { - return messages.map((message) => { - const { role, text } = message - return { role, parts: [{ text }] } - }) + private toVertexAIMessages = ({ + prompt, + systemPrompt = this.options.systemPrompt, + extraMessages = [], + }: KurtGenerateNaturalLanguageOptions): VertexAIMessage[] => { + const systemMessage: VertexAIMessage[] = systemPrompt + ? [toVertexAIMessage({ role: "system", text: systemPrompt })] + : [] + + const userMessage = toVertexAIMessage({ role: "user", text: prompt }) + + const extras = extraMessages.map(toVertexAIMessage) + + return systemMessage.concat(userMessage, extras) + } } +const toVertexAIMessage = ({ role, text }: KurtMessage): VertexAIMessage => ({ + role, + parts: [{ text }], +}) + function jsonSchemaForVertexAI( zodSchema: KurtSchema ) {