Skip to content

Commit

Permalink
Merge pull request #381 from Opla/feat_376d
Browse files Browse the repository at this point in the history
feat: refactor conversation connectors
  • Loading branch information
mikbry authored Mar 1, 2024
2 parents 7177840 + c793a48 commit 1255ee9
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 15 deletions.
5 changes: 4 additions & 1 deletion webapp/components/views/Threads/Explorer/AssistantsList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export default function AssistantsList({ selectedId }: AssistantsListProps) {
<ExplorerGroup
title="Assistants"
toolbar={
<Button variant="outline" size="sm">
<Button variant="outline" size="sm" className="text-primary">
<Store className="mr-2 h-4 w-4" strokeWidth={1.5} />
{t('Explore the store')}
</Button>
Expand All @@ -43,6 +43,9 @@ export default function AssistantsList({ selectedId }: AssistantsListProps) {
<ExplorerList<Assistant>
selectedId={selectedId || OplaAssistant.id}
items={assistants}
getItemTitle={(assistant) =>
assistant.id === OplaAssistant.id ? t('Use your local AI Models') : assistant.name
}
renderLeftSide={(assistant) =>
assistant.id === OplaAssistant.id ? (
<Opla className="h-4 w-4" />
Expand Down
45 changes: 33 additions & 12 deletions webapp/components/views/Threads/Thread.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import { AppContext } from '@/context';
import {
Asset,
Conversation,
ConversationConnector,
ConversationConnectorType,
LlmParameters,
Message,
MessageStatus,
Expand All @@ -38,6 +40,11 @@ import {
updateOrCreateConversation,
addAssetsToConversation,
getConversationAssets,
getConnectorModelId,
getConversationModelId,
addConnector,
getConversationProvider,
addConversationConnector,
} from '@/utils/data/conversations';
import useBackend from '@/hooks/useBackendContext';
import { buildContext, completion, getCompletionParametersDefinition } from '@/utils/providers';
Expand Down Expand Up @@ -105,10 +112,11 @@ function Thread({
} = useContext(AppContext);
const { backendContext, setActiveModel } = useBackend();
const { activeModel: aModel } = backendContext.config.models;
const [tempModelProvider, setTempModelProvider] = useState<[string, ProviderType] | undefined>(
/* const [tempModelProvider, setTempModelProvider] = useState<[string, ProviderType] | undefined>(
undefined,
);
const activeModel = tempModelProvider?.[0] || aModel;
); */
const [connector, setConnector] = useState<ConversationConnector | undefined>(undefined);
const activeModel = getConnectorModelId(connector) || aModel;
const [tempConversationId, setTempConversationId] = useState<string | undefined>(undefined);
const conversationId = _conversationId || tempConversationId;
const selectedConversation = conversations.find((c) => c.id === conversationId);
Expand Down Expand Up @@ -162,7 +170,7 @@ function Thread({
]);

const showEmptyChat = !conversationId;
const selectedModel = selectedConversation?.model || activeModel;
const selectedModel = getConversationModelId(selectedConversation) || activeModel;

const { modelItems, commandManager } = useMemo(() => {
const items = getModelsAsItems(providers, backendContext, selectedModel);
Expand Down Expand Up @@ -228,17 +236,24 @@ function Thread({
`handleSelectModel ${model} ${provider} activeModel=${typeof activeModel}`,
selectedConversation,
);
const newConnector: ConversationConnector = {
type: ConversationConnectorType.Model,
modelId: model as string,
provider,
};
if (model && selectedConversation) {
const connectors = addConnector(selectedConversation.connectors, newConnector);
const newConversations = updateConversation(
{ ...selectedConversation, model, provider, parameters: {}, ...partial },
{ ...selectedConversation, connectors, parameters: {}, ...partial },
conversations,
true,
);
updateConversations(newConversations);
} else if (model && !activeModel) {
await setActiveModel(model);
} else if (model) {
setTempModelProvider([model, provider]);
// setTempModelProvider([model, provider]);
setConnector(newConnector);
}
};

Expand All @@ -252,14 +267,16 @@ function Thread({
let providerName: string | undefined = model?.provider;
const returnedMessage = { ...message };
let provider: Provider | undefined;
if (conversation.provider && conversation.model) {
provider = findProvider(conversation.provider, providers);
model = findModel(conversation.model, provider?.models || []);
const conversationProvider = getConversationProvider(conversation);
const conversationModel = getConversationModelId(conversation);
if (conversationProvider && conversationModel) {
provider = findProvider(conversationProvider, providers);
model = findModel(conversationModel, provider?.models || []);
if (provider) {
providerName = provider.name;
}
}
const modelName = message.author.name || model?.name || conversation.model || activeModel;
const modelName = message.author.name || model?.name || conversationModel || activeModel;
if (!model || model.name !== modelName) {
model = findModelInAll(modelName, providers, backendContext);
}
Expand Down Expand Up @@ -628,16 +645,20 @@ function Thread({
newConversation.temp = true;
newConversation.name = conversationName;
newConversation.currentPrompt = prompt;
if (tempModelProvider) {
/* if (tempModelProvider) {
[newConversation.model, newConversation.provider] = tempModelProvider;
setTempModelProvider(undefined);
} */
if (connector) {
addConversationConnector(newConversation, connector);
setConnector(undefined);
}
setTempConversationId(newConversation.id);
}
updateConversations(updatedConversations);
setChangedPrompt(undefined);
},
[tempConversationId, conversationId, conversations, updateConversations, tempModelProvider],
[tempConversationId, conversationId, conversations, updateConversations, connector],
);

useDebounceFunc<ParsedPrompt | undefined>(handleUpdatePrompt, changedPrompt, 500);
Expand Down
4 changes: 3 additions & 1 deletion webapp/components/views/Threads/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ export default function Threads({ selectedThreadId, view = ViewName.Recent }: Th
return (
<ResizablePanelGroup direction="horizontal">
<ResizablePanel
minSize={10}
minSize={14}
maxSize={40}
defaultSize={pageSettings.explorerWidth}
onResize={handleResizeExplorer}
className={pageSettings.explorerHidden === true ? 'hidden' : ''}
Expand Down Expand Up @@ -233,6 +234,7 @@ export default function Threads({ selectedThreadId, view = ViewName.Recent }: Th
<ResizablePanel
minSize={20}
defaultSize={20}
maxSize={50}
onResize={handleResizeSettings}
className={!pageSettings.settingsHidden && view === ViewName.Recent ? '' : 'hidden'}
>
Expand Down
25 changes: 25 additions & 0 deletions webapp/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,39 @@ export type ConversationUsage = {
totalPerSecond?: number;
};

export enum ConversationConnectorType {
Model = 'model',
Assistant = 'assistant',
}

export type ConversationConnector = {
disabled?: boolean;
} & (
| {
type: ConversationConnectorType.Model;
modelId: string;
provider?: ProviderType;
}
| {
type: ConversationConnectorType.Assistant;
assistantId: string;
targetId?: string;
}
);

export type Conversation = BaseNamedRecord & {
messages: Message[] | undefined;
pluginIds?: string[];
preset?: string;
currentPrompt?: string | ParsedPrompt;
note?: string;

// Deprecated replaced by connectors
model?: string;
provider?: string;

connectors?: ConversationConnector[];

importedFrom?: string;
temp?: boolean;

Expand Down
87 changes: 86 additions & 1 deletion webapp/utils/data/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Asset, ContextWindowPolicy, Conversation } from '@/types';
import {
Asset,
ContextWindowPolicy,
Conversation,
ConversationConnector,
ConversationConnectorType,
ProviderType,
} from '@/types';
import { createBaseRecord, createBaseNamedRecord, updateRecord } from '.';

export const getConversationAssets = (conversation: Conversation) =>
Expand Down Expand Up @@ -110,3 +117,81 @@ export const mergeConversations = (
});
return Array.from(conversationMap.values());
};

export const getConversationConnector = (
conversation: Conversation,
connectorType: ConversationConnectorType,
) => {
let connector = conversation.connectors?.find((c) => c.type === connectorType);
if (!connector) {
if (conversation.model && connectorType === ConversationConnectorType.Model) {
connector = {
type: connectorType,
modelId: conversation.model,
provider: conversation.provider as ProviderType,
};
}
}
return connector;
};

export const addConnector = (
_connectors: ConversationConnector[] | undefined,
connector: ConversationConnector,
): ConversationConnector[] => {
const connectors = _connectors || [];
const index = connectors?.findIndex((c) => c.type === connector.type) ?? -1;
if (index !== -1) {
connectors[index] = connector;
} else {
connectors.push(connector);
}
return connectors;
};

export const addConversationConnector = (
conversation: Conversation,
connector: ConversationConnector,
): Conversation => {
const index = conversation.connectors?.findIndex((c) => c.type === connector.type) ?? -1;
const connectors = conversation.connectors || [];
if (index !== -1) {
connectors[index] = connector;
} else {
connectors.push(connector);
}
return {
...conversation,
connectors,
};
};

export const getConnectorModelId = (modelConnector: ConversationConnector | undefined) => {
if (modelConnector && modelConnector.type === ConversationConnectorType.Model) {
return modelConnector.modelId;
}
return undefined;
};

export const getConnectorProvider = (modelConnector: ConversationConnector | undefined) => {
if (modelConnector && modelConnector.type === ConversationConnectorType.Model) {
return modelConnector.provider;
}
return undefined;
};

export const getConversationModelId = (conversation: Conversation | undefined) => {
if (!conversation) {
return undefined;
}
const modelConnector = getConversationConnector(conversation, ConversationConnectorType.Model);
return getConnectorModelId(modelConnector);
};

export const getConversationProvider = (conversation: Conversation | undefined) => {
if (!conversation) {
return undefined;
}
const modelConnector = getConversationConnector(conversation, ConversationConnectorType.Model);
return getConnectorProvider(modelConnector);
};

0 comments on commit 1255ee9

Please sign in to comment.