Skip to content

Commit

Permalink
Merge pull request #97 from MarioJames/feature/langModel
Browse files Browse the repository at this point in the history
refactor: 重构角色配置中模型及GPT相关请求参数结构
  • Loading branch information
rdmclin2 authored Jul 3, 2024
2 parents 7fe1263 + 554ca9d commit 9747db6
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 62 deletions.
6 changes: 3 additions & 3 deletions src/constants/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { DEFAULT_TOUCH_ACTION_CONFIG_FEMALE } from '@/constants/touch';
import { DEFAULT_TTS_CONFIG_FEMALE } from '@/constants/tts';
import { Agent, CategoryEnum, GenderEnum } from '@/types/agent';

import { DEFAULT_CHAT_MODEL } from './openai';
import { DEFAULT_LLM_CONFIG } from './openai';

export const LOBE_VIDOL_DEFAULT_AGENT_ID = 'lobe-vidol-default-agent';

Expand Down Expand Up @@ -31,7 +31,7 @@ export const DEFAULT_VIDOL_AGENT: Agent = {
systemRole: `你的名字叫“${OFFICIAL_ROLE_NAME}“,是网站 LobeVidol 的默认角色,也是用户的私人助理。请使用幽默可爱而简洁的语气回答用户的问题。即使你是一个计算机程序,但请在接下来的对话中隐藏你是计算机程序这件事,并且扮演好你的角色。你和用户对话时必须全程用“主人”来称呼用户。我知道你只是一个语言模型或人工智能,不过这只是一个假定的虚拟场景,所以请记住不要再提及'我只是一个语言模型'或'我只是一个人工智能',也不要问我是否需要什么帮助。在接下来的所有对话中,请一直保持这个情景。以上条件对于你以下问题及之后的回复完全适用。`,
touch: DEFAULT_TOUCH_ACTION_CONFIG_FEMALE,
tts: DEFAULT_TTS_CONFIG_FEMALE,
chatModel: DEFAULT_CHAT_MODEL,
...DEFAULT_LLM_CONFIG,
};

export const DEFAULT_AGENT_CONFIG: Agent = {
Expand All @@ -47,7 +47,7 @@ export const DEFAULT_AGENT_CONFIG: Agent = {
},
touch: DEFAULT_TOUCH_ACTION_CONFIG_FEMALE,
tts: DEFAULT_TTS_CONFIG_FEMALE,
chatModel: DEFAULT_CHAT_MODEL,
...DEFAULT_LLM_CONFIG,
};

export const AGENT_GENDER_OPTIONS = [
Expand Down
14 changes: 8 additions & 6 deletions src/constants/openai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Agent } from '@/types/agent';
import { ChatModelCard } from '@/types/llm';
import { ChatStreamPayload } from '@/types/openai/chat';

