diff --git a/dist/buildinfo.json b/dist/buildinfo.json index cd0e8e7e..e58b5c0d 100644 --- a/dist/buildinfo.json +++ b/dist/buildinfo.json @@ -1 +1 @@ -{"sha":"d1d9307","timestamp":1735019140} \ No newline at end of file +{"sha":"284e897","timestamp":1735026022} \ No newline at end of file diff --git a/dist/index.js b/dist/index.js index 8f01d079..488b2804 100644 --- a/dist/index.js +++ b/dist/index.js @@ -69,6 +69,7 @@ class WorkersConfig { WORKERS_CHAT_MODEL = "@cf/qwen/qwen1.5-7b-chat-awq"; WORKERS_IMAGE_MODEL = "@cf/black-forest-labs/flux-1-schnell"; WORKERS_CHAT_MODELS_LIST = ""; + WORKERS_IMAGE_MODELS_LIST = ""; } class GeminiConfig { GOOGLE_API_KEY = null; @@ -193,8 +194,8 @@ class ConfigMerger { } } } -const BUILD_TIMESTAMP = 1735019140; -const BUILD_VERSION = "d1d9307"; +const BUILD_TIMESTAMP = 1735026022; +const BUILD_VERSION = "284e897"; function createAgentUserConfig() { return Object.assign( {}, @@ -814,6 +815,75 @@ class MessageSender { return this.api.sendPhoto(params); } } +function extractTextContent(history) { + if (typeof history.content === "string") { + return history.content; + } + if (Array.isArray(history.content)) { + return history.content.map((item) => { + if (item.type === "text") { + return item.text; + } + return ""; + }).join(""); + } + return ""; +} +function extractImageContent(imageData) { + if (imageData instanceof URL) { + return { url: imageData.href }; + } + if (typeof imageData === "string") { + if (imageData.startsWith("http")) { + return { url: imageData }; + } else { + return { base64: imageData }; + } + } + if (typeof Buffer !== "undefined") { + if (imageData instanceof Uint8Array) { + return { base64: Buffer.from(imageData).toString("base64") }; + } + if (Buffer.isBuffer(imageData)) { + return { base64: Buffer.from(imageData).toString("base64") }; + } + } + return {}; +} +async function convertStringToResponseMessages(input) { + const text = await input; + return { + text, + responses: [{ role: "assistant", content: await input }] + }; +} +async function loadModelsList(raw, remoteLoader) { + if (!raw) { + return []; + } + if (raw.startsWith("[") && raw.endsWith("]")) { + try { + return JSON.parse(raw); + } catch (e) { + console.error(e); + return []; + } + } + if (raw.startsWith("http") && remoteLoader) { + return await remoteLoader(raw); + } + return [raw]; +} +function bearerHeader(token, stream) { + const res = { + "Authorization": `Bearer ${token}`, + "Content-Type": "application/json" + }; + if (stream !== void 0) { + res.Accept = stream ? "text/event-stream" : "application/json"; + } + return res; +} class Cache { maxItems; maxAge; @@ -894,6 +964,66 @@ async function imageToBase64String(url) { function renderBase64DataURI(params) { return `data:${params.format};base64,${params.data}`; } +var ImageSupportFormat = ((ImageSupportFormat2) => { + ImageSupportFormat2["URL"] = "url"; + ImageSupportFormat2["BASE64"] = "base64"; + return ImageSupportFormat2; +})(ImageSupportFormat || {}); +async function renderOpenAIMessage(item, supportImage) { + const res = { + role: item.role, + content: item.content + }; + if (Array.isArray(item.content)) { + const contents = []; + for (const content of item.content) { + switch (content.type) { + case "text": + contents.push({ type: "text", text: content.text }); + break; + case "image": + if (supportImage) { + const isSupportURL = supportImage.includes("url" ); + const isSupportBase64 = supportImage.includes("base64" ); + const data = extractImageContent(content.image); + if (data.url) { + if (ENV.TELEGRAM_IMAGE_TRANSFER_MODE === "base64" && isSupportBase64) { + contents.push(await imageToBase64String(data.url).then((data2) => { + return { type: "image_url", image_url: { url: renderBase64DataURI(data2) } }; + })); + } else if (isSupportURL) { + contents.push({ type: "image_url", image_url: { url: data.url } }); + } + } else if (data.base64 && isSupportBase64) { + contents.push({ type: "image_base64", image_base64: { base64: data.base64 } }); + } + } + break; + } + } + res.content = contents; + } + return res; +} +async function renderOpenAIMessages(prompt, items, supportImage) { + const messages = await Promise.all(items.map((r) => renderOpenAIMessage(r, supportImage))); + if (prompt) { + if (messages.length > 0 && messages[0].role === "system") { + messages.shift(); + } + messages.unshift({ role: "system", content: prompt }); + } + return messages; +} +function loadOpenAIModelList(list, base, headers) { + if (list === "") { + list = `${base}/models`; + } + return loadModelsList(list, async (url) => { + const data = await fetch(url, { headers }).then((res) => res.json()); + return data.data?.map((model) => model.id) || []; + }); +} class Stream { response; controller; @@ -1193,72 +1323,6 @@ async function requestChatCompletions(url, header, body, onStream, options) { } return await mapResponseToAnswer(resp, controller, options, onStream); } -function extractTextContent(history) { - if (typeof history.content === "string") { - return history.content; - } - if (Array.isArray(history.content)) { - return history.content.map((item) => { - if (item.type === "text") { - return item.text; - } - return ""; - }).join(""); - } - return ""; -} -function extractImageContent(imageData) { - if (imageData instanceof URL) { - return { url: imageData.href }; - } - if (typeof imageData === "string") { - if (imageData.startsWith("http")) { - return { url: imageData }; - } else { - return { base64: imageData }; - } - } - if (typeof Buffer !== "undefined") { - if (imageData instanceof Uint8Array) { - return { base64: Buffer.from(imageData).toString("base64") }; - } - if (Buffer.isBuffer(imageData)) { - return { base64: Buffer.from(imageData).toString("base64") }; - } - } - return {}; -} -async function convertStringToResponseMessages(input) { - const text = await input; - return { - text, - responses: [{ role: "assistant", content: await input }] - }; -} -async function loadModelsList(raw, remoteLoader) { - if (!raw) { - return []; - } - if (raw.startsWith("[") && raw.endsWith("]")) { - try { - return JSON.parse(raw); - } catch (e) { - console.error(e); - return []; - } - } - if (raw.startsWith("http") && remoteLoader) { - return await remoteLoader(raw); - } - return []; -} -function bearerHeader(token, stream = false) { - return { - "Authorization": `Bearer ${token}`, - "Content-Type": "application/json", - "Accept": stream ? "text/event-stream" : "application/json" - }; -} function anthropicHeader(context) { return { "x-api-key": context.ANTHROPIC_API_KEY || "", @@ -1269,10 +1333,10 @@ function anthropicHeader(context) { class Anthropic { name = "anthropic"; modelKey = "ANTHROPIC_CHAT_MODEL"; - enable = (context) => { - return !!context.ANTHROPIC_API_KEY; - }; - render = async (item) => { + enable = (ctx) => !!ctx.ANTHROPIC_API_KEY; + model = (ctx) => ctx.ANTHROPIC_CHAT_MODEL; + modelList = (ctx) => loadOpenAIModelList(ctx.ANTHROPIC_CHAT_MODELS_LIST, ctx.ANTHROPIC_API_BASE, anthropicHeader(ctx)); + static render = async (item) => { const res = { role: item.role, content: item.content @@ -1304,9 +1368,6 @@ class Anthropic { } return res; }; - model = (ctx) => { - return ctx.ANTHROPIC_CHAT_MODEL; - }; static parser(sse) { switch (sse.event) { case "content_block_delta": @@ -1336,7 +1397,7 @@ class Anthropic { const body = { system: prompt, model: context.ANTHROPIC_CHAT_MODEL, - messages: (await Promise.all(messages.map((item) => this.render(item)))).filter((i) => i !== null), + messages: (await Promise.all(messages.map((item) => Anthropic.render(item)))).filter((i) => i !== null), stream: onStream != null, max_tokens: ENV.MAX_TOKEN_LENGTH > 0 ? ENV.MAX_TOKEN_LENGTH : 2048 }; @@ -1358,138 +1419,6 @@ class Anthropic { }; return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, options)); }; - modelList = async (context) => { - if (context.ANTHROPIC_CHAT_MODELS_LIST === "") { - context.ANTHROPIC_CHAT_MODELS_LIST = `${context.ANTHROPIC_API_BASE}/models`; - } - return loadModelsList(context.ANTHROPIC_CHAT_MODELS_LIST, async (url) => { - const data = await fetch(url, { - headers: anthropicHeader(context) - }).then((res) => res.json()); - return data?.data?.map((model) => model.id) || []; - }); - }; -} -var ImageSupportFormat = ((ImageSupportFormat2) => { - ImageSupportFormat2["URL"] = "url"; - ImageSupportFormat2["BASE64"] = "base64"; - return ImageSupportFormat2; -})(ImageSupportFormat || {}); -async function renderOpenAIMessage(item, supportImage) { - const res = { - role: item.role, - content: item.content - }; - if (Array.isArray(item.content)) { - const contents = []; - for (const content of item.content) { - switch (content.type) { - case "text": - contents.push({ type: "text", text: content.text }); - break; - case "image": - if (supportImage) { - const isSupportURL = supportImage.includes("url" ); - const isSupportBase64 = supportImage.includes("base64" ); - const data = extractImageContent(content.image); - if (data.url) { - if (ENV.TELEGRAM_IMAGE_TRANSFER_MODE === "base64" && isSupportBase64) { - contents.push(await imageToBase64String(data.url).then((data2) => { - return { type: "image_url", image_url: { url: renderBase64DataURI(data2) } }; - })); - } else if (isSupportURL) { - contents.push({ type: "image_url", image_url: { url: data.url } }); - } - } else if (data.base64 && isSupportBase64) { - contents.push({ type: "image_base64", image_base64: { base64: data.base64 } }); - } - } - break; - } - } - res.content = contents; - } - return res; -} -async function renderOpenAIMessages(prompt, items, supportImage) { - const messages = await Promise.all(items.map((r) => renderOpenAIMessage(r, supportImage))); - if (prompt) { - if (messages.length > 0 && messages[0].role === "system") { - messages.shift(); - } - messages.unshift({ role: "system", content: prompt }); - } - return messages; -} -function openAIApiKey(context) { - const length = context.OPENAI_API_KEY.length; - return context.OPENAI_API_KEY[Math.floor(Math.random() * length)]; -} -class OpenAI { - name = "openai"; - modelKey = "OPENAI_CHAT_MODEL"; - enable = (context) => { - return context.OPENAI_API_KEY.length > 0; - }; - model = (ctx) => { - return ctx.OPENAI_CHAT_MODEL; - }; - request = async (params, context, onStream) => { - const { prompt, messages } = params; - const url = `${context.OPENAI_API_BASE}/chat/completions`; - const header = bearerHeader(openAIApiKey(context)); - const body = { - model: context.OPENAI_CHAT_MODEL, - ...context.OPENAI_API_EXTRA_PARAMS, - messages: await renderOpenAIMessages(prompt, messages, ["url" , "base64" ]), - stream: onStream != null - }; - return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, null)); - }; - modelList = async (context) => { - if (context.OPENAI_CHAT_MODELS_LIST === "") { - context.OPENAI_CHAT_MODELS_LIST = `${context.OPENAI_API_BASE}/models`; - } - return loadModelsList(context.OPENAI_CHAT_MODELS_LIST, async (url) => { - const data = await fetch(url, { - headers: bearerHeader(openAIApiKey(context)) - }).then((res) => res.json()); - return data.data?.map((model) => model.id) || []; - }); - }; -} -class Dalle { - name = "openai"; - modelKey = "OPENAI_DALLE_API"; - enable = (context) => { - return context.OPENAI_API_KEY.length > 0; - }; - model = (ctx) => { - return ctx.DALL_E_MODEL; - }; - request = async (prompt, context) => { - const url = `${context.OPENAI_API_BASE}/images/generations`; - const header = bearerHeader(openAIApiKey(context)); - const body = { - prompt, - n: 1, - size: context.DALL_E_IMAGE_SIZE, - model: context.DALL_E_MODEL - }; - if (body.model === "dall-e-3") { - body.quality = context.DALL_E_IMAGE_QUALITY; - body.style = context.DALL_E_IMAGE_STYLE; - } - const resp = await fetch(url, { - method: "POST", - headers: header, - body: JSON.stringify(body) - }).then((res) => res.json()); - if (resp.error?.message) { - throw new Error(resp.error.message); - } - return resp?.data?.at(0)?.url; - }; } function azureHeader(context) { return { @@ -1500,12 +1429,8 @@ function azureHeader(context) { class AzureChatAI { name = "azure"; modelKey = "AZURE_CHAT_MODEL"; - enable = (context) => { - return !!(context.AZURE_API_KEY && context.AZURE_RESOURCE_NAME); - }; - model = (ctx) => { - return ctx.AZURE_CHAT_MODEL; - }; + enable = (ctx) => !!(ctx.AZURE_API_KEY && ctx.AZURE_RESOURCE_NAME); + model = (ctx) => ctx.AZURE_CHAT_MODEL; request = async (params, context, onStream) => { const { prompt, messages } = params; const url = `https://${context.AZURE_RESOURCE_NAME}.openai.azure.com/openai/deployments/${context.AZURE_CHAT_MODEL}/chat/completions?api-version=${context.AZURE_API_VERSION}`; @@ -1532,12 +1457,9 @@ class AzureChatAI { class AzureImageAI { name = "azure"; modelKey = "AZURE_DALLE_API"; - enable = (context) => { - return !!(context.AZURE_API_KEY && context.AZURE_DALLE_API); - }; - model = (ctx) => { - return ctx.AZURE_IMAGE_MODEL; - }; + enable = (ctx) => !!(ctx.AZURE_API_KEY && ctx.AZURE_DALLE_API); + model = (ctx) => ctx.AZURE_IMAGE_MODEL; + modelList = (ctx) => Promise.resolve([ctx.AZURE_IMAGE_MODEL]); request = async (prompt, context) => { const url = `https://${context.AZURE_RESOURCE_NAME}.openai.azure.com/openai/deployments/${context.AZURE_IMAGE_MODEL}/images/generations?api-version=${context.AZURE_API_VERSION}`; const header = azureHeader(context); @@ -1566,12 +1488,8 @@ class AzureImageAI { class Cohere { name = "cohere"; modelKey = "COHERE_CHAT_MODEL"; - enable = (context) => { - return !!context.COHERE_API_KEY; - }; - model = (ctx) => { - return ctx.COHERE_CHAT_MODEL; - }; + enable = (ctx) => !!ctx.COHERE_API_KEY; + model = (ctx) => ctx.COHERE_CHAT_MODEL; request = async (params, context, onStream) => { const { prompt, messages } = params; const url = `${context.COHERE_API_BASE}/chat`; @@ -1600,7 +1518,7 @@ class Cohere { } return loadModelsList(context.COHERE_CHAT_MODELS_LIST, async (url) => { const data = await fetch(url, { - headers: bearerHeader(context.COHERE_API_KEY, false) + headers: bearerHeader(context.COHERE_API_KEY) }).then((res) => res.json()); return data.models?.filter((model) => model.endpoints?.includes("chat")).map((model) => model.name) || []; }); @@ -1609,12 +1527,8 @@ class Cohere { class Gemini { name = "gemini"; modelKey = "GOOGLE_COMPLETIONS_MODEL"; - enable = (context) => { - return !!context.GOOGLE_API_KEY; - }; - model = (ctx) => { - return ctx.GOOGLE_COMPLETIONS_MODEL; - }; + enable = (ctx) => !!ctx.GOOGLE_API_KEY; + model = (ctx) => ctx.GOOGLE_COMPLETIONS_MODEL; request = async (params, context, onStream) => { const { prompt, messages } = params; const url = `${context.GOOGLE_API_BASE}/openai/chat/completions`; @@ -1639,12 +1553,9 @@ class Gemini { class Mistral { name = "mistral"; modelKey = "MISTRAL_CHAT_MODEL"; - enable = (context) => { - return !!context.MISTRAL_API_KEY; - }; - model = (ctx) => { - return ctx.MISTRAL_CHAT_MODEL; - }; + enable = (ctx) => !!ctx.MISTRAL_API_KEY; + model = (ctx) => ctx.MISTRAL_CHAT_MODEL; + modelList = (ctx) => loadOpenAIModelList(ctx.MISTRAL_CHAT_MODELS_LIST, ctx.MISTRAL_API_BASE, bearerHeader(ctx.MISTRAL_API_KEY)); request = async (params, context, onStream) => { const { prompt, messages } = params; const url = `${context.MISTRAL_API_BASE}/chat/completions`; @@ -1656,27 +1567,59 @@ class Mistral { }; return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, null)); }; - modelList = async (context) => { - if (context.MISTRAL_CHAT_MODELS_LIST === "") { - context.MISTRAL_CHAT_MODELS_LIST = `${context.MISTRAL_API_BASE}/models`; - } - return loadModelsList(context.MISTRAL_CHAT_MODELS_LIST, async (url) => { - const data = await fetch(url, { - headers: bearerHeader(context.MISTRAL_API_KEY) - }).then((res) => res.json()); - return data.data?.map((model) => model.id) || []; - }); +} +function openAIApiKey(context) { + const length = context.OPENAI_API_KEY.length; + return context.OPENAI_API_KEY[Math.floor(Math.random() * length)]; +} +class OpenAI { + name = "openai"; + modelKey = "OPENAI_CHAT_MODEL"; + enable = (ctx) => ctx.OPENAI_API_KEY.length > 0; + model = (ctx) => ctx.OPENAI_CHAT_MODEL; + modelList = (ctx) => loadOpenAIModelList(ctx.OPENAI_CHAT_MODELS_LIST, ctx.OPENAI_API_BASE, bearerHeader(openAIApiKey(ctx))); + request = async (params, context, onStream) => { + const { prompt, messages } = params; + const url = `${context.OPENAI_API_BASE}/chat/completions`; + const header = bearerHeader(openAIApiKey(context)); + const body = { + model: context.OPENAI_CHAT_MODEL, + ...context.OPENAI_API_EXTRA_PARAMS, + messages: await renderOpenAIMessages(prompt, messages, [ImageSupportFormat.URL, ImageSupportFormat.BASE64]), + stream: onStream != null + }; + return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, null)); }; } -async function sendWorkerRequest(model, body, id, token) { - return await fetch( - `https://api.cloudflare.com/client/v4/accounts/${id}/ai/run/${model}`, - { - headers: { Authorization: `Bearer ${token}` }, +class Dalle { + name = "openai"; + modelKey = "DALL_E_MODEL"; + enable = (ctx) => ctx.OPENAI_API_KEY.length > 0; + model = (ctx) => ctx.DALL_E_MODEL; + modelList = (ctx) => Promise.resolve([ctx.DALL_E_MODEL]); + request = async (prompt, context) => { + const url = `${context.OPENAI_API_BASE}/images/generations`; + const header = bearerHeader(openAIApiKey(context)); + const body = { + prompt, + n: 1, + size: context.DALL_E_IMAGE_SIZE, + model: context.DALL_E_MODEL + }; + if (body.model === "dall-e-3") { + body.quality = context.DALL_E_IMAGE_QUALITY; + body.style = context.DALL_E_IMAGE_STYLE; + } + const resp = await fetch(url, { method: "POST", + headers: header, body: JSON.stringify(body) + }).then((res) => res.json()); + if (resp.error?.message) { + throw new Error(resp.error.message); } - ); + return resp?.data?.at(0)?.url; + }; } function isWorkerAIEnable(context) { if (ENV.AI_BINDING) { @@ -1684,44 +1627,29 @@ function isWorkerAIEnable(context) { } return !!(context.CLOUDFLARE_ACCOUNT_ID && context.CLOUDFLARE_TOKEN); } -function mapAiTextGenerationOutput2Response(output, stream) { - if (stream && output instanceof ReadableStream) { - return new Response(output, { - headers: { "content-type": "text/event-stream" } - }); - } else { - return Response.json({ result: output }); - } -} -function mapAiTextToImageOutput2Response(output) { - if (output instanceof ReadableStream) { - return new Response(output, { - headers: { - "content-type": "image/jpg" - } - }); - } else { - return Response.json({ result: output }); - } -} -async function mapResponseToImage(output) { - if (isJsonResponse(output)) { - const { result } = await output.json(); - const image = result?.image; - if (typeof image !== "string") { - throw new TypeError("Invalid image response"); +function loadWorkersModelList(task, loader) { + return async (context) => { + let uri = loader(context); + if (uri === "") { + const id = context.CLOUDFLARE_ACCOUNT_ID; + const taskEncoded = encodeURIComponent(task); + uri = `https://api.cloudflare.com/client/v4/accounts/${id}/ai/models/search?task=${taskEncoded}`; } - return base64StringToBlob(image); - } - return await output.blob(); + return loadModelsList(uri, async (url) => { + const header = { + Authorization: `Bearer ${context.CLOUDFLARE_TOKEN}` + }; + const data = await fetch(url, { headers: header }).then((res) => res.json()); + return data.result?.map((model) => model.name) || []; + }); + }; } class WorkersChat { name = "workers"; modelKey = "WORKERS_CHAT_MODEL"; enable = isWorkerAIEnable; - model = (ctx) => { - return ctx.WORKERS_CHAT_MODEL; - }; + model = (ctx) => ctx.WORKERS_CHAT_MODEL; + modelList = loadWorkersModelList("Text Generation", (ctx) => ctx.WORKERS_CHAT_MODELS_LIST); request = async (params, context, onStream) => { const { prompt, messages } = params; const model = context.WORKERS_CHAT_MODEL; @@ -1741,7 +1669,7 @@ class WorkersChat { }; if (ENV.AI_BINDING) { const answer = await ENV.AI_BINDING.run(model, body); - const response = mapAiTextGenerationOutput2Response(answer, onStream !== null); + const response = WorkersChat.outputToResponse(answer, onStream !== null); return convertStringToResponseMessages(mapResponseToAnswer(response, new AbortController(), options, onStream)); } else if (context.CLOUDFLARE_ACCOUNT_ID && context.CLOUDFLARE_TOKEN) { const id = context.CLOUDFLARE_ACCOUNT_ID; @@ -1753,49 +1681,76 @@ class WorkersChat { throw new Error("Cloudflare account ID and token are required"); } }; - modelList = async (context) => { - if (context.WORKERS_CHAT_MODELS_LIST === "") { - const id = context.CLOUDFLARE_ACCOUNT_ID; - context.WORKERS_CHAT_MODELS_LIST = `https://api.cloudflare.com/client/v4/accounts/${id}/ai/models/search?task=Text%20Generation`; + static outputToResponse(output, stream) { + if (stream && output instanceof ReadableStream) { + return new Response(output, { + headers: { "content-type": "text/event-stream" } + }); + } else { + return Response.json({ result: output }); } - return loadModelsList(context.WORKERS_CHAT_MODELS_LIST, async (url) => { - const header = { - Authorization: `Bearer ${context.CLOUDFLARE_TOKEN}` - }; - const data = await fetch(url, { headers: header }).then((res) => res.json()); - return data.result?.map((model) => model.name) || []; - }); - }; + } } class WorkersImage { name = "workers"; modelKey = "WORKERS_IMAGE_MODEL"; enable = isWorkerAIEnable; - model = (ctx) => { - return ctx.WORKERS_IMAGE_MODEL; - }; + model = (ctx) => ctx.WORKERS_IMAGE_MODEL; + modelList = loadWorkersModelList("Text-to-Image", (ctx) => ctx.WORKERS_IMAGE_MODELS_LIST); request = async (prompt, context) => { if (ENV.AI_BINDING) { const answer = await ENV.AI_BINDING.run(context.WORKERS_IMAGE_MODEL, { prompt }); - const raw = mapAiTextToImageOutput2Response(answer); - return await mapResponseToImage(raw); + const raw = WorkersImage.outputToResponse(answer); + return await WorkersImage.responseToImage(raw); } else if (context.CLOUDFLARE_ACCOUNT_ID && context.CLOUDFLARE_TOKEN) { const id = context.CLOUDFLARE_ACCOUNT_ID; const token = context.CLOUDFLARE_TOKEN; - const raw = await sendWorkerRequest(context.WORKERS_IMAGE_MODEL, { prompt }, id, token); - return await mapResponseToImage(raw); + const raw = await WorkersImage.fetch(context.WORKERS_IMAGE_MODEL, { prompt }, id, token); + return await WorkersImage.responseToImage(raw); } else { throw new Error("Cloudflare account ID and token are required"); } }; -} -async function base64StringToBlob(base64String) { - if (typeof Buffer !== "undefined") { - const buffer = Buffer.from(base64String, "base64"); - return new Blob([buffer], { type: "image/png" }); - } else { - const uint8Array = Uint8Array.from(atob(base64String), (c) => c.charCodeAt(0)); - return new Blob([uint8Array], { type: "image/png" }); + static outputToResponse(output) { + if (output instanceof ReadableStream) { + return new Response(output, { + headers: { + "content-type": "image/jpg" + } + }); + } else { + return Response.json({ result: output }); + } + } + static async responseToImage(output) { + if (isJsonResponse(output)) { + const { result } = await output.json(); + const image = result?.image; + if (typeof image !== "string") { + throw new TypeError("Invalid image response"); + } + return WorkersImage.base64StringToBlob(image); + } + return await output.blob(); + } + static async base64StringToBlob(base64String) { + if (typeof Buffer !== "undefined") { + const buffer = Buffer.from(base64String, "base64"); + return new Blob([buffer], { type: "image/png" }); + } else { + const uint8Array = Uint8Array.from(atob(base64String), (c) => c.charCodeAt(0)); + return new Blob([uint8Array], { type: "image/png" }); + } + } + static async fetch(model, body, id, token) { + return await fetch( + `https://api.cloudflare.com/client/v4/accounts/${id}/ai/run/${model}`, + { + headers: { Authorization: `Bearer ${token}` }, + method: "POST", + body: JSON.stringify(body) + } + ); } } const CHAT_AGENTS = [ @@ -1919,13 +1874,30 @@ async function requestCompletionsFromLLM(params, context, agent, modifier, onStr return text; } class AgentListCallbackQueryHandler { - prefix = "al:"; + prefix; + changeAgentPrefix; + agentLoader; needAuth = TELEGRAM_AUTH_CHECKER.shareModeGroup; + constructor(prefix, changeAgentPrefix, agentLoader) { + this.agentLoader = agentLoader; + this.prefix = prefix; + this.changeAgentPrefix = changeAgentPrefix; + } + static NewChatAgentListCallbackQueryHandler() { + return new AgentListCallbackQueryHandler("al:", "ca:", () => { + return CHAT_AGENTS.filter((agent) => agent.enable(ENV.USER_CONFIG)).map((agent) => agent.name); + }); + } + static NewImageAgentListCallbackQueryHandler() { + return new AgentListCallbackQueryHandler("ial:", "ica:", () => { + return IMAGE_AGENTS.filter((agent) => agent.enable(ENV.USER_CONFIG)).map((agent) => agent.name); + }); + } handle = async (query, data, context) => { if (!query.message) { throw new Error("no message"); } - const names = CHAT_AGENTS.filter((agent) => agent.enable(ENV.USER_CONFIG)).map((agent) => agent.name); + const names = this.agentLoader(); const sender = MessageSender.fromCallbackQuery(context.SHARE_CONTEXT.botToken, query); const keyboards = []; for (let i = 0; i < names.length; i += 2) { @@ -1937,7 +1909,7 @@ class AgentListCallbackQueryHandler { } row.push({ text: names[index], - callback_data: `ca:${JSON.stringify([names[index], 0])}` + callback_data: `${this.changeAgentPrefix}${JSON.stringify([names[index], 0])}` }); } keyboards.push(row); @@ -1953,24 +1925,50 @@ class AgentListCallbackQueryHandler { return sender.editRawMessage(params); }; } +function changeChatAgentType(conf, agent) { + return { + ...conf, + AI_PROVIDER: agent + }; +} +function changeImageAgentType(conf, agent) { + return { + ...conf, + AI_IMAGE_PROVIDER: agent + }; +} class ModelListCallbackQueryHandler { - prefix = "ca:"; + prefix; + agentListPrefix; + changeModelPrefix; + agentLoader; + changeAgentType; needAuth = TELEGRAM_AUTH_CHECKER.shareModeGroup; + constructor(prefix, agentListPrefix, changeModelPrefix, agentLoader, changeAgentType) { + this.prefix = prefix; + this.agentListPrefix = agentListPrefix; + this.changeModelPrefix = changeModelPrefix; + this.agentLoader = agentLoader; + this.changeAgentType = changeAgentType; + } + static NewChatModelListCallbackQueryHandler() { + return new ModelListCallbackQueryHandler("ca:", "al:", "cm:", loadChatLLM, changeChatAgentType); + } + static NewImageModelListCallbackQueryHandler() { + return new ModelListCallbackQueryHandler("ica:", "ial:", "icm:", loadImageGen, changeImageAgentType); + } async handle(query, data, context) { if (!query.message) { throw new Error("no message"); } const sender = MessageSender.fromCallbackQuery(context.SHARE_CONTEXT.botToken, query); const [agent, page] = JSON.parse(data.substring(this.prefix.length)); - const conf = { - ...ENV.USER_CONFIG, - AI_PROVIDER: agent - }; - const chatAgent = loadChatLLM(conf); - if (!chatAgent) { + const conf = this.changeAgentType(ENV.USER_CONFIG, agent); + const theAgent = this.agentLoader(conf); + if (!theAgent) { throw new Error(`agent not found: ${agent}`); } - const models = await chatAgent.modelList(conf); + const models = await theAgent.modelList(conf); const keyboard = []; const maxRow = 10; const maxCol = Math.max(1, Math.min(5, ENV.MODEL_LIST_COLUMNS)); @@ -1979,7 +1977,7 @@ class ModelListCallbackQueryHandler { for (let i = page * maxRow * maxCol; i < models.length; i++) { currentRow.push({ text: models[i], - callback_data: `cm:${JSON.stringify([agent, models[i]])}` + callback_data: `${this.changeModelPrefix}${JSON.stringify([agent, models[i]])}` }); if (i % maxCol === 0) { keyboard.push(currentRow); @@ -1996,19 +1994,19 @@ class ModelListCallbackQueryHandler { keyboard.push([ { text: "<", - callback_data: `ca:${JSON.stringify([agent, Math.max(page - 1, 0)])}` + callback_data: `${this.prefix}${JSON.stringify([agent, Math.max(page - 1, 0)])}` }, { text: `${page + 1}/${maxPage}`, - callback_data: `ca:${JSON.stringify([agent, page])}` + callback_data: `${this.prefix}${JSON.stringify([agent, page])}` }, { text: ">", - callback_data: `ca:${JSON.stringify([agent, Math.min(page + 1, maxPage - 1)])}` + callback_data: `${this.prefix}${JSON.stringify([agent, Math.min(page + 1, maxPage - 1)])}` }, { text: "⇤", - callback_data: `al:` + callback_data: this.agentListPrefix } ]); if (models.length > (page + 1) * maxRow * maxCol) { @@ -2027,8 +2025,21 @@ class ModelListCallbackQueryHandler { } } class ModelChangeCallbackQueryHandler { - prefix = "cm:"; + prefix; + agentLoader; + changeAgentType; needAuth = TELEGRAM_AUTH_CHECKER.shareModeGroup; + constructor(prefix, agentLoader, changeAgentType) { + this.prefix = prefix; + this.agentLoader = agentLoader; + this.changeAgentType = changeAgentType; + } + static NewChatModelChangeCallbackQueryHandler() { + return new ModelChangeCallbackQueryHandler("cm:", loadChatLLM, changeChatAgentType); + } + static NewImageModelChangeCallbackQueryHandler() { + return new ModelChangeCallbackQueryHandler("icm:", loadChatLLM, changeImageAgentType); + } async handle(query, data, context) { if (!query.message) { throw new Error("no message"); @@ -2039,16 +2050,16 @@ class ModelChangeCallbackQueryHandler { ...ENV.USER_CONFIG, AI_PROVIDER: agent }; - const chatAgent = loadChatLLM(conf); + const theAgent = this.agentLoader(conf); if (!agent) { throw new Error(`agent not found: ${agent}`); } - if (!chatAgent?.modelKey) { + if (!theAgent?.modelKey) { throw new Error(`modelKey not found: ${agent}`); } await context.execChangeAndSave({ AI_PROVIDER: agent, - [chatAgent.modelKey]: model + [theAgent.modelKey]: model }); console.log("Change model:", agent, model); const message = { @@ -2060,9 +2071,12 @@ class ModelChangeCallbackQueryHandler { } } const QUERY_HANDLERS = [ - new AgentListCallbackQueryHandler(), - new ModelListCallbackQueryHandler(), - new ModelChangeCallbackQueryHandler() + AgentListCallbackQueryHandler.NewChatAgentListCallbackQueryHandler(), + AgentListCallbackQueryHandler.NewImageAgentListCallbackQueryHandler(), + ModelListCallbackQueryHandler.NewChatModelListCallbackQueryHandler(), + ModelListCallbackQueryHandler.NewImageModelListCallbackQueryHandler(), + ModelChangeCallbackQueryHandler.NewChatModelChangeCallbackQueryHandler(), + ModelChangeCallbackQueryHandler.NewImageModelChangeCallbackQueryHandler() ]; async function handleCallbackQuery(callbackQuery, context) { const sender = MessageSender.fromCallbackQuery(context.SHARE_CONTEXT.botToken, callbackQuery); @@ -2386,7 +2400,19 @@ class ImgCommandHandler { handle = async (message, subcommand, context) => { const sender = MessageSender.fromMessage(context.SHARE_CONTEXT.botToken, message); if (subcommand === "") { - return sender.sendPlainText(ENV.I18N.command.help.img); + const params = { + chat_id: message.chat.id, + text: ENV.I18N.command.help.img, + reply_markup: { + inline_keyboard: [[ + { + text: ENV.I18N.callback_query.open_model_list, + callback_data: "ial:" + } + ]] + } + }; + return sender.sendRawMessage(params); } try { const api = createTelegramBotAPI(context.SHARE_CONTEXT.botToken); diff --git a/packages/lib/core/src/agent/anthropic.ts b/packages/lib/core/src/agent/anthropic.ts index 46ed6882..eb694501 100644 --- a/packages/lib/core/src/agent/anthropic.ts +++ b/packages/lib/core/src/agent/anthropic.ts @@ -4,6 +4,7 @@ import type { SSEMessage, SSEParserResult } from './stream'; import type { AgentEnable, AgentModel, + AgentModelList, ChatAgent, ChatAgentRequest, ChatAgentResponse, @@ -11,11 +12,12 @@ import type { HistoryItem, LLMChatParams, } from './types'; +import { loadOpenAIModelList } from '#/agent/openai_compatibility'; import { ENV } from '#/config'; import { imageToBase64String } from '#/utils/image'; import { requestChatCompletions } from './request'; import { Stream } from './stream'; -import { convertStringToResponseMessages, extractImageContent, loadModelsList } from './utils'; +import { convertStringToResponseMessages, extractImageContent } from './utils'; function anthropicHeader(context: AgentUserConfig): Record { return { @@ -29,11 +31,11 @@ export class Anthropic implements ChatAgent { readonly name = 'anthropic'; readonly modelKey = 'ANTHROPIC_CHAT_MODEL'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return !!(context.ANTHROPIC_API_KEY); - }; + readonly enable: AgentEnable = ctx => !!(ctx.ANTHROPIC_API_KEY); + readonly model: AgentModel = ctx => ctx.ANTHROPIC_CHAT_MODEL; + readonly modelList: AgentModelList = ctx => loadOpenAIModelList(ctx.ANTHROPIC_CHAT_MODELS_LIST, ctx.ANTHROPIC_API_BASE, anthropicHeader(ctx)); - private render = async (item: HistoryItem): Promise => { + private static render = async (item: HistoryItem): Promise => { const res: Record = { role: item.role, content: item.content, @@ -68,10 +70,6 @@ export class Anthropic implements ChatAgent { return res; }; - readonly model: AgentModel = (ctx: AgentUserConfig): string | null => { - return ctx.ANTHROPIC_CHAT_MODEL; - }; - private static parser(sse: SSEMessage): SSEParserResult { // example: // event: content_block_delta @@ -109,7 +107,7 @@ export class Anthropic implements ChatAgent { const body = { system: prompt, model: context.ANTHROPIC_CHAT_MODEL, - messages: (await Promise.all(messages.map(item => this.render(item)))).filter(i => i !== null), + messages: (await Promise.all(messages.map(item => Anthropic.render(item)))).filter(i => i !== null), stream: onStream != null, max_tokens: ENV.MAX_TOKEN_LENGTH > 0 ? ENV.MAX_TOKEN_LENGTH : 2048, }; @@ -131,16 +129,4 @@ export class Anthropic implements ChatAgent { }; return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, options)); }; - - readonly modelList = async (context: AgentUserConfig): Promise => { - if (context.ANTHROPIC_CHAT_MODELS_LIST === '') { - context.ANTHROPIC_CHAT_MODELS_LIST = `${context.ANTHROPIC_API_BASE}/models`; - } - return loadModelsList(context.ANTHROPIC_CHAT_MODELS_LIST, async (url): Promise => { - const data = await fetch(url, { - headers: anthropicHeader(context), - }).then(res => res.json() as any); - return data?.data?.map((model: any) => model.id) || []; - }); - }; } diff --git a/packages/lib/core/src/agent/azure.ts b/packages/lib/core/src/agent/azure.ts index d4101e93..076c3c33 100644 --- a/packages/lib/core/src/agent/azure.ts +++ b/packages/lib/core/src/agent/azure.ts @@ -2,6 +2,7 @@ import type { AgentUserConfig } from '#/config'; import type { AgentEnable, AgentModel, + AgentModelList, ChatAgent, ChatAgentRequest, ChatAgentResponse, @@ -10,7 +11,7 @@ import type { ImageAgentRequest, LLMChatParams, } from './types'; -import { ImageSupportFormat, renderOpenAIMessages } from './openai'; +import { ImageSupportFormat, renderOpenAIMessages } from '#/agent/openai_compatibility'; import { requestChatCompletions } from './request'; import { convertStringToResponseMessages, loadModelsList } from './utils'; @@ -25,13 +26,8 @@ export class AzureChatAI implements ChatAgent { readonly name = 'azure'; readonly modelKey = 'AZURE_CHAT_MODEL'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return !!(context.AZURE_API_KEY && context.AZURE_RESOURCE_NAME); - }; - - readonly model: AgentModel = (ctx: AgentUserConfig): string | null => { - return ctx.AZURE_CHAT_MODEL; - }; + readonly enable: AgentEnable = ctx => !!(ctx.AZURE_API_KEY && ctx.AZURE_RESOURCE_NAME); + readonly model: AgentModel = ctx => ctx.AZURE_CHAT_MODEL; readonly request: ChatAgentRequest = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise => { const { prompt, messages } = params; @@ -62,13 +58,9 @@ export class AzureImageAI implements ImageAgent { readonly name = 'azure'; readonly modelKey = 'AZURE_DALLE_API'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return !!(context.AZURE_API_KEY && context.AZURE_DALLE_API); - }; - - readonly model: AgentModel = (ctx: AgentUserConfig) => { - return ctx.AZURE_IMAGE_MODEL; - }; + readonly enable: AgentEnable = ctx => !!(ctx.AZURE_API_KEY && ctx.AZURE_DALLE_API); + readonly model: AgentModel = ctx => ctx.AZURE_IMAGE_MODEL; + readonly modelList: AgentModelList = ctx => Promise.resolve([ctx.AZURE_IMAGE_MODEL]); readonly request: ImageAgentRequest = async (prompt: string, context: AgentUserConfig): Promise => { const url = `https://${context.AZURE_RESOURCE_NAME}.openai.azure.com/openai/deployments/${context.AZURE_IMAGE_MODEL}/images/generations?api-version=${context.AZURE_API_VERSION}`; diff --git a/packages/lib/core/src/agent/cohere.ts b/packages/lib/core/src/agent/cohere.ts index 14ebea97..e7716000 100644 --- a/packages/lib/core/src/agent/cohere.ts +++ b/packages/lib/core/src/agent/cohere.ts @@ -9,7 +9,7 @@ import type { ChatStreamTextHandler, LLMChatParams, } from './types'; -import { renderOpenAIMessages } from './openai'; +import { renderOpenAIMessages } from '#/agent/openai_compatibility'; import { requestChatCompletions } from './request'; import { bearerHeader, convertStringToResponseMessages, loadModelsList } from './utils'; @@ -17,13 +17,8 @@ export class Cohere implements ChatAgent { readonly name = 'cohere'; readonly modelKey = 'COHERE_CHAT_MODEL'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return !!(context.COHERE_API_KEY); - }; - - readonly model: AgentModel = (ctx: AgentUserConfig): string | null => { - return ctx.COHERE_CHAT_MODEL; - }; + readonly enable: AgentEnable = ctx => !!(ctx.COHERE_API_KEY); + readonly model: AgentModel = ctx => ctx.COHERE_CHAT_MODEL; readonly request: ChatAgentRequest = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise => { const { prompt, messages } = params; diff --git a/packages/lib/core/src/agent/gemini.ts b/packages/lib/core/src/agent/gemini.ts index 773c4109..1382b078 100644 --- a/packages/lib/core/src/agent/gemini.ts +++ b/packages/lib/core/src/agent/gemini.ts @@ -8,7 +8,7 @@ import type { ChatStreamTextHandler, LLMChatParams, } from './types'; -import { ImageSupportFormat, renderOpenAIMessages } from './openai'; +import { ImageSupportFormat, renderOpenAIMessages } from '#/agent/openai_compatibility'; import { requestChatCompletions } from './request'; import { bearerHeader, convertStringToResponseMessages, loadModelsList } from './utils'; @@ -16,13 +16,8 @@ export class Gemini implements ChatAgent { readonly name = 'gemini'; readonly modelKey = 'GOOGLE_COMPLETIONS_MODEL'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return !!(context.GOOGLE_API_KEY); - }; - - readonly model: AgentModel = (ctx: AgentUserConfig): string => { - return ctx.GOOGLE_COMPLETIONS_MODEL; - }; + readonly enable: AgentEnable = ctx => !!(ctx.GOOGLE_API_KEY); + readonly model: AgentModel = ctx => ctx.GOOGLE_COMPLETIONS_MODEL; readonly request: ChatAgentRequest = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise => { const { prompt, messages } = params; diff --git a/packages/lib/core/src/agent/mistralai.ts b/packages/lib/core/src/agent/mistralai.ts index 927f0cf5..4975b3e2 100644 --- a/packages/lib/core/src/agent/mistralai.ts +++ b/packages/lib/core/src/agent/mistralai.ts @@ -2,27 +2,24 @@ import type { AgentUserConfig } from '#/config'; import type { AgentEnable, AgentModel, + AgentModelList, ChatAgent, ChatAgentRequest, ChatAgentResponse, ChatStreamTextHandler, LLMChatParams, } from './types'; -import { ImageSupportFormat, renderOpenAIMessages } from './openai'; +import { ImageSupportFormat, loadOpenAIModelList, renderOpenAIMessages } from '#/agent/openai_compatibility'; import { requestChatCompletions } from './request'; -import { bearerHeader, convertStringToResponseMessages, loadModelsList } from './utils'; +import { bearerHeader, convertStringToResponseMessages } from './utils'; export class Mistral implements ChatAgent { readonly name = 'mistral'; readonly modelKey = 'MISTRAL_CHAT_MODEL'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return !!(context.MISTRAL_API_KEY); - }; - - readonly model: AgentModel = (ctx: AgentUserConfig): string | null => { - return ctx.MISTRAL_CHAT_MODEL; - }; + readonly enable: AgentEnable = ctx => !!(ctx.MISTRAL_API_KEY); + readonly model: AgentModel = ctx => ctx.MISTRAL_CHAT_MODEL; + readonly modelList: AgentModelList = ctx => loadOpenAIModelList(ctx.MISTRAL_CHAT_MODELS_LIST, ctx.MISTRAL_API_BASE, bearerHeader(ctx.MISTRAL_API_KEY)); readonly request: ChatAgentRequest = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise => { const { prompt, messages } = params; @@ -37,16 +34,4 @@ export class Mistral implements ChatAgent { return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, null)); }; - - readonly modelList = async (context: AgentUserConfig): Promise => { - if (context.MISTRAL_CHAT_MODELS_LIST === '') { - context.MISTRAL_CHAT_MODELS_LIST = `${context.MISTRAL_API_BASE}/models`; - } - return loadModelsList(context.MISTRAL_CHAT_MODELS_LIST, async (url): Promise => { - const data = await fetch(url, { - headers: bearerHeader(context.MISTRAL_API_KEY), - }).then(res => res.json()) as any; - return data.data?.map((model: any) => model.id) || []; - }); - }; } diff --git a/packages/lib/core/src/agent/openai.ts b/packages/lib/core/src/agent/openai.ts index 3ebced9c..ff107da2 100644 --- a/packages/lib/core/src/agent/openai.ts +++ b/packages/lib/core/src/agent/openai.ts @@ -2,74 +2,18 @@ import type { AgentUserConfig } from '#/config'; import type { AgentEnable, AgentModel, + AgentModelList, ChatAgent, ChatAgentRequest, ChatAgentResponse, ChatStreamTextHandler, - HistoryItem, ImageAgent, ImageAgentRequest, LLMChatParams, } from './types'; -import { ENV } from '#/config'; -import { imageToBase64String, renderBase64DataURI } from '#/utils/image'; +import { ImageSupportFormat, loadOpenAIModelList, renderOpenAIMessages } from '#/agent/openai_compatibility'; import { requestChatCompletions } from './request'; -import { bearerHeader, convertStringToResponseMessages, extractImageContent, loadModelsList } from './utils'; - -export enum ImageSupportFormat { - URL = 'url', - BASE64 = 'base64', -} - -async function renderOpenAIMessage(item: HistoryItem, supportImage?: ImageSupportFormat[] | null): Promise { - const res: any = { - role: item.role, - content: item.content, - }; - if (Array.isArray(item.content)) { - const contents = []; - for (const content of item.content) { - switch (content.type) { - case 'text': - contents.push({ type: 'text', text: content.text }); - break; - case 'image': - if (supportImage) { - const isSupportURL = supportImage.includes(ImageSupportFormat.URL); - const isSupportBase64 = supportImage.includes(ImageSupportFormat.BASE64); - const data = extractImageContent(content.image); - if (data.url) { - if (ENV.TELEGRAM_IMAGE_TRANSFER_MODE === 'base64' && isSupportBase64) { - contents.push(await imageToBase64String(data.url).then((data) => { - return { type: 'image_url', image_url: { url: renderBase64DataURI(data) } }; - })); - } else if (isSupportURL) { - contents.push({ type: 'image_url', image_url: { url: data.url } }); - } - } else if (data.base64 && isSupportBase64) { - contents.push({ type: 'image_base64', image_base64: { base64: data.base64 } }); - } - } - break; - default: - break; - } - } - res.content = contents; - } - return res; -} - -export async function renderOpenAIMessages(prompt: string | undefined, items: HistoryItem[], supportImage?: ImageSupportFormat[] | null): Promise { - const messages = await Promise.all(items.map(r => renderOpenAIMessage(r, supportImage))); - if (prompt) { - if (messages.length > 0 && messages[0].role === 'system') { - messages.shift(); - } - messages.unshift({ role: 'system', content: prompt }); - } - return messages; -} +import { bearerHeader, convertStringToResponseMessages } from './utils'; function openAIApiKey(context: AgentUserConfig): string { const length = context.OPENAI_API_KEY.length; @@ -80,13 +24,9 @@ export class OpenAI implements ChatAgent { readonly name = 'openai'; readonly modelKey = 'OPENAI_CHAT_MODEL'; - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return context.OPENAI_API_KEY.length > 0; - }; - - readonly model: AgentModel = (ctx: AgentUserConfig): string | null => { - return ctx.OPENAI_CHAT_MODEL; - }; + readonly enable: AgentEnable = ctx => ctx.OPENAI_API_KEY.length > 0; + readonly model: AgentModel = ctx => ctx.OPENAI_CHAT_MODEL; + readonly modelList: AgentModelList = ctx => loadOpenAIModelList(ctx.OPENAI_CHAT_MODELS_LIST, ctx.OPENAI_API_BASE, bearerHeader(openAIApiKey(ctx))); readonly request: ChatAgentRequest = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise => { const { prompt, messages } = params; @@ -98,34 +38,17 @@ export class OpenAI implements ChatAgent { messages: await renderOpenAIMessages(prompt, messages, [ImageSupportFormat.URL, ImageSupportFormat.BASE64]), stream: onStream != null, }; - return convertStringToResponseMessages(requestChatCompletions(url, header, body, onStream, null)); }; - - readonly modelList = async (context: AgentUserConfig): Promise => { - if (context.OPENAI_CHAT_MODELS_LIST === '') { - context.OPENAI_CHAT_MODELS_LIST = `${context.OPENAI_API_BASE}/models`; - } - return loadModelsList(context.OPENAI_CHAT_MODELS_LIST, async (url): Promise => { - const data = await fetch(url, { - headers: bearerHeader(openAIApiKey(context)), - }).then(res => res.json()) as any; - return data.data?.map((model: any) => model.id) || []; - }); - }; } export class Dalle implements ImageAgent { readonly name = 'openai'; - readonly modelKey = 'OPENAI_DALLE_API'; - - readonly enable: AgentEnable = (context: AgentUserConfig): boolean => { - return context.OPENAI_API_KEY.length > 0; - }; + readonly modelKey = 'DALL_E_MODEL'; - readonly model: AgentModel = (ctx: AgentUserConfig): string => { - return ctx.DALL_E_MODEL; - }; + readonly enable: AgentEnable = ctx => ctx.OPENAI_API_KEY.length > 0; + readonly model: AgentModel = ctx => ctx.DALL_E_MODEL; + readonly modelList: AgentModelList = ctx => Promise.resolve([ctx.DALL_E_MODEL]); readonly request: ImageAgentRequest = async (prompt: string, context: AgentUserConfig): Promise => { const url = `${context.OPENAI_API_BASE}/images/generations`; diff --git a/packages/lib/core/src/agent/openai_compatibility.ts b/packages/lib/core/src/agent/openai_compatibility.ts new file mode 100644 index 00000000..2d2195ad --- /dev/null +++ b/packages/lib/core/src/agent/openai_compatibility.ts @@ -0,0 +1,69 @@ +import type { HistoryItem } from '#/agent/types'; +import { extractImageContent, loadModelsList } from '#/agent/utils'; +import { ENV } from '#/config'; +import { imageToBase64String, renderBase64DataURI } from '#/utils/image'; + +export enum ImageSupportFormat { + URL = 'url', + BASE64 = 'base64', +} + +async function renderOpenAIMessage(item: HistoryItem, supportImage?: ImageSupportFormat[] | null): Promise { + const res: any = { + role: item.role, + content: item.content, + }; + if (Array.isArray(item.content)) { + const contents = []; + for (const content of item.content) { + switch (content.type) { + case 'text': + contents.push({ type: 'text', text: content.text }); + break; + case 'image': + if (supportImage) { + const isSupportURL = supportImage.includes(ImageSupportFormat.URL); + const isSupportBase64 = supportImage.includes(ImageSupportFormat.BASE64); + const data = extractImageContent(content.image); + if (data.url) { + if (ENV.TELEGRAM_IMAGE_TRANSFER_MODE === 'base64' && isSupportBase64) { + contents.push(await imageToBase64String(data.url).then((data) => { + return { type: 'image_url', image_url: { url: renderBase64DataURI(data) } }; + })); + } else if (isSupportURL) { + contents.push({ type: 'image_url', image_url: { url: data.url } }); + } + } else if (data.base64 && isSupportBase64) { + contents.push({ type: 'image_base64', image_base64: { base64: data.base64 } }); + } + } + break; + default: + break; + } + } + res.content = contents; + } + return res; +} + +export async function renderOpenAIMessages(prompt: string | undefined, items: HistoryItem[], supportImage?: ImageSupportFormat[] | null): Promise { + const messages = await Promise.all(items.map(r => renderOpenAIMessage(r, supportImage))); + if (prompt) { + if (messages.length > 0 && messages[0].role === 'system') { + messages.shift(); + } + messages.unshift({ role: 'system', content: prompt }); + } + return messages; +} + +export function loadOpenAIModelList(list: string, base: string, headers: Record): Promise { + if (list === '') { + list = `${base}/models`; + } + return loadModelsList(list, async (url): Promise => { + const data = await fetch(url, { headers }).then(res => res.json()) as any; + return data.data?.map((model: any) => model.id) || []; + }); +} diff --git a/packages/lib/core/src/agent/types.ts b/packages/lib/core/src/agent/types.ts index 332897b5..1a357ac2 100644 --- a/packages/lib/core/src/agent/types.ts +++ b/packages/lib/core/src/agent/types.ts @@ -43,6 +43,7 @@ export type HistoryModifier = (history: HistoryItem[], message: UserMessageItem export type AgentEnable = (context: AgentUserConfig) => boolean; export type AgentModel = (ctx: AgentUserConfig) => string | null; +export type AgentModelList = (ctx: AgentUserConfig) => Promise; export type ChatAgentRequest = (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null) => Promise; export type ImageAgentRequest = (prompt: string, context: AgentUserConfig) => Promise; @@ -51,11 +52,10 @@ export interface Agent { modelKey: string; enable: AgentEnable; model: AgentModel; + modelList: AgentModelList; request: AgentRequest; } -export interface ChatAgent extends Agent { - modelList: (ctx: AgentUserConfig) => Promise; -} +export interface ChatAgent extends Agent {} export interface ImageAgent extends Agent {} diff --git a/packages/lib/core/src/agent/utils.ts b/packages/lib/core/src/agent/utils.ts index 3f22b130..523088ed 100644 --- a/packages/lib/core/src/agent/utils.ts +++ b/packages/lib/core/src/agent/utils.ts @@ -52,8 +52,7 @@ export async function convertStringToResponseMessages(input: Promise): P }; } -export type RemoteParser = (url: string) => Promise; -export async function loadModelsList(raw: string, remoteLoader?: RemoteParser): Promise { +export async function loadModelsList(raw: string, remoteLoader?: (url: string) => Promise): Promise { if (!raw) { return []; } @@ -68,7 +67,7 @@ export async function loadModelsList(raw: string, remoteLoader?: RemoteParser): if (raw.startsWith('http') && remoteLoader) { return await remoteLoader(raw); } - return []; + return [raw]; } export function bearerHeader(token: string | null, stream?: boolean): Record { diff --git a/packages/lib/core/src/agent/workersai.ts b/packages/lib/core/src/agent/workersai.ts index 0dddac79..22c92ec1 100644 --- a/packages/lib/core/src/agent/workersai.ts +++ b/packages/lib/core/src/agent/workersai.ts @@ -3,6 +3,7 @@ import type { SseChatCompatibleOptions } from './request'; import type { AgentEnable, AgentModel, + AgentModelList, ChatAgent, ChatAgentRequest, ChatAgentResponse, @@ -11,22 +12,11 @@ import type { ImageAgentRequest, LLMChatParams, } from './types'; +import { renderOpenAIMessages } from '#/agent/openai_compatibility'; import { ENV } from '#/config'; -import { renderOpenAIMessages } from './openai'; import { isJsonResponse, mapResponseToAnswer, requestChatCompletions } from './request'; import { bearerHeader, convertStringToResponseMessages, loadModelsList } from './utils'; -async function sendWorkerRequest(model: string, body: any, id: string, token: string): Promise { - return await fetch( - `https://api.cloudflare.com/client/v4/accounts/${id}/ai/run/${model}`, - { - headers: { Authorization: `Bearer ${token}` }, - method: 'POST', - body: JSON.stringify(body), - }, - ); -} - function isWorkerAIEnable(context: AgentUserConfig): boolean { if (ENV.AI_BINDING) { return true; @@ -34,38 +24,22 @@ function isWorkerAIEnable(context: AgentUserConfig): boolean { return !!(context.CLOUDFLARE_ACCOUNT_ID && context.CLOUDFLARE_TOKEN); } -function mapAiTextGenerationOutput2Response(output: AiTextGenerationOutput, stream: boolean): Response { - if (stream && output instanceof ReadableStream) { - return new Response(output, { - headers: { 'content-type': 'text/event-stream' }, - }); - } else { - return Response.json({ result: output }); - } -} - -function mapAiTextToImageOutput2Response(output: AiTextToImageOutput): Response { - if (output instanceof ReadableStream) { - return new Response(output, { - headers: { - 'content-type': 'image/jpg', - }, - }); - } else { - return Response.json({ result: output }); - } -} - -async function mapResponseToImage(output: Response): Promise { - if (isJsonResponse(output)) { - const { result } = await output.json(); - const image = result?.image; - if (typeof image !== 'string') { - throw new TypeError('Invalid image response'); +function loadWorkersModelList(task: string, loader: (context: AgentUserConfig) => string): (context: AgentUserConfig) => Promise { + return async (context: AgentUserConfig): Promise => { + let uri = loader(context); + if (uri === '') { + const id = context.CLOUDFLARE_ACCOUNT_ID; + const taskEncoded = encodeURIComponent(task); + uri = `https://api.cloudflare.com/client/v4/accounts/${id}/ai/models/search?task=${taskEncoded}`; } - return base64StringToBlob(image); - } - return await output.blob(); + return loadModelsList(uri, async (url): Promise => { + const header = { + Authorization: `Bearer ${context.CLOUDFLARE_TOKEN}`, + }; + const data = await fetch(url, { headers: header }).then(res => res.json()); + return data.result?.map((model: any) => model.name) || []; + }); + }; } export class WorkersChat implements ChatAgent { @@ -73,9 +47,8 @@ export class WorkersChat implements ChatAgent { readonly modelKey = 'WORKERS_CHAT_MODEL'; readonly enable: AgentEnable = isWorkerAIEnable; - readonly model: AgentModel = (ctx: AgentUserConfig): string | null => { - return ctx.WORKERS_CHAT_MODEL; - }; + readonly model: AgentModel = ctx => ctx.WORKERS_CHAT_MODEL; + readonly modelList: AgentModelList = loadWorkersModelList('Text Generation', ctx => ctx.WORKERS_CHAT_MODELS_LIST); readonly request: ChatAgentRequest = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise => { const { prompt, messages } = params; @@ -97,7 +70,7 @@ export class WorkersChat implements ChatAgent { if (ENV.AI_BINDING) { const answer = await ENV.AI_BINDING.run(model, body); - const response = mapAiTextGenerationOutput2Response(answer, onStream !== null); + const response = WorkersChat.outputToResponse(answer, onStream !== null); return convertStringToResponseMessages(mapResponseToAnswer(response, new AbortController(), options, onStream)); } else if (context.CLOUDFLARE_ACCOUNT_ID && context.CLOUDFLARE_TOKEN) { const id = context.CLOUDFLARE_ACCOUNT_ID; @@ -110,18 +83,14 @@ export class WorkersChat implements ChatAgent { } }; - readonly modelList = async (context: AgentUserConfig): Promise => { - if (context.WORKERS_CHAT_MODELS_LIST === '') { - const id = context.CLOUDFLARE_ACCOUNT_ID; - context.WORKERS_CHAT_MODELS_LIST = `https://api.cloudflare.com/client/v4/accounts/${id}/ai/models/search?task=Text%20Generation`; + static outputToResponse(output: AiTextGenerationOutput, stream: boolean): Response { + if (stream && output instanceof ReadableStream) { + return new Response(output, { + headers: { 'content-type': 'text/event-stream' }, + }); + } else { + return Response.json({ result: output }); } - return loadModelsList(context.WORKERS_CHAT_MODELS_LIST, async (url): Promise => { - const header = { - Authorization: `Bearer ${context.CLOUDFLARE_TOKEN}`, - }; - const data = await fetch(url, { headers: header }).then(res => res.json()); - return data.result?.map((model: any) => model.name) || []; - }); }; } @@ -130,32 +99,66 @@ export class WorkersImage implements ImageAgent { readonly modelKey = 'WORKERS_IMAGE_MODEL'; readonly enable: AgentEnable = isWorkerAIEnable; - readonly model: AgentModel = (ctx: AgentUserConfig): string => { - return ctx.WORKERS_IMAGE_MODEL; - }; + readonly model: AgentModel = ctx => ctx.WORKERS_IMAGE_MODEL; + readonly modelList: AgentModelList = loadWorkersModelList('Text-to-Image', ctx => ctx.WORKERS_IMAGE_MODELS_LIST); readonly request: ImageAgentRequest = async (prompt: string, context: AgentUserConfig): Promise => { if (ENV.AI_BINDING) { const answer = await ENV.AI_BINDING.run(context.WORKERS_IMAGE_MODEL, { prompt }); - const raw = mapAiTextToImageOutput2Response(answer); - return await mapResponseToImage(raw); + const raw = WorkersImage.outputToResponse(answer); + return await WorkersImage.responseToImage(raw); } else if (context.CLOUDFLARE_ACCOUNT_ID && context.CLOUDFLARE_TOKEN) { const id = context.CLOUDFLARE_ACCOUNT_ID; const token = context.CLOUDFLARE_TOKEN; - const raw = await sendWorkerRequest(context.WORKERS_IMAGE_MODEL, { prompt }, id, token); - return await mapResponseToImage(raw); + const raw = await WorkersImage.fetch(context.WORKERS_IMAGE_MODEL, { prompt }, id, token); + return await WorkersImage.responseToImage(raw); } else { throw new Error('Cloudflare account ID and token are required'); } }; -} -async function base64StringToBlob(base64String: string): Promise { - if (typeof Buffer !== 'undefined') { - const buffer = Buffer.from(base64String, 'base64'); - return new Blob([buffer], { type: 'image/png' }); - } else { - const uint8Array = Uint8Array.from(atob(base64String), c => c.charCodeAt(0)); - return new Blob([uint8Array], { type: 'image/png' }); - } + static outputToResponse(output: AiTextToImageOutput): Response { + if (output instanceof ReadableStream) { + return new Response(output, { + headers: { + 'content-type': 'image/jpg', + }, + }); + } else { + return Response.json({ result: output }); + } + }; + + static async responseToImage(output: Response): Promise { + if (isJsonResponse(output)) { + const { result } = await output.json(); + const image = result?.image; + if (typeof image !== 'string') { + throw new TypeError('Invalid image response'); + } + return WorkersImage.base64StringToBlob(image); + } + return await output.blob(); + }; + + static async base64StringToBlob(base64String: string): Promise { + if (typeof Buffer !== 'undefined') { + const buffer = Buffer.from(base64String, 'base64'); + return new Blob([buffer], { type: 'image/png' }); + } else { + const uint8Array = Uint8Array.from(atob(base64String), c => c.charCodeAt(0)); + return new Blob([uint8Array], { type: 'image/png' }); + } + }; + + static async fetch(model: string, body: any, id: string, token: string): Promise { + return await fetch( + `https://api.cloudflare.com/client/v4/accounts/${id}/ai/run/${model}`, + { + headers: { Authorization: `Bearer ${token}` }, + method: 'POST', + body: JSON.stringify(body), + }, + ); + }; } diff --git a/packages/lib/core/src/config/config.ts b/packages/lib/core/src/config/config.ts index ba6d0fa2..b2955e36 100644 --- a/packages/lib/core/src/config/config.ts +++ b/packages/lib/core/src/config/config.ts @@ -154,6 +154,8 @@ export class WorkersConfig { WORKERS_IMAGE_MODEL = '@cf/black-forest-labs/flux-1-schnell'; // Workers Chat Models List, When empty, will use the api to get the list WORKERS_CHAT_MODELS_LIST = ''; + // Workers Image Models List, When empty, will use the api to get the list + WORKERS_IMAGE_MODELS_LIST = ''; } // -- Gemini 配置 -- diff --git a/packages/lib/core/src/config/version.ts b/packages/lib/core/src/config/version.ts index 1175156c..61bfab4f 100644 --- a/packages/lib/core/src/config/version.ts +++ b/packages/lib/core/src/config/version.ts @@ -1,2 +1,2 @@ -export const BUILD_TIMESTAMP = 1735019140; -export const BUILD_VERSION = 'd1d9307'; +export const BUILD_TIMESTAMP = 1735026022; +export const BUILD_VERSION = '284e897'; diff --git a/packages/lib/core/src/telegram/callback_query/index.ts b/packages/lib/core/src/telegram/callback_query/index.ts index 8fb25256..3f5e915f 100644 --- a/packages/lib/core/src/telegram/callback_query/index.ts +++ b/packages/lib/core/src/telegram/callback_query/index.ts @@ -5,9 +5,12 @@ import { MessageSender } from '../sender'; import { AgentListCallbackQueryHandler, ModelChangeCallbackQueryHandler, ModelListCallbackQueryHandler } from './system'; const QUERY_HANDLERS = [ - new AgentListCallbackQueryHandler(), - new ModelListCallbackQueryHandler(), - new ModelChangeCallbackQueryHandler(), + AgentListCallbackQueryHandler.NewChatAgentListCallbackQueryHandler(), + AgentListCallbackQueryHandler.NewImageAgentListCallbackQueryHandler(), + ModelListCallbackQueryHandler.NewChatModelListCallbackQueryHandler(), + ModelListCallbackQueryHandler.NewImageModelListCallbackQueryHandler(), + ModelChangeCallbackQueryHandler.NewChatModelChangeCallbackQueryHandler(), + ModelChangeCallbackQueryHandler.NewImageModelChangeCallbackQueryHandler(), ]; export async function handleCallbackQuery(callbackQuery: Telegram.CallbackQuery, context: WorkerContext): Promise { diff --git a/packages/lib/core/src/telegram/callback_query/system.ts b/packages/lib/core/src/telegram/callback_query/system.ts index 40f8184b..8cbfbf69 100644 --- a/packages/lib/core/src/telegram/callback_query/system.ts +++ b/packages/lib/core/src/telegram/callback_query/system.ts @@ -1,21 +1,42 @@ +import type { ChatAgent, ImageAgent } from '#/agent'; import type { AgentUserConfig, WorkerContext } from '#/config'; import type * as Telegram from 'telegram-bot-api-types'; import type { CallbackQueryHandler } from './types'; -import { CHAT_AGENTS, loadChatLLM } from '#/agent'; +import { CHAT_AGENTS, IMAGE_AGENTS, loadChatLLM, loadImageGen } from '#/agent'; import { ENV } from '#/config'; import { TELEGRAM_AUTH_CHECKER } from '../auth'; import { MessageSender } from '../sender'; export class AgentListCallbackQueryHandler implements CallbackQueryHandler { - prefix = 'al:'; + prefix: string; + changeAgentPrefix: string; + agentLoader: () => string[]; needAuth = TELEGRAM_AUTH_CHECKER.shareModeGroup; + constructor(prefix: string, changeAgentPrefix: string, agentLoader: () => string[]) { + this.agentLoader = agentLoader; + this.prefix = prefix; + this.changeAgentPrefix = changeAgentPrefix; + } + + static NewChatAgentListCallbackQueryHandler(): AgentListCallbackQueryHandler { + return new AgentListCallbackQueryHandler('al:', 'ca:', () => { + return CHAT_AGENTS.filter(agent => agent.enable(ENV.USER_CONFIG)).map(agent => agent.name); + }); + } + + static NewImageAgentListCallbackQueryHandler(): AgentListCallbackQueryHandler { + return new AgentListCallbackQueryHandler('ial:', 'ica:', () => { + return IMAGE_AGENTS.filter(agent => agent.enable(ENV.USER_CONFIG)).map(agent => agent.name); + }); + } + handle = async (query: Telegram.CallbackQuery, data: string, context: WorkerContext): Promise => { if (!query.message) { throw new Error('no message'); } - const names = CHAT_AGENTS.filter(agent => agent.enable(ENV.USER_CONFIG)).map(agent => agent.name); + const names = this.agentLoader(); const sender = MessageSender.fromCallbackQuery(context.SHARE_CONTEXT.botToken, query); const keyboards: Telegram.InlineKeyboardButton[][] = []; for (let i = 0; i < names.length; i += 2) { @@ -27,7 +48,7 @@ export class AgentListCallbackQueryHandler implements CallbackQueryHandler { } row.push({ text: names[index], - callback_data: `ca:${JSON.stringify([names[index], 0])}`, + callback_data: `${this.changeAgentPrefix}${JSON.stringify([names[index], 0])}`, }); } keyboards.push(row); @@ -44,26 +65,63 @@ export class AgentListCallbackQueryHandler implements CallbackQueryHandler { }; } +type AgentLoader = (conf: AgentUserConfig) => ChatAgent | ImageAgent | null; +type ChangeAgentType = (conf: AgentUserConfig, agent: string) => AgentUserConfig; + +function changeChatAgentType(conf: AgentUserConfig, agent: string): AgentUserConfig { + return { + ...conf, + AI_PROVIDER: agent, + }; +} + +function changeImageAgentType(conf: AgentUserConfig, agent: string): AgentUserConfig { + return { + ...conf, + AI_IMAGE_PROVIDER: agent, + }; +} + export class ModelListCallbackQueryHandler implements CallbackQueryHandler { - prefix = 'ca:'; // ca:model:page + prefix: string; + agentListPrefix: string; + changeModelPrefix: string; + + agentLoader: AgentLoader; + changeAgentType: ChangeAgentType; needAuth = TELEGRAM_AUTH_CHECKER.shareModeGroup; + constructor(prefix: string, agentListPrefix: string, changeModelPrefix: string, agentLoader: AgentLoader, changeAgentType: ChangeAgentType) { + this.prefix = prefix; + this.agentListPrefix = agentListPrefix; + this.changeModelPrefix = changeModelPrefix; + this.agentLoader = agentLoader; + this.changeAgentType = changeAgentType; + } + + static NewChatModelListCallbackQueryHandler(): ModelListCallbackQueryHandler { + return new ModelListCallbackQueryHandler('ca:', 'al:', 'cm:', loadChatLLM, changeChatAgentType); + } + + static NewImageModelListCallbackQueryHandler(): ModelListCallbackQueryHandler { + return new ModelListCallbackQueryHandler('ica:', 'ial:', 'icm:', loadImageGen, changeImageAgentType); + } + async handle(query: Telegram.CallbackQuery, data: string, context: WorkerContext): Promise { if (!query.message) { throw new Error('no message'); } const sender = MessageSender.fromCallbackQuery(context.SHARE_CONTEXT.botToken, query); + const [agent, page] = JSON.parse(data.substring(this.prefix.length)); - const conf: AgentUserConfig = { - ...ENV.USER_CONFIG, - AI_PROVIDER: agent, - }; - const chatAgent = loadChatLLM(conf); - if (!chatAgent) { + const conf: AgentUserConfig = this.changeAgentType(ENV.USER_CONFIG, agent); + const theAgent = this.agentLoader(conf); + if (!theAgent) { throw new Error(`agent not found: ${agent}`); } - const models = await chatAgent.modelList(conf); + + const models = await theAgent.modelList(conf); const keyboard: Telegram.InlineKeyboardButton[][] = []; const maxRow = 10; const maxCol = Math.max(1, Math.min(5, ENV.MODEL_LIST_COLUMNS)); @@ -73,7 +131,7 @@ export class ModelListCallbackQueryHandler implements CallbackQueryHandler { for (let i = page * maxRow * maxCol; i < models.length; i++) { currentRow.push({ text: models[i], - callback_data: `cm:${JSON.stringify([agent, models[i]])}`, + callback_data: `${this.changeModelPrefix}${JSON.stringify([agent, models[i]])}`, }); if (i % maxCol === 0) { keyboard.push(currentRow); @@ -90,19 +148,19 @@ export class ModelListCallbackQueryHandler implements CallbackQueryHandler { keyboard.push([ { text: '<', - callback_data: `ca:${JSON.stringify([agent, Math.max(page - 1, 0)])}`, + callback_data: `${this.prefix}${JSON.stringify([agent, Math.max(page - 1, 0)])}`, }, { text: `${page + 1}/${maxPage}`, - callback_data: `ca:${JSON.stringify([agent, page])}`, + callback_data: `${this.prefix}${JSON.stringify([agent, page])}`, }, { text: '>', - callback_data: `ca:${JSON.stringify([agent, Math.min(page + 1, maxPage - 1)])}`, + callback_data: `${this.prefix}${JSON.stringify([agent, Math.min(page + 1, maxPage - 1)])}`, }, { text: '⇤', - callback_data: `al:`, + callback_data: this.agentListPrefix, }, ]); if (models.length > (page + 1) * maxRow * maxCol) { @@ -122,10 +180,26 @@ export class ModelListCallbackQueryHandler implements CallbackQueryHandler { } export class ModelChangeCallbackQueryHandler implements CallbackQueryHandler { - prefix = 'cm:'; + prefix: string; + agentLoader: AgentLoader; + changeAgentType: ChangeAgentType; needAuth = TELEGRAM_AUTH_CHECKER.shareModeGroup; + constructor(prefix: string, agentLoader: AgentLoader, changeAgentType: ChangeAgentType) { + this.prefix = prefix; + this.agentLoader = agentLoader; + this.changeAgentType = changeAgentType; + } + + static NewChatModelChangeCallbackQueryHandler(): ModelChangeCallbackQueryHandler { + return new ModelChangeCallbackQueryHandler('cm:', loadChatLLM, changeChatAgentType); + } + + static NewImageModelChangeCallbackQueryHandler(): ModelChangeCallbackQueryHandler { + return new ModelChangeCallbackQueryHandler('icm:', loadChatLLM, changeImageAgentType); + } + async handle(query: Telegram.CallbackQuery, data: string, context: WorkerContext): Promise { if (!query.message) { throw new Error('no message'); @@ -136,16 +210,16 @@ export class ModelChangeCallbackQueryHandler implements CallbackQueryHandler { ...ENV.USER_CONFIG, AI_PROVIDER: agent, }; - const chatAgent = loadChatLLM(conf); + const theAgent = this.agentLoader(conf); if (!agent) { throw new Error(`agent not found: ${agent}`); } - if (!chatAgent?.modelKey) { + if (!theAgent?.modelKey) { throw new Error(`modelKey not found: ${agent}`); } await context.execChangeAndSave({ AI_PROVIDER: agent, - [chatAgent.modelKey]: model, + [theAgent.modelKey]: model, }); console.log('Change model:', agent, model); const message: Telegram.EditMessageTextParams = { diff --git a/packages/lib/core/src/telegram/command/system.ts b/packages/lib/core/src/telegram/command/system.ts index 87702f47..465b6a1f 100644 --- a/packages/lib/core/src/telegram/command/system.ts +++ b/packages/lib/core/src/telegram/command/system.ts @@ -15,7 +15,19 @@ export class ImgCommandHandler implements CommandHandler { handle = async (message: Telegram.Message, subcommand: string, context: WorkerContext): Promise => { const sender = MessageSender.fromMessage(context.SHARE_CONTEXT.botToken, message); if (subcommand === '') { - return sender.sendPlainText(ENV.I18N.command.help.img); + const params: Telegram.SendMessageParams = { + chat_id: message.chat.id, + text: ENV.I18N.command.help.img, + reply_markup: { + inline_keyboard: [[ + { + text: ENV.I18N.callback_query.open_model_list, + callback_data: 'ial:', + }, + ]], + }, + }; + return sender.sendRawMessage(params); } try { const api = createTelegramBotAPI(context.SHARE_CONTEXT.botToken);