From 3daaeb62ec051e443111bf63857a952f465aff6a Mon Sep 17 00:00:00 2001 From: Mohammad Afzal Date: Wed, 11 Sep 2024 12:20:48 -0400 Subject: [PATCH] Added Chat window settings --- app/api/chat/route.ts | 5 +- app/page.tsx | 5 +- components/ChatSettings.tsx | 90 +++++++++++++++++++++ components/ChatWindow.tsx | 152 ++++++++++++++++++++---------------- components/ui/checkbox.tsx | 30 +++++++ constants/index.ts | 11 +++ package-lock.json | 31 ++++++++ package.json | 1 + types.d.ts | 4 +- utils/reasoningPrompt.ts | 8 +- utils/utils.ts | 11 +++ 11 files changed, 275 insertions(+), 73 deletions(-) create mode 100644 components/ChatSettings.tsx create mode 100644 components/ui/checkbox.tsx create mode 100644 utils/utils.ts diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index 90b6030..ed2430f 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -8,6 +8,7 @@ import { reasoningPrompt } from "@/utils/reasoningPrompt"; import { getStructuredPrompt } from "@/utils/prompts"; import { timestampLambda } from "@/utils/timestampLambda"; import { RunnableSequence } from "@langchain/core/runnables"; +import { MODELS } from "@/constants"; export const runtime = "nodejs"; @@ -15,6 +16,7 @@ export async function POST(req: NextRequest) { try { const body = await req.json(); const messages = body.messages ?? []; + const modelName = body.modelName ?? MODELS["openai"][0]; const contracts = (await contractCollection.get()).filter( (c) => !(body.disabledContractKeys ?? []).includes(c.key), ); @@ -28,6 +30,7 @@ export async function POST(req: NextRequest) { // Reasoning prompt takes the contracts and chat history to asks the llm to reduce the # of abi functions // It returns an object of the contract and abis most appropriate to the chat history const reasoningPromptResponse = await reasoningPrompt({ + modelName: modelName, contracts, input: currentMessageContent, chatHistory: previousMessages, @@ -55,7 +58,7 @@ export async function POST(req: NextRequest) { const tools = getToolsFromContracts(filteredContracts); const model = new ChatOpenAI({ - model: "gpt-4o-mini", + model: modelName, temperature: 0, streaming: true, }).bindTools(tools); diff --git a/app/page.tsx b/app/page.tsx index fdfbe24..3eae0e9 100644 --- a/app/page.tsx +++ b/app/page.tsx @@ -54,10 +54,7 @@ export default function Page() { )} {/* Remove div with id=temp if enabling side nav */} -
+
diff --git a/components/ChatSettings.tsx b/components/ChatSettings.tsx new file mode 100644 index 0000000..699e4a3 --- /dev/null +++ b/components/ChatSettings.tsx @@ -0,0 +1,90 @@ +"use client"; + +import { useState } from "react"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Checkbox } from "./ui/checkbox"; +import { MODELS } from "@/constants"; +import { findModelKey } from "@/utils/utils"; + +type IChatSettingProps = { + clearOnChange: boolean; + onClearOnChange: (value: boolean) => void; + modelName: string; + onModelNameChange: (value: string) => void; +}; + +export function ChatSettings(props: IChatSettingProps) { + const { modelName, onModelNameChange, clearOnChange, onClearOnChange } = + props; + const [inferenceProvider, setInferenceProvider] = useState(() => { + return findModelKey(modelName); + }); + + const handleInferenceProviderChange = (value: keyof typeof MODELS) => { + setInferenceProvider(value); + }; + + return ( +
+

Settings

+
+
+ + +
+ {inferenceProvider && ( +
+ + +
+ )} + +
+ + +
+
+
+ ); +} diff --git a/components/ChatWindow.tsx b/components/ChatWindow.tsx index 9a82495..45b10e8 100644 --- a/components/ChatWindow.tsx +++ b/components/ChatWindow.tsx @@ -2,8 +2,7 @@ import { toast } from "sonner"; import { useChat } from "ai/react"; -import { useEffect, useRef, type FormEvent } from "react"; - +import { useEffect, useRef, useState, type FormEvent } from "react"; import { Input } from "@/components/ui/input"; import { Button } from "@/components/ui/button"; import { @@ -15,15 +14,20 @@ import { } from "@/components/ui/card"; import { ChatMessageBubble } from "@/components/ChatMessageBubble"; import { LoadingIcon } from "@/components/LoadingIcon"; -import { Label } from "./ui/label"; +import { Label } from "@/components/ui/label"; import { CornerDownLeft, Trash2 } from "lucide-react"; import { ConfirmAlert } from "./ConfirmAlert"; import { useContracts } from "@/utils/useContracts"; +import { ChatSettings } from "./ChatSettings"; +import { MODELS } from "@/constants"; export function ChatWindow(props: { titleText?: string }) { const { titleText } = props; const chatContainerRef = useRef(null); const { disabledKeys } = useContracts(); + const [modelName, setModelName] = useState(MODELS["openai"][0]); + const [clearOnChange, setClearOnChange] = useState(false); + const { messages, input, @@ -33,7 +37,7 @@ export function ChatWindow(props: { titleText?: string }) { isLoading, } = useChat({ api: "api/chat", - body: { disabledContractKeys: disabledKeys }, + body: { disabledContractKeys: disabledKeys, modelName }, streamProtocol: "text", onError: (e) => { toast(e.message); @@ -66,7 +70,6 @@ export function ChatWindow(props: { titleText?: string }) { return () => observer.disconnect(); } }, []); - async function sendMessage(e: FormEvent) { e.preventDefault(); if (!messages.length) { @@ -82,69 +85,86 @@ export function ChatWindow(props: { titleText?: string }) { setMessages([]); }; + const onModelNameChange = (value: string) => { + if (clearOnChange) { + onClearMessages(); + } + setModelName(value); + }; + return ( - - - {titleText} - - -
-
- {messages.length > 0 - ? messages.map((m) => ( - - )) - : null} +
+ + + {titleText} + + +
+
+
+ {messages.length > 0 + ? messages.map((m) => ( + + )) + : null} +
+
-
- - -
- - -
- - - Clear messages - - } + + + + + - -
-
-
- +
+ + + Clear messages + + } + /> + +
+ + + + +
); } diff --git a/components/ui/checkbox.tsx b/components/ui/checkbox.tsx new file mode 100644 index 0000000..df61a13 --- /dev/null +++ b/components/ui/checkbox.tsx @@ -0,0 +1,30 @@ +"use client" + +import * as React from "react" +import * as CheckboxPrimitive from "@radix-ui/react-checkbox" +import { Check } from "lucide-react" + +import { cn } from "@/lib/utils" + +const Checkbox = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + + +)) +Checkbox.displayName = CheckboxPrimitive.Root.displayName + +export { Checkbox } diff --git a/constants/index.ts b/constants/index.ts index 4daa0f9..4e56b04 100644 --- a/constants/index.ts +++ b/constants/index.ts @@ -1,5 +1,16 @@ import { IContract } from "@/types"; +export const MODELS = { + openai: ["gpt-4o-mini", "gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-latest"], + together: [ + "mistralai/Mistral-7B-Instruct-v0.3", + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + ], + ollama: ["mistral", "llama3.1"], +}; + export const CHAINS = { 11155111: { name: "ETH Sepolia", diff --git a/package-lock.json b/package-lock.json index c50abc6..da71391 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,6 +16,7 @@ "@magic-sdk/admin": "^2.4.1", "@next/bundle-analyzer": "^13.4.19", "@radix-ui/react-alert-dialog": "^1.1.1", + "@radix-ui/react-checkbox": "^1.1.1", "@radix-ui/react-collapsible": "^1.1.0", "@radix-ui/react-dialog": "^1.1.1", "@radix-ui/react-label": "^2.1.0", @@ -1714,6 +1715,36 @@ } } }, + "node_modules/@radix-ui/react-checkbox": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.1.1.tgz", + "integrity": "sha512-0i/EKJ222Afa1FE0C6pNJxDq1itzcl3HChE9DwskA4th4KRse8ojx8a1nVcOjwJdbpDLcz7uol77yYnQNMHdKw==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.0", + "@radix-ui/react-presence": "1.1.0", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-use-controllable-state": "1.1.0", + "@radix-ui/react-use-previous": "1.1.0", + "@radix-ui/react-use-size": "1.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-collapsible": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.0.tgz", diff --git a/package.json b/package.json index 82d8801..cdece91 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "@magic-sdk/admin": "^2.4.1", "@next/bundle-analyzer": "^13.4.19", "@radix-ui/react-alert-dialog": "^1.1.1", + "@radix-ui/react-checkbox": "^1.1.1", "@radix-ui/react-collapsible": "^1.1.0", "@radix-ui/react-dialog": "^1.1.1", "@radix-ui/react-label": "^2.1.0", diff --git a/types.d.ts b/types.d.ts index 10ac6d9..0d1cd61 100644 --- a/types.d.ts +++ b/types.d.ts @@ -1,7 +1,9 @@ -import { CHAINS } from "./constants"; +import { CHAINS, MODELS } from "./constants"; export type ChainIdEnum = keyof typeof CHAINS; +export type InferenceEnum = keyof typeof MODELS; + export type IContract = { key: number; address: string; diff --git a/utils/reasoningPrompt.ts b/utils/reasoningPrompt.ts index b13c5ac..0fdbc3a 100644 --- a/utils/reasoningPrompt.ts +++ b/utils/reasoningPrompt.ts @@ -39,14 +39,20 @@ const trimIfStartsWith = (str: string, prefix: string) => { }; export async function reasoningPrompt({ + modelName, input, contracts, chatHistory, }: { + modelName: string; input: string; contracts: IContract[]; chatHistory: VercelChatMessage[]; }): Promise { + if (!modelName) { + return []; + } + // Reduce contract.abi to just functions const contractFunctions = contracts .filter(({ abi }) => abi?.length) @@ -70,7 +76,7 @@ export async function reasoningPrompt({ .join("\n"); const model = new ChatOpenAI({ - model: "gpt-4o-mini", + model: modelName, temperature: 0, streaming: true, }).withStructuredOutput( diff --git a/utils/utils.ts b/utils/utils.ts new file mode 100644 index 0000000..4d2eed1 --- /dev/null +++ b/utils/utils.ts @@ -0,0 +1,11 @@ +import { MODELS } from "@/constants"; +import { InferenceEnum } from "@/types"; + +export function findModelKey(modelName: string): InferenceEnum | undefined { + for (const key in MODELS) { + if (MODELS[key as InferenceEnum].includes(modelName)) { + return key as InferenceEnum; + } + } + return undefined; +}