Skip to content

Commit

Permalink
Add model choice to Persona
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Dec 7, 2023
1 parent 26e808d commit 56785e6
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Add llm_model_version_override to Persona
Revision ID: baf71f781b9e
Revises: 50b683a8295c
Create Date: 2023-12-06 21:56:50.286158
"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "baf71f781b9e"
down_revision = "50b683a8295c"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"persona",
sa.Column("llm_model_version_override", sa.String(), nullable=True),
)


def downgrade() -> None:
op.drop_column("persona", "llm_model_version_override")
3 changes: 3 additions & 0 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def upsert_persona(
persona_id: int | None = None,
default_persona: bool = False,
document_sets: list[DocumentSetDBModel] | None = None,
llm_model_version_override: str | None = None,
commit: bool = True,
overwrite_duplicate_named_persona: bool = False,
) -> Persona:
Expand Down Expand Up @@ -355,6 +356,7 @@ def upsert_persona(
persona.num_chunks = num_chunks
persona.apply_llm_relevance_filter = apply_llm_relevance_filter
persona.default_persona = default_persona
persona.llm_model_version_override = llm_model_version_override

# Do not delete any associations manually added unless
# a new updated list is provided
Expand All @@ -375,6 +377,7 @@ def upsert_persona(
apply_llm_relevance_filter=apply_llm_relevance_filter,
default_persona=default_persona,
document_sets=document_sets if document_sets else [],
llm_model_version_override=llm_model_version_override,
)
db_session.add(persona)

Expand Down
7 changes: 7 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,13 @@ class Persona(Base):
apply_llm_relevance_filter: Mapped[bool | None] = mapped_column(
Boolean, nullable=True
)
# allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
# auto-detected time filters, relevance filters, etc.
llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True
)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
Expand Down
6 changes: 5 additions & 1 deletion backend/danswer/direct_qa/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def get_qa_model_for_persona(
timeout: int = QA_TIMEOUT,
) -> QAModel:
return QABlock(
llm=get_default_llm(api_key=api_key, timeout=timeout),
llm=get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=persona.llm_model_version_override,
),
qa_handler=PersonaBasedQAHandler(
system_prompt=persona.system_text or "", task_prompt=persona.hint_text or ""
),
Expand Down
8 changes: 7 additions & 1 deletion backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@ def get_default_llm(
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
use_fast_llm: bool = False,
gen_ai_model_version_override: str | None = None,
) -> LLM:
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
if gen_ai_model_version_override:
model_version = gen_ai_model_version_override
else:
model_version = (
FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
)
if api_key is None:
api_key = get_gen_ai_api_key()

Expand Down
48 changes: 48 additions & 0 deletions backend/danswer/server/features/persona/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import fetch_persona_by_id
from danswer.db.chat import fetch_personas
from danswer.db.chat import mark_persona_as_deleted
Expand Down Expand Up @@ -50,6 +52,7 @@ def create_persona(
num_chunks=create_persona_request.num_chunks,
apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter,
document_sets=document_sets,
llm_model_version_override=create_persona_request.llm_model_version_override,
)
except ValueError as e:
logger.exception("Failed to update persona")
Expand Down Expand Up @@ -84,6 +87,7 @@ def update_persona(
num_chunks=update_persona_request.num_chunks,
apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter,
document_sets=document_sets,
llm_model_version_override=update_persona_request.llm_model_version_override,
persona_id=persona_id,
)
except ValueError as e:
Expand Down Expand Up @@ -134,3 +138,47 @@ def build_final_template_prompt(
system_prompt=system_prompt, task_prompt=task_prompt
).build_dummy_prompt()
)


"""Utility endpoints for selecting which model to use for a persona.
Putting here for now, since we have no other flows which use this."""

GPT_4_MODEL_VERSIONS = [
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
]
GPT_3_5_TURBO_MODEL_VERSIONS = [
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]


@router.get("/persona-utils/list-available-models")
def list_available_model_versions(
_: User | None = Depends(current_admin_user),
) -> list[str]:
# currently only support selecting different models for OpenAI
if GEN_AI_MODEL_PROVIDER != "openai":
return []

return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS


@router.get("/persona-utils/default-model")
def get_default_model(
_: User | None = Depends(current_admin_user),
) -> str:
# currently only support selecting different models for OpenAI
if GEN_AI_MODEL_PROVIDER != "openai":
return ""

return GEN_AI_MODEL_VERSION
3 changes: 3 additions & 0 deletions backend/danswer/server/features/persona/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CreatePersonaRequest(BaseModel):
task_prompt: str
num_chunks: int | None = None
apply_llm_relevance_filter: bool | None = None
llm_model_version_override: str | None = None


class PersonaSnapshot(BaseModel):
Expand All @@ -21,6 +22,7 @@ class PersonaSnapshot(BaseModel):
system_prompt: str
task_prompt: str
document_sets: list[DocumentSet]
llm_model_version_override: str | None

@classmethod
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
Expand All @@ -34,6 +36,7 @@ def from_model(cls, persona: Persona) -> "PersonaSnapshot":
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
llm_model_version_override=persona.llm_model_version_override,
)