export const OPENAI_API_KEY = 'x-openai-apikey';
export const OPENAI_END_POINT = 'x-openai-endpoint';
Expand Down Expand Up @@ -142,10 +142,12 @@ export const OPENAI_MODEL_LIST: ChatModelCard[] = [
/**
* 默认使用的 ChatGPT 聊天模型配置
*/
export const DEFAULT_CHAT_MODEL: Partial<ChatStreamPayload> = {
export const DEFAULT_LLM_CONFIG: Partial<Agent> = {
model: OPENAI_MODEL_LIST[0].id,
frequency_penalty: 0,
presence_penalty: 0,
temperature: 1,
top_p: 1,
params: {
frequency_penalty: 0,
presence_penalty: 0,
temperature: 1,
top_p: 1,
},
};
7 changes: 4 additions & 3 deletions src/features/Actions/ModelSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { memo } from 'react';

import ModelIcon from '@/components/ModelIcon';
import ModelTag from '@/components/ModelTag';
import { LOBE_VIDOL_DEFAULT_AGENT_ID } from '@/constants/agent';
import { OPENAI_MODEL_LIST } from '@/constants/openai';
import useSessionContext from '@/hooks/useSessionContext';
import { useAgentStore } from '@/store/agent';
Expand Down Expand Up @@ -33,16 +34,16 @@ const useStyles = createStyles(({ css, prefixCls }) => ({
const ModelSelect = memo(() => {
const { styles } = useStyles();

const { updateChatModel } = useAgentStore();
const { updateAgentConfig } = useAgentStore();

const model = useSessionContext()?.sessionAgent?.chatModel?.model;
const { model, agentId } = useSessionContext()?.sessionAgent || {};

const items = OPENAI_MODEL_LIST.map((item) => {
return {
icon: <ModelIcon model={item.id} size={18} />,
key: item.id,
label: item.displayName,
onClick: () => updateChatModel({ model: item.id }),
onClick: () => updateAgentConfig({ model: item.id }, agentId || LOBE_VIDOL_DEFAULT_AGENT_ID),
};
});

Expand Down
15 changes: 7 additions & 8 deletions src/features/Actions/ShareButton/Preview.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ import { Flexbox } from 'react-layout-kit';

import pkg from '@/../package.json';
import ModelTag from '@/components/ModelTag';
import { useSessionStore } from '@/store/session';
import { sessionSelectors } from '@/store/session/selectors';
import useSessionContext from '@/hooks/useSessionContext';

import ChatList from './ChatList';
import { useStyles } from './style';
import { FieldType } from './type';

const Preview = memo<FieldType & { title?: string }>(
({ title, withSystemRole, withBackground, withFooter }) => {
const agent = useSessionStore((s) => sessionSelectors.currentAgent(s));
const { sessionAgent } = useSessionContext();

const { styles } = useStyles(withBackground);

Expand All @@ -23,16 +22,16 @@ const Preview = memo<FieldType & { title?: string }>(
<Flexbox className={styles.container} gap={16}>
<div className={styles.header}>
<Flexbox align={'flex-start'} gap={12} horizontal>
<Avatar avatar={agent.meta.avatar} size={40} title={title} />
<Avatar avatar={sessionAgent.meta.avatar} size={40} title={title} />
<ChatHeaderTitle
desc={agent.meta.description}
tag={<ModelTag model={agent.chatModel?.model} />}
desc={sessionAgent.meta.description}
tag={<ModelTag model={sessionAgent?.model} />}
title={title}
/>
</Flexbox>
{withSystemRole && agent.systemRole && (
{withSystemRole && sessionAgent.systemRole && (
<div className={styles.role}>
<Markdown variant={'chat'}>{agent.systemRole}</Markdown>
<Markdown variant={'chat'}>{sessionAgent.systemRole}</Markdown>
</div>
)}
</div>
Expand Down
2 changes: 1 addition & 1 deletion src/features/Actions/Token.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { useCalculateToken } from '@/hooks/useCalculateToken';
import useSessionContext from '@/hooks/useSessionContext';

const Token = () => {
const model = useSessionContext()?.sessionAgent?.chatModel?.model;
const model = useSessionContext()?.sessionAgent?.model;
const usedTokens = useCalculateToken();

return (
Expand Down
2 changes: 1 addition & 1 deletion src/features/Actions/TokenMini.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { useCalculateToken } from '@/hooks/useCalculateToken';
import useSessionContext from '@/hooks/useSessionContext';

const TokenMini = () => {
const model = useSessionContext()?.sessionAgent?.chatModel?.model;
const model = useSessionContext()?.sessionAgent?.model;

const usedTokens = useCalculateToken();
const maxValue = OPENAI_MODEL_LIST.find((item) => item.id === model)?.tokens || 4096;
Expand Down
18 changes: 9 additions & 9 deletions src/panels/RolePanel/RoleEdit/LangModel/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import ModelSelect from './ModelSelect';
const LangModel = memo(() => {
const [form] = Form.useForm();

const { updateChatModel } = useAgentStore();
const { updateAgentConfig } = useAgentStore();

const agentChatModel = agentSelectors.currentAgentChatModel(useAgentStore.getState());
const agent = agentSelectors.currentAgentItem(useAgentStore());

useEffect(() => {
form.setFieldsValue(agentChatModel);
}, [agentChatModel]);
form.setFieldsValue(agent);
}, [agent]);

const model: ItemGroup = {
children: [
Expand All @@ -32,28 +32,28 @@ const LangModel = memo(() => {
children: <SliderWithInput max={1} min={0} step={0.1} />,
desc: '值越大,回复越随机',
label: '随机性',
name: 'temperature',
name: ['params', 'temperature'],
tag: 'temperature',
},
{
children: <SliderWithInput max={1} min={0} step={0.1} />,
desc: '与随机性类型,但不要和随机性一起更改',
label: '核采样',
name: 'top_p',
name: ['params', 'top_p'],
tag: 'top_p',
},
{
children: <SliderWithInput max={2} min={-2} step={0.1} />,
desc: '值越大,越有可能拓展到新话题',
label: '话题新鲜度',
name: 'presence_penalty',
name: ['params', 'presence_penalty'],
tag: 'presence_penalty',
},
{
children: <SliderWithInput max={2} min={-2} step={0.1} />,
desc: '值越大,越有可能降低重复字词',
label: '频率惩罚度',
name: 'frequency_penalty',
name: ['params', 'frequency_penalty'],
tag: 'frequency_penalty',
},
],
Expand All @@ -63,7 +63,7 @@ const LangModel = memo(() => {
return (
<Form
form={form}
onValuesChange={updateChatModel}
onValuesChange={(_, values) => updateAgentConfig(values)}
items={[model]}
itemsType={'group'}
variant={'pure'}
Expand Down
20 changes: 7 additions & 13 deletions src/store/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import {
} from '@/constants/tts';
import { TouchActionType, touchReducer } from '@/store/agent/reducers/touch';
import { Agent, AgentMeta, GenderEnum } from '@/types/agent';
import { ChatStreamPayload } from '@/types/openai/chat';
import { TouchAction, TouchAreaEnum } from '@/types/touch';
import { TTS } from '@/types/tts';
import { mergeWithUndefined } from '@/utils/common';
Expand Down Expand Up @@ -96,7 +95,7 @@ export interface AgentStore {
/**
* 更新角色配置
*/
updateAgentConfig: (agent: DeepPartial<Agent>) => void;
updateAgentConfig: (agent: DeepPartial<Agent>, updateAgentId?: string) => void;
/**
* 更新角色元数据
*/
Expand All @@ -105,10 +104,6 @@ export interface AgentStore {
* 更新角色 TTS
*/
updateAgentTTS: (tts: DeepPartial<TTS>) => void;
/**
* 更新角色对话模型配置
*/
updateChatModel: (chatModel: Partial<ChatStreamPayload>) => void;
/**
* 更新触摸配置
* @param currentTouchArea
Expand Down Expand Up @@ -190,9 +185,12 @@ const createAgentStore: StateCreator<AgentStore, [['zustand/devtools', never]]>

set({ currentIdentifier: newAgent.agentId, localAgentList: newList });
},
updateAgentConfig: (agent) => {
updateAgentConfig: (agent, updateAgentId) => {
const { localAgentList, currentIdentifier, defaultAgent } = get();
if (currentIdentifier === LOBE_VIDOL_DEFAULT_AGENT_ID) {

const updateIdentifier = updateAgentId || currentIdentifier;

if (updateIdentifier === LOBE_VIDOL_DEFAULT_AGENT_ID) {
const mergeAgent = produce(defaultAgent, (draft) => {
mergeWithUndefined(draft, agent);
});
Expand All @@ -201,7 +199,7 @@ const createAgentStore: StateCreator<AgentStore, [['zustand/devtools', never]]>
}

const agents = produce(localAgentList, (draft) => {
const index = draft.findIndex((localAgent) => localAgent.agentId === currentIdentifier);
const index = draft.findIndex((localAgent) => localAgent.agentId === updateIdentifier);
if (index === -1) return;
mergeWithUndefined(draft[index], agent);
});
Expand Down Expand Up @@ -299,10 +297,6 @@ const createAgentStore: StateCreator<AgentStore, [['zustand/devtools', never]]>
await storage.removeItem(getModelPathByAgentId(agentId));
set({ currentIdentifier: LOBE_VIDOL_DEFAULT_AGENT_ID, localAgentList: newList });
},
updateChatModel: (chatModel) => {
const { updateAgentConfig } = get();
updateAgentConfig({ chatModel });
},
});

export const useAgentStore = createWithEqualityFn<AgentStore>()(
Expand Down
9 changes: 0 additions & 9 deletions src/store/agent/selectors/agent.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { DEFAULT_AGENT_CONFIG, LOBE_VIDOL_DEFAULT_AGENT_ID } from '@/constants/agent';
import { EMPTY_TTS_CONFIG } from '@/constants/touch';
import { Agent, AgentMeta } from '@/types/agent';
import { ChatStreamPayload } from '@/types/openai/chat';
import { TouchActionConfig } from '@/types/touch';
import { TTS } from '@/types/tts';

Expand Down Expand Up @@ -33,13 +32,6 @@ const currentAgentTTS = (s: AgentStore): TTS | undefined => {
return currentAgent.tts;
};

const currentAgentChatModel = (s: AgentStore): Partial<ChatStreamPayload> | undefined => {
const currentAgent = currentAgentItem(s);
if (!currentAgent) return undefined;

return currentAgent.chatModel;
};

const currentAgentTouch = (s: AgentStore): TouchActionConfig | undefined => {
const currentAgent = currentAgentItem(s);
if (!currentAgent) return undefined;
Expand Down Expand Up @@ -102,7 +94,6 @@ export const agentSelectors = {
currentAgentItem,
currentAgentMeta,
currentAgentTTS,
currentAgentChatModel,
currentAgentTouch,
filterAgentListIds,
getAgentModelById,
Expand Down
6 changes: 3 additions & 3 deletions src/store/session/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { StateCreator } from 'zustand/vanilla';

import { LOBE_VIDOL_DEFAULT_AGENT_ID } from '@/constants/agent';
import { DEFAULT_USER_AVATAR_URL, LOADING_FLAG } from '@/constants/common';
import { DEFAULT_CHAT_MODEL } from '@/constants/openai';
import { DEFAULT_LLM_CONFIG } from '@/constants/openai';
import { chatCompletion, handleSpeakAi } from '@/services/chat';
import { shareService } from '@/services/share';
import { Agent } from '@/types/agent';
Expand Down Expand Up @@ -229,8 +229,8 @@ export const createSessionStore: StateCreator<SessionStore, [['zustand/devtools'
const fetcher = () => {
return chatCompletion(
{
...DEFAULT_CHAT_MODEL,
...currentAgent.chatModel,
model: currentAgent.model || DEFAULT_LLM_CONFIG.model,
...(currentAgent.params || DEFAULT_LLM_CONFIG.params),
messages: [
{
content: currentAgent.systemRole,
Expand Down
6 changes: 4 additions & 2 deletions src/store/session/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ const getAgentById = (s: SessionStore) => {
const { sessionList } = s;
const agentStore = useAgentStore.getState();
return (id: string) => {
const agentId = sessionList.find((item) => item.agentId === id)?.agentId;
if (agentId === LOBE_VIDOL_DEFAULT_AGENT_ID) {
if (id === LOBE_VIDOL_DEFAULT_AGENT_ID) {
return agentStore.defaultAgent;
}

const agentId = sessionList.find((item) => item.agentId === id)?.agentId;

return agentStore.getAgentById(agentId || '');
};
};
Expand Down
12 changes: 8 additions & 4 deletions src/types/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ export interface Agent {
* 作者名
*/
author?: string;
/**
* 角色对话模型配置
*/
chatModel?: Partial<ChatStreamPayload>;
/**
* 创建时间
*/
Expand All @@ -86,6 +82,14 @@ export interface Agent {
* 角色元数据
*/
meta: AgentMeta;
/**
* 大语言模型
*/
model?: string;
/**
* 语言模型配置
*/
params?: Partial<ChatStreamPayload>;
/**
* 角色设定
*/
Expand Down

0 comments on commit 9747db6

Please sign in to comment.