From 668000202077afb6b4d204ce60a61d0518af18f1 Mon Sep 17 00:00:00 2001 From: mikbry Date: Thu, 15 Feb 2024 19:37:28 +0100 Subject: [PATCH 1/4] feat: edit a local model --- webapp/components/common/Form/index.tsx | 3 +- webapp/components/models/Explorer.tsx | 47 +++++++++++++------ webapp/components/models/Model.tsx | 42 ++++++++++------- webapp/components/models/NewLocalModel.tsx | 1 + webapp/components/models/index.tsx | 20 +++++++- webapp/hooks/useParameters.ts | 53 ++++++++++++++++++++++ webapp/native/src/data/model/mod.rs | 7 +++ webapp/types/index.ts | 1 + 8 files changed, 141 insertions(+), 33 deletions(-) create mode 100644 webapp/hooks/useParameters.ts diff --git a/webapp/components/common/Form/index.tsx b/webapp/components/common/Form/index.tsx index 2c586191..4a23dfec 100644 --- a/webapp/components/common/Form/index.tsx +++ b/webapp/components/common/Form/index.tsx @@ -32,10 +32,11 @@ export default function Form({ onParametersChanged, debounceDelay = 600, }: FormProps) { + const { t } = useTranslation(); + const [updatedParameters, setUpdatedParameters] = useState( undefined, ); - const { t } = useTranslation(); const handleParameterChange = (name: string, value?: ParameterValue) => { logger.info('handleParameterChange', name, value); diff --git a/webapp/components/models/Explorer.tsx b/webapp/components/models/Explorer.tsx index f37a0f03..c6e92543 100644 --- a/webapp/components/models/Explorer.tsx +++ b/webapp/components/models/Explorer.tsx @@ -27,16 +27,21 @@ import { shortcutAsText } from '@/utils/shortcuts'; import useShortcuts, { ShortcutIds } from '@/hooks/useShortcuts'; import ContextMenuList from '../ui/ContextMenu/ContextMenuList'; import { Button } from '../ui/button'; +import EditableItem from '../common/EditableItem'; + +export type ModelsExplorerProps = { + models: Model[]; + selectedModelId?: string; + collection: Model[]; + onModelRename: (id: string, name: string) => void; +}; function ModelsExplorer({ models, selectedModelId, collection, -}: { - models: Model[]; - selectedModelId?: string; - collection: Model[]; -}) { + onModelRename, +}: ModelsExplorerProps) { const router = useRouter(); const { t } = useTranslation(); const { showModal } = useContext(ModalsContext); @@ -51,6 +56,11 @@ function ModelsExplorer({ showModal(ModalIds.NewLocalModel); }; + const handleChangeModelName = (id: string, name: string) => { + logger.info(`change model name ${id} ${name}`); + onModelRename(id, name); + }; + useShortcuts(ShortcutIds.INSTALL_MODEL, (event) => { event.preventDefault(); logger.info('shortcut install Model'); @@ -118,13 +128,24 @@ function ModelsExplorer({ className="flex cursor-pointer flex-row items-center" tabIndex={0} > -
-
-
- {model.title || model.name} + {!model.editable && ( +
+
+ {model.title || model.name}
-
+ )} + {model.editable && ( +
+ +
+ )}
@@ -166,10 +187,8 @@ function ModelsExplorer({ tabIndex={0} >
-
-
- {model.title || model.name} -
+
+ {model.title || model.name}
diff --git a/webapp/components/models/Model.tsx b/webapp/components/models/Model.tsx index 42ee3af0..7b4eb24e 100644 --- a/webapp/components/models/Model.tsx +++ b/webapp/components/models/Model.tsx @@ -32,6 +32,7 @@ import { DownloadIcon } from '@radix-ui/react-icons'; import useTranslation from '@/hooks/useTranslation'; import { Model } from '@/types'; import { getEntityName, getResourceUrl } from '@/utils/data'; +import useParameters, { ParametersCallback } from '@/hooks/useParameters'; import Parameter from '../common/Parameter'; import { Button } from '../ui/button'; import { Table, TableBody, TableRow, TableCell, TableHeader, TableHead } from '../ui/table'; @@ -42,20 +43,25 @@ import { Table, TableBody, TableRow, TableCell, TableHeader, TableHead } from '. DropdownMenuTrigger, } from '../ui/dropdown-menu'; */ +export type ModelViewProps = { + model: Model; + isDownloading: boolean; + local: boolean; + downloadables: Model[]; + onChange: (item?: Model) => void; + onParametersChange: ParametersCallback; +}; + function ModelView({ model, isDownloading, local, downloadables, onChange, -}: { - model: Model; - isDownloading: boolean; - local: boolean; - downloadables: Model[]; - onChange: (item?: Model) => void; -}) { + onParametersChange, +}: ModelViewProps) { const { t } = useTranslation(); + const [updatedParameters, setUpdatedParameters] = useParameters(onParametersChange); if (!model) { return null; @@ -116,9 +122,10 @@ function ModelView({ {model.fileName && ( )} {getEntityName(model.creator).toLowerCase() !== @@ -144,7 +152,7 @@ function ModelView({ title={t('Creator')} name="version" value={`${getEntityName(model.creator)}`} - disabled + disabled={!model.editable} type="text" /> )} @@ -157,7 +165,7 @@ function ModelView({ title={t('Publisher')} name="version" value={`${getEntityName(model.publisher)}`} - disabled + disabled={!model.editable} type="text" /> )} @@ -165,21 +173,21 @@ function ModelView({ title={t('Version')} name="version" value={`${model.version}`} - disabled + disabled={!model.editable} type="text" />
diff --git a/webapp/components/models/NewLocalModel.tsx b/webapp/components/models/NewLocalModel.tsx index bd57c6ab..9c99dd7a 100644 --- a/webapp/components/models/NewLocalModel.tsx +++ b/webapp/components/models/NewLocalModel.tsx @@ -107,6 +107,7 @@ function NewLocalModel({ toast.error(`File not found ${file}`); return; } + model.editable = true; const id = await installModel(model, undefined, filepath, download); await updateBackendStore(); logger.info('onLocalInstall', id, model, filepath, download); diff --git a/webapp/components/models/index.tsx b/webapp/components/models/index.tsx index 62b68105..02c8a0b5 100644 --- a/webapp/components/models/index.tsx +++ b/webapp/components/models/index.tsx @@ -27,6 +27,7 @@ import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '../ui/resi import Explorer from './Explorer'; import ModelView from './Model'; import NewLocalModel from './NewLocalModel'; +import { ParametersRecord } from '../common/Parameter'; export default function Models({ selectedModelId }: { selectedModelId?: string }) { const { backendContext, updateBackendStore } = useBackend(); @@ -108,6 +109,17 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } } }; + const handleParametersChange = (parameters: ParametersRecord) => { + logger.info(`change model parameters ${parameters}`); + // onParametersChange(id, parameters); + return undefined; + }; + + const handleModelRename = (id: string, name: string) => { + logger.info(`change model name ${id} ${name}`); + // onModelRename(id, name); + }; + const { downloads = [] } = backendContext; const isDownloading = downloads.findIndex((d) => d.id === model?.id) !== -1; @@ -115,7 +127,12 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } return ( - + @@ -125,6 +142,7 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } local={local} downloadables={downloadables} onChange={handleChange} + onParametersChange={handleParametersChange} /> diff --git a/webapp/hooks/useParameters.ts b/webapp/hooks/useParameters.ts new file mode 100644 index 00000000..3d2e0c93 --- /dev/null +++ b/webapp/hooks/useParameters.ts @@ -0,0 +1,53 @@ +// Copyright 2024 mik +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { ParameterValue, ParametersRecord } from '@/components/common/Parameter'; +import logger from '@/utils/logger'; +import { useState } from 'react'; +import useDebounceFunc from './useDebounceFunc'; + +export type ParametersCallback = (params: ParametersRecord) => ParametersRecord | undefined; + +function useParameters( + onParametersChanged: ParametersCallback, + debounceDelay = 600, +): [ParametersRecord | undefined, (name: string, value?: ParameterValue) => void] { + const [updatedParameters, setUpdatedParameters] = useState( + undefined, + ); + + const handleParameterChange = (name: string, value?: ParameterValue) => { + logger.info('handleParameterChange', name, value); + const newParams = updatedParameters || {}; + if (newParams[name] !== value) { + setUpdatedParameters({ ...newParams, [name]: value }); + } + }; + + const updateParameters = (newParameters: ParametersRecord) => { + logger.info('updateParameters', newParameters); + let changedParameters = onParametersChanged(newParameters); + if (changedParameters && Object.keys(changedParameters).length === 0) { + changedParameters = undefined; + } + // TODO handle errors + setUpdatedParameters(changedParameters); + }; + + useDebounceFunc(updateParameters, updatedParameters, debounceDelay); + + return [updatedParameters, handleParameterChange]; +} + +export default useParameters; diff --git a/webapp/native/src/data/model/mod.rs b/webapp/native/src/data/model/mod.rs index fd66c0b8..a97b91e3 100644 --- a/webapp/native/src/data/model/mod.rs +++ b/webapp/native/src/data/model/mod.rs @@ -139,11 +139,17 @@ pub struct Model { )] pub paper: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] pub include: Option>, + #[serde(skip_serializing_if = "Option::is_none", default)] pub system: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] pub context_window: Option, + + #[serde(skip_serializing_if = "Option::is_none", default)] + pub editable: Option, } impl Model { @@ -183,6 +189,7 @@ impl Model { include: None, system: None, context_window: None, + editable: None, } } diff --git a/webapp/types/index.ts b/webapp/types/index.ts index 1fbc3293..ade53aa5 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -218,6 +218,7 @@ export type Model = BaseNamedRecord & { system?: string; contextWindow?: number; + editable?: boolean; }; export type ModelsCollection = { From 553d55a173743ac240e2d8ce3d2befa13a992893 Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 16 Feb 2024 10:03:07 +0100 Subject: [PATCH 2/4] feat: useParameters in Form --- webapp/components/common/Form/index.tsx | 37 ++++--------------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/webapp/components/common/Form/index.tsx b/webapp/components/common/Form/index.tsx index 4a23dfec..af14a961 100644 --- a/webapp/components/common/Form/index.tsx +++ b/webapp/components/common/Form/index.tsx @@ -12,51 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { useState } from 'react'; import { ParametersDefinition } from '@/types'; import useTranslation from '@/hooks/useTranslation'; -import useDebounceFunc from '@/hooks/useDebounceFunc'; -import logger from '@/utils/logger'; +import useParameters from '@/hooks/useParameters'; import Parameter, { ParameterValue, ParametersRecord } from '../Parameter'; export type FormProps = { parameters: Record | undefined; parametersDefinition: ParametersDefinition; debounceDelay?: number; - onParametersChanged: (params: ParametersRecord) => ParametersRecord | undefined; + onParametersChange: (params: ParametersRecord) => ParametersRecord | undefined; }; export default function Form({ parameters, parametersDefinition, - onParametersChanged, - debounceDelay = 600, + onParametersChange, + debounceDelay, }: FormProps) { const { t } = useTranslation(); - - const [updatedParameters, setUpdatedParameters] = useState( - undefined, - ); - - const handleParameterChange = (name: string, value?: ParameterValue) => { - logger.info('handleParameterChange', name, value); - const newParams = updatedParameters || {}; - if (newParams[name] !== value) { - setUpdatedParameters({ ...newParams, [name]: value }); - } - }; - - const updateParameters = (newParameters: ParametersRecord) => { - logger.info('updateParameters', newParameters); - let changedParameters = onParametersChanged(newParameters); - if (changedParameters && Object.keys(changedParameters).length === 0) { - changedParameters = undefined; - } - // TODO handle errors - setUpdatedParameters(changedParameters); - }; - - useDebounceFunc(updateParameters, updatedParameters, debounceDelay); + const [updatedParameters, setUpdatedParameters] = useParameters(onParametersChange, debounceDelay); return (
@@ -73,7 +48,7 @@ export default function Form({ } description={t(parametersDefinition[key].description)} inputCss="max-w-20 pl-2" - onChange={handleParameterChange} + onChange={setUpdatedParameters} /> ))} From db031cd07b32482df6d05f30cc748fbb426c322b Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 16 Feb 2024 10:03:45 +0100 Subject: [PATCH 3/4] chore: prettier --- webapp/components/common/Form/index.tsx | 5 ++++- webapp/components/threads/Settings.tsx | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/webapp/components/common/Form/index.tsx b/webapp/components/common/Form/index.tsx index af14a961..b63906c8 100644 --- a/webapp/components/common/Form/index.tsx +++ b/webapp/components/common/Form/index.tsx @@ -31,7 +31,10 @@ export default function Form({ debounceDelay, }: FormProps) { const { t } = useTranslation(); - const [updatedParameters, setUpdatedParameters] = useParameters(onParametersChange, debounceDelay); + const [updatedParameters, setUpdatedParameters] = useParameters( + onParametersChange, + debounceDelay, + ); return (
diff --git a/webapp/components/threads/Settings.tsx b/webapp/components/threads/Settings.tsx index e6df5361..725f10fc 100644 --- a/webapp/components/threads/Settings.tsx +++ b/webapp/components/threads/Settings.tsx @@ -189,7 +189,7 @@ export default function Settings({ conversationId }: { conversationId?: string } parameters={selectedConversation?.parameters} parametersDefinition={parametersDefinition} - onParametersChanged={updateParameters} + onParametersChange={updateParameters} /> From c1733ca6d39ec5fe52c77f86e14bdd637a4b55b3 Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 16 Feb 2024 11:38:34 +0100 Subject: [PATCH 4/4] feat: update and save model --- .../components/common/EditableItem/index.tsx | 5 +- webapp/components/common/Form/index.tsx | 9 ++- webapp/components/models/Model.tsx | 50 ++++++++------- webapp/components/models/index.tsx | 63 +++++++++++++++---- webapp/components/providers/opla/index.tsx | 2 +- webapp/components/threads/Settings.tsx | 23 +++---- webapp/hooks/useParameters.ts | 10 ++- webapp/hooks/useProviderState.ts | 11 +++- webapp/native/src/data/model/mod.rs | 11 +++- webapp/native/src/main.rs | 17 +++++ webapp/utils/backend/commands.ts | 9 +++ webapp/utils/data/index.ts | 40 +++++++----- 12 files changed, 177 insertions(+), 73 deletions(-) diff --git a/webapp/components/common/EditableItem/index.tsx b/webapp/components/common/EditableItem/index.tsx index 2b986143..0e48ab9c 100644 --- a/webapp/components/common/EditableItem/index.tsx +++ b/webapp/components/common/EditableItem/index.tsx @@ -34,7 +34,10 @@ export default function EditableItem({ const [changedValue, setChangedValue] = useState(undefined); const onDebouncedChange = (value: string) => { - onChange?.(value, id); + if (value !== title) { + onChange?.(value, id); + } + // setChangedValue(undefined); }; useDebounceFunc(onDebouncedChange, changedValue, 500); diff --git a/webapp/components/common/Form/index.tsx b/webapp/components/common/Form/index.tsx index b63906c8..34467808 100644 --- a/webapp/components/common/Form/index.tsx +++ b/webapp/components/common/Form/index.tsx @@ -14,17 +14,19 @@ import { ParametersDefinition } from '@/types'; import useTranslation from '@/hooks/useTranslation'; -import useParameters from '@/hooks/useParameters'; -import Parameter, { ParameterValue, ParametersRecord } from '../Parameter'; +import useParameters, { ParametersCallback } from '@/hooks/useParameters'; +import Parameter, { ParameterValue } from '../Parameter'; export type FormProps = { + id: string | undefined; parameters: Record | undefined; parametersDefinition: ParametersDefinition; debounceDelay?: number; - onParametersChange: (params: ParametersRecord) => ParametersRecord | undefined; + onParametersChange: ParametersCallback; }; export default function Form({ + id, parameters, parametersDefinition, onParametersChange, @@ -32,6 +34,7 @@ export default function Form({ }: FormProps) { const { t } = useTranslation(); const [updatedParameters, setUpdatedParameters] = useParameters( + id, onParametersChange, debounceDelay, ); diff --git a/webapp/components/models/Model.tsx b/webapp/components/models/Model.tsx index 7b4eb24e..719348f9 100644 --- a/webapp/components/models/Model.tsx +++ b/webapp/components/models/Model.tsx @@ -61,7 +61,7 @@ function ModelView({ onParametersChange, }: ModelViewProps) { const { t } = useTranslation(); - const [updatedParameters, setUpdatedParameters] = useParameters(onParametersChange); + const [updatedParameters, setUpdatedParameters] = useParameters(model?.id, onParametersChange); if (!model) { return null; @@ -169,27 +169,33 @@ function ModelView({ type="text" /> )} - - - + {model.version && ( + + )} + {model.license && ( + + )} + {model.repository && ( + + )} {downloadables.length > 0 && (
diff --git a/webapp/components/models/index.tsx b/webapp/components/models/index.tsx index 02c8a0b5..4e1728c4 100644 --- a/webapp/components/models/index.tsx +++ b/webapp/components/models/index.tsx @@ -18,9 +18,14 @@ import { useEffect, useState } from 'react'; import { useRouter } from 'next/router'; import { Model } from '@/types'; import logger from '@/utils/logger'; -import { getModelsCollection, installModel, uninstallModel } from '@/utils/backend/commands'; +import { + getModelsCollection, + installModel, + uninstallModel, + updateModel, +} from '@/utils/backend/commands'; import useBackend from '@/hooks/useBackendContext'; -import { deepMerge, getEntityName, getResourceUrl } from '@/utils/data'; +import { deepCopy, deepMerge, getEntityName, getResourceUrl } from '@/utils/data'; import { getDownloadables, isValidFormat } from '@/utils/data/models'; import { Page } from '@/types/ui'; import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '../ui/resizable'; @@ -45,7 +50,6 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } }; getCollection(); }, []); - logger.info('collection: ', collection); const models = backendContext.config.models.items; let local = true; @@ -59,7 +63,7 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } : getDownloadables(model).filter((d) => d.private !== true && isValidFormat(d)); const handleInstall = async (item?: Model) => { - const selectedModel: Model = deepMerge(model, item || {}, true); + const selectedModel: Model = deepMerge(model, item || {}, true); logger.info(`install ${model.name}`, selectedModel, item); if (selectedModel.private === true) { delete selectedModel.private; @@ -88,7 +92,7 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } router.replace(`/models${nextModelId ? `/${nextModelId}` : ''}`); }; - const handleChange = (selectedModel?: Model) => { + const handleLocalInstall = (selectedModel?: Model) => { if (local && !selectedModel) { // showModal(ModalIds.DeleteItem, { item: model, onAction: onUninstall }); handleUninstall(); @@ -109,15 +113,52 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } } }; - const handleParametersChange = (parameters: ParametersRecord) => { - logger.info(`change model parameters ${parameters}`); + const handleParametersChange = async (id: string | undefined, parameters: ParametersRecord) => { + logger.info(`change model parameters ${id} ${parameters}`); // onParametersChange(id, parameters); + let updatedModel = models.find((m) => m.id === id) as Model; + if (updatedModel) { + let needUpdate = false; + updatedModel = deepCopy(updatedModel); + Object.keys(parameters).forEach((key) => { + switch (key) { + case 'name': + if (parameters[key] !== updatedModel.name) { + updatedModel.name = parameters[key] as string; + needUpdate = true; + } + break; + case 'description': + if (parameters[key] !== updatedModel.description) { + updatedModel.description = parameters[key] as string; + needUpdate = true; + } + break; + case 'author': + if (parameters[key] !== updatedModel.author) { + updatedModel.author = parameters[key] as string; + needUpdate = true; + } + break; + default: + logger.warn(`unknown parameter ${key}`); + } + }); + if (needUpdate) { + await updateModel(updatedModel); + await updateBackendStore(); + } + } return undefined; }; - const handleModelRename = (id: string, name: string) => { - logger.info(`change model name ${id} ${name}`); - // onModelRename(id, name); + const handleModelRename = async (name: string, id: string) => { + const updatedModel = models.find((m) => m.id === id); + logger.info(`change model name ${id} ${name}`, updatedModel, models); + if (updatedModel && updatedModel.name !== name) { + await updateModel({ ...updatedModel, name }); + await updateBackendStore(); + } }; const { downloads = [] } = backendContext; @@ -141,7 +182,7 @@ export default function Models({ selectedModelId }: { selectedModelId?: string } isDownloading={isDownloading} local={local} downloadables={downloadables} - onChange={handleChange} + onChange={handleLocalInstall} onParametersChange={handleParametersChange} /> diff --git a/webapp/components/providers/opla/index.tsx b/webapp/components/providers/opla/index.tsx index c62ec75b..9af164a0 100644 --- a/webapp/components/providers/opla/index.tsx +++ b/webapp/components/providers/opla/index.tsx @@ -45,7 +45,7 @@ export default function Opla({ (provider, 'description'))} disabled type="large-text" /> diff --git a/webapp/components/threads/Settings.tsx b/webapp/components/threads/Settings.tsx index 725f10fc..696bad46 100644 --- a/webapp/components/threads/Settings.tsx +++ b/webapp/components/threads/Settings.tsx @@ -71,19 +71,22 @@ export default function Settings({ conversationId }: { conversationId?: string } } }; - const updateParameters = (params: ParametersRecord): ParametersRecord | undefined => { + const updateParameters = async ( + id: string | undefined, + params: ParametersRecord, + ): Promise => { let newParams: ParametersRecord | undefined; - if (selectedConversation) { + if (id && selectedConversation) { const { parameters = {} } = selectedConversation; let newConversation: Conversation | undefined; newParams = { ...params }; - let update = false; + let needUpdate = false; Object.keys(params).forEach((key) => { const value = params[key]; if (value === undefined) { delete parameters[key]; delete newParams?.[key]; - update = true; + needUpdate = true; } else { const parameterDef = parametersDefinition[key]; const result = parameterDef.z.safeParse(value); @@ -93,12 +96,12 @@ export default function Settings({ conversationId }: { conversationId?: string } } else { parameters[key] = result.data; delete newParams?.[key]; - update = true; + needUpdate = true; } } }); - if (update) { + if (needUpdate) { if (selectedConversation.parameters && Object.keys(parameters).length === 0) { newConversation = { ...selectedConversation }; delete newConversation.parameters; @@ -113,13 +116,6 @@ export default function Settings({ conversationId }: { conversationId?: string } return newParams; }; - /* const handleParameterChange = (name: string, value?: ParameterValue) => { - logger.info('handleParameterChange', name, value); - setParams({ ...params, [name]: value }); - }; - - useDebounceFunc(updateParameters, params, 600); */ - const handlePolicyChange = (policy: ContextWindowPolicy) => { if (selectedConversation) { const newConversations = updateConversation( @@ -187,6 +183,7 @@ export default function Settings({ conversationId }: { conversationId?: string } {t('Parameters')} + id={selectedConversation?.id} parameters={selectedConversation?.parameters} parametersDefinition={parametersDefinition} onParametersChange={updateParameters} diff --git a/webapp/hooks/useParameters.ts b/webapp/hooks/useParameters.ts index 3d2e0c93..83fd1c4f 100644 --- a/webapp/hooks/useParameters.ts +++ b/webapp/hooks/useParameters.ts @@ -17,9 +17,13 @@ import logger from '@/utils/logger'; import { useState } from 'react'; import useDebounceFunc from './useDebounceFunc'; -export type ParametersCallback = (params: ParametersRecord) => ParametersRecord | undefined; +export type ParametersCallback = ( + key: string | undefined, + params: ParametersRecord, +) => Promise; function useParameters( + key: string | undefined, onParametersChanged: ParametersCallback, debounceDelay = 600, ): [ParametersRecord | undefined, (name: string, value?: ParameterValue) => void] { @@ -35,9 +39,9 @@ function useParameters( } }; - const updateParameters = (newParameters: ParametersRecord) => { + const updateParameters = async (newParameters: ParametersRecord) => { logger.info('updateParameters', newParameters); - let changedParameters = onParametersChanged(newParameters); + let changedParameters = await onParametersChanged(key, newParameters); if (changedParameters && Object.keys(changedParameters).length === 0) { changedParameters = undefined; } diff --git a/webapp/hooks/useProviderState.ts b/webapp/hooks/useProviderState.ts index 5bf783e7..08a73042 100644 --- a/webapp/hooks/useProviderState.ts +++ b/webapp/hooks/useProviderState.ts @@ -52,13 +52,20 @@ const useProviderState = (providerId?: string, newProvider?: Provider) => { }, [backendContext, hasParametersChanged, providerId, providers, updatedProvider, newProvider]); const handleParameterChange = (name: string, value: ParameterValue) => { - const mergedProvider = deepSet(updatedProvider, name, value); + const mergedProvider = deepSet( + updatedProvider as Provider, + name, + value, + ); logger.info('handleParameterChange', name, value, mergedProvider); setUpdatedProvider(mergedProvider); }; const handleParametersSave = (partialProvider: Partial = {}) => { - const mergedProvider = deepMerge(provider, partialProvider); + if (!provider) { + return; + } + const mergedProvider = deepMerge(provider, partialProvider); const newProviders = updateProvider(mergedProvider, providers); logger.info('handleParametersSave', mergedProvider, newProviders); setProviders(newProviders); diff --git a/webapp/native/src/data/model/mod.rs b/webapp/native/src/data/model/mod.rs index a97b91e3..66718286 100644 --- a/webapp/native/src/data/model/mod.rs +++ b/webapp/native/src/data/model/mod.rs @@ -351,14 +351,19 @@ impl ModelStorage { self.items.retain(|m| !m.reference.is_same_id(id)); } - pub fn update_model(&mut self, model: ModelEntity) { + pub fn update_model(&mut self, model: Model) { if let Some(index) = self.items .iter() - .position(|m| m.reference.is_same_model(&model.reference)) + .position(|m| m.reference.is_same_model(&model)) { + let mut model_entity = match self.items.get(index) { + Some(model_entity) => model_entity.clone(), + None => return, + }; + model_entity.reference = model; self.items.remove(index); - self.items.insert(index, model); + self.items.insert(index, model_entity.clone()); } } } diff --git a/webapp/native/src/main.rs b/webapp/native/src/main.rs index bca213a9..d3619faf 100644 --- a/webapp/native/src/main.rs +++ b/webapp/native/src/main.rs @@ -259,6 +259,22 @@ async fn install_model( Ok(model_id.clone()) } +#[tauri::command] +async fn update_model( + _app: tauri::AppHandle, + _window: tauri::Window, + context: State<'_, OplaContext>, + model: Model, +) -> Result<(), String> { + let mut store = context.store.lock().map_err(|err| err.to_string())?; + + store.models.update_model(model); + + store.save().map_err(|err| err.to_string())?; + + Ok(()) +} + #[tauri::command] async fn uninstall_model( _app: tauri::AppHandle, @@ -611,6 +627,7 @@ fn main() { search_hfhub_models, install_model, uninstall_model, + update_model, set_active_model, llm_call_completion ] diff --git a/webapp/utils/backend/commands.ts b/webapp/utils/backend/commands.ts index 70c8942a..aa0fc2a9 100644 --- a/webapp/utils/backend/commands.ts +++ b/webapp/utils/backend/commands.ts @@ -82,3 +82,12 @@ export const uninstallModel = async (modelId: String) => { const id = (await invokeTauri('uninstall_model', { modelId })) as String; return id; }; + +export const updateModel = async (model: Model) => { + try { + await invokeTauri('update_model', { model }); + } catch (error) { + logger.error(error); + toast.error(`Error installing model ${error}`); + } +}; diff --git a/webapp/utils/data/index.ts b/webapp/utils/data/index.ts index d0cfd5f4..a74502a6 100644 --- a/webapp/utils/data/index.ts +++ b/webapp/utils/data/index.ts @@ -39,13 +39,16 @@ const createBaseNamedRecord = (name: string, description?: string) => { return item; }; -const deepCopy = (obj: any) => - window?.structuredClone ? window.structuredClone(obj) : JSON.parse(JSON.stringify(obj)); +const deepCopy = (obj: T): T => + window?.structuredClone + ? (window.structuredClone(obj) as T) + : (JSON.parse(JSON.stringify(obj)) as T); -const deepMerge = (_target: any, source: any, copy = false) => { - const target = copy ? deepCopy(_target) : _target; - Object.keys(source).forEach((key: string) => { - const value = source[key]; +const deepMerge = (_target: T, source: Partial, copy = false): T => { + const target = (copy ? deepCopy(_target) : _target) as Record; + const obj = source as Record; + Object.keys(obj).forEach((key: string) => { + const value = obj[key]; if (value !== null && typeof value === 'object') { if (typeof target[key] !== 'object') { target[key] = {}; @@ -55,14 +58,19 @@ const deepMerge = (_target: any, source: any, copy = false) => { target[key] = value; } }); - return target; + return target as T; }; -const deepSet = (obj: any, path: string, _value: any, root = path): any => { +const deepSet = (obj: V, path: string, _value: T, root = path): Record => { const [property, ...properties] = path.split('.'); let value = _value; - if (properties.length) { - value = deepSet(obj[property] || {}, properties.join('.'), _value, root); + if (typeof obj === 'object' && properties.length) { + value = deepSet( + (obj as Record)[property] || {}, + properties.join('.'), + _value, + root, + ) as T; } if (typeof obj !== 'object') { throw new Error(`Path '${root}' is not accessible`); @@ -70,15 +78,19 @@ const deepSet = (obj: any, path: string, _value: any, root = path): any => { return { ...obj, [property]: value }; }; -const deepGet = (obj: any, path: string, defaultValue?: any, root = path): any => { +const deepGet = (obj: T, path: string, defaultValue?: V, root = path): V => { const [property, ...properties] = path.split('.'); + if (obj === undefined || obj === null) { + return defaultValue as V; + } + const prop = (obj as Record)[property]; if (properties.length) { - return deepGet(obj[property], properties.join('.'), root); + return deepGet(prop as Record, properties.join('.'), defaultValue, root); } if (typeof obj !== 'object' || property in obj === false) { - return defaultValue; + return defaultValue as V; } - return obj[property]; + return prop; }; // Inspiration: https://github.com/rayepps/radash/blob/31c1397437d7fb7a78e97499c8d46f992c49844c/src/object.ts