Expand Down
55 changes: 46 additions & 9 deletions web/src/app/admin/personas/PersonaEditor.tsx
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
"use client";

import { DocumentSet } from "@/lib/types";
import { Button, Divider } from "@tremor/react";
import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
} from "formik";
import { Button, Divider, Text } from "@tremor/react";
import { ArrayHelpers, FieldArray, Form, Formik } from "formik";

import * as Yup from "yup";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
Expand All @@ -20,6 +13,7 @@ import Link from "next/link";
import { useEffect, useState } from "react";
import {
BooleanFormField,
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";

Expand All @@ -40,9 +34,13 @@ function SubLabel({ children }: { children: string | JSX.Element }) {
export function PersonaEditor({
existingPersona,
documentSets,
llmOverrideOptions,
defaultLLM,
}: {
existingPersona?: Persona | null;
documentSets: DocumentSet[];
llmOverrideOptions: string[];
defaultLLM: string;
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
Expand Down Expand Up @@ -74,6 +72,7 @@ export function PersonaEditor({
<div className="dark">
{popup}
<Formik
enableReinitialize={true}
initialValues={{
name: existingPersona?.name ?? "",
description: existingPersona?.description ?? "",
Expand All @@ -86,6 +85,8 @@ export function PersonaEditor({
num_chunks: existingPersona?.num_chunks ?? null,
apply_llm_relevance_filter:
existingPersona?.apply_llm_relevance_filter ?? false,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
}}
validationSchema={Yup.object().shape({
name: Yup.string().required("Must give the Persona a name!"),
Expand All @@ -101,6 +102,7 @@ export function PersonaEditor({
document_set_ids: Yup.array().of(Yup.number()),
num_chunks: Yup.number().max(20).nullable(),
apply_llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
Expand Down Expand Up @@ -259,6 +261,41 @@ export function PersonaEditor({

<Divider />

{llmOverrideOptions.length > 0 && defaultLLM && (
<>
<SectionHeader>[Advanced] Model Selection</SectionHeader>

<Text>
Pick which LLM to use for this Persona. If left as Default,
will use <b className="italic">{defaultLLM}</b>.
<br />
<br />
For more information on the different LLMs, checkout the{" "}
<a
href="https://platform.openai.com/docs/models"
target="_blank"
className="text-blue-500"
>
OpenAI docs
</a>
.
</Text>

<SelectorFormField
name="llm_model_version_override"
options={llmOverrideOptions.map((llmOption) => {
return {
name: llmOption,
value: llmOption,
};
})}
includeDefault={true}
/>
</>
)}

<Divider />

<SectionHeader>[Advanced] Retrieval Customization</SectionHeader>

<TextFormField
Expand Down
44 changes: 40 additions & 4 deletions web/src/app/admin/personas/[personaId]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@ import { DocumentSet } from "@/lib/types";
import { BackButton } from "@/components/BackButton";
import { Card, Title } from "@tremor/react";
import { DeletePersonaButton } from "./DeletePersonaButton";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";

export default async function Page({
params,
}: {
params: { personaId: string };
}) {
const personaResponse = await fetchSS(`/persona/${params.personaId}`);
const [
personaResponse,
documentSetsResponse,
llmOverridesResponse,
defaultLLMResponse,
] = await Promise.all([
fetchSS(`/persona/${params.personaId}`),
fetchSS("/manage/document-set"),
fetchSS("/persona-utils/list-available-models"),
fetchSS("/persona-utils/default-model"),
]);

if (!personaResponse.ok) {
return (
Expand All @@ -23,8 +34,6 @@ export default async function Page({
);
}

const documentSetsResponse = await fetchSS("/manage/document-set");

if (!documentSetsResponse.ok) {
return (
<ErrorCallout
Expand All @@ -34,18 +43,45 @@ export default async function Page({
);
}

if (!llmOverridesResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch LLM override options - ${await documentSetsResponse.text()}`}
/>
);
}

if (!defaultLLMResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch default LLM - ${await documentSetsResponse.text()}`}
/>
);
}

const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
const persona = (await personaResponse.json()) as Persona;
const llmOverrideOptions = (await llmOverridesResponse.json()) as string[];
const defaultLLM = (await defaultLLMResponse.json()) as string;

return (
<div className="dark">
<InstantSSRAutoRefresh />

<BackButton />
<div className="pb-2 mb-4 flex">
<h1 className="text-3xl font-bold pl-2">Edit Persona</h1>
</div>

<Card>
<PersonaEditor existingPersona={persona} documentSets={documentSets} />
<PersonaEditor
existingPersona={persona}
documentSets={documentSets}
llmOverrideOptions={llmOverrideOptions}
defaultLLM={defaultLLM}
/>
</Card>

<div className="mt-12">
Expand Down
1 change: 1 addition & 0 deletions web/src/app/admin/personas/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ export interface Persona {
document_sets: DocumentSet[];
num_chunks?: number;
apply_llm_relevance_filter?: boolean;
llm_model_version_override?: string;
}
Loading

1 comment on commit 56785e6

@vercel
Copy link

@vercel vercel bot commented on 56785e6 Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.