Skip to content

Commit

Permalink
Persona / prompt hardening (#3375)
Browse files Browse the repository at this point in the history
* Persona / prompt hardening

* fix it
  • Loading branch information
Weves authored Dec 9, 2024
1 parent 4a7bd55 commit 970320b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 46 deletions.
88 changes: 52 additions & 36 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ def upsert_persona(
"""

if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
existing_persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
persona = _get_persona_by_name(
existing_persona = _get_persona_by_name(
persona_name=name, user=user, db_session=db_session
)

Expand All @@ -481,62 +481,78 @@ def upsert_persona(
prompts = None
if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found")

if prompts is not None and len(prompts) == 0:
raise ValueError(
f"Invalid Persona config, no valid prompts "
f"specified. Specified IDs were: '{prompt_ids}'"
)

# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)

if persona:
if existing_persona:
# Built-in personas can only be updated through YAML configuration.
# This ensures that core system personas are not modified unintentionally.
if persona.builtin_persona and not builtin_persona:
if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")

# this checks if the user has permission to edit the persona
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)

# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
persona.name = name
persona.description = description
persona.num_chunks = num_chunks
persona.chunks_above = chunks_above
persona.chunks_below = chunks_below
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public
persona.icon_color = icon_color
persona.icon_shape = icon_shape
existing_persona.name = name
existing_persona.description = description
existing_persona.num_chunks = num_chunks
existing_persona.chunks_above = chunks_above
existing_persona.chunks_below = chunks_below
existing_persona.llm_relevance_filter = llm_relevance_filter
existing_persona.llm_filter_extraction = llm_filter_extraction
existing_persona.recency_bias = recency_bias
existing_persona.llm_model_provider_override = llm_model_provider_override
existing_persona.llm_model_version_override = llm_model_version_override
existing_persona.starter_messages = starter_messages
existing_persona.deleted = False # Un-delete if previously deleted
existing_persona.is_public = is_public
existing_persona.icon_color = icon_color
existing_persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.category_id = category_id
existing_persona.uploaded_image_id = uploaded_image_id
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
persona.document_sets.clear()
persona.document_sets = document_sets or []
existing_persona.document_sets.clear()
existing_persona.document_sets = document_sets or []

if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts or []
existing_persona.prompts.clear()
existing_persona.prompts = prompts

if tools is not None:
persona.tools = tools or []
existing_persona.tools = tools or []

persona = existing_persona

else:
persona = Persona(
if not prompts:
raise ValueError(
"Invalid Persona config. "
"Must specify at least one prompt for a new persona."
)

new_persona = Persona(
id=persona_id,
user_id=user.id if user else None,
is_public=is_public,
Expand All @@ -549,7 +565,7 @@ def upsert_persona(
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
builtin_persona=builtin_persona,
prompts=prompts or [],
prompts=prompts,
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
Expand All @@ -564,8 +580,8 @@ def upsert_persona(
is_default_persona=is_default_persona,
category_id=category_id,
)
db_session.add(persona)

db_session.add(new_persona)
persona = new_persona
if commit:
db_session.commit()
else:
Expand Down
19 changes: 13 additions & 6 deletions backend/danswer/seeding/load_yamls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def load_personas_from_yaml(
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]

if not prompt_ids:
raise ValueError("Invalid Persona config, no prompts exist")

p_id = persona.get("id")
tool_ids = []

Expand Down Expand Up @@ -123,12 +126,16 @@ def load_personas_from_yaml(
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
display_priority=(
existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority")
),
is_visible=(
existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible")
),
db_session=db_session,
)

Expand Down
11 changes: 8 additions & 3 deletions backend/ee/danswer/server/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
if personas:
logger.notice("Seeding Personas")
for persona in personas:
if not persona.prompt_ids:
raise ValueError(
f"Invalid Persona with name {persona.name}; no prompts exist"
)

upsert_persona(
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks
if persona.num_chunks is not None
else 0.0,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/integration/common_utils/managers/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create(
"is_public": is_public,
"llm_filter_extraction": llm_filter_extraction,
"recency_bias": recency_bias,
"prompt_ids": prompt_ids or [],
"prompt_ids": prompt_ids or [0],
"document_set_ids": document_set_ids or [],
"tool_ids": tool_ids or [],
"llm_model_provider_override": llm_model_provider_override,
Expand Down

0 comments on commit 970320b

Please sign in to comment.