Skip to content

Commit

Permalink
fix: get message subgraph adapt kb (#417)
Browse files Browse the repository at this point in the history
using kb graph editor if the chat message is completed by a chat engine
with kb.
  • Loading branch information
Mini256 authored Nov 27, 2024
1 parent f887819 commit aa68dba
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
8 changes: 6 additions & 2 deletions backend/app/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from app.api.deps import SessionDep, OptionalUserDep, CurrentUserDep
from app.rag.chat_config import get_default_embedding_model, ChatEngineConfig
from app.rag.knowledge_base.config import get_kb_embed_model
from app.rag.knowledge_base.index_store import get_kb_tidb_graph_editor
from app.rag.knowledge_graph.graph_store.tidb_graph_editor import legacy_tidb_graph_editor
from app.repositories import chat_repo, knowledge_base_repo
from app.models import Chat, ChatUpdate
from app.rag.chat import (
Expand Down Expand Up @@ -196,15 +198,17 @@ def get_chat_subgraph(session: SessionDep, user: OptionalUserDep, chat_message_i
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Access denied")

engine_options = chat_message.chat.engine_options
chat_engine_config = ChatEngineConfig.validate(engine_options)
chat_engine_config = ChatEngineConfig.model_validate(engine_options)

if chat_engine_config.knowledge_base:
kb = knowledge_base_repo.must_get(session, chat_engine_config.knowledge_base.linked_knowledge_base.id)
embed_model = get_kb_embed_model(session, kb)
graph_editor = get_kb_tidb_graph_editor(session, kb)
else:
embed_model = get_default_embedding_model(session)
graph_editor = legacy_tidb_graph_editor

entities, relations = get_chat_message_subgraph(session, chat_message, embed_model)
entities, relations = get_chat_message_subgraph(graph_editor, session, chat_message, embed_model)
return SubgraphResponse(entities=entities, relationships=relations)


Expand Down
11 changes: 6 additions & 5 deletions backend/app/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from app.rag.knowledge_base.config import get_kb_embed_model
from app.rag.knowledge_graph.graph_store import TiDBGraphStore
from app.rag.vector_store.tidb_vector_store import TiDBVectorStore
from app.rag.knowledge_graph.graph_store.tidb_graph_editor import legacy_tidb_graph_editor
from app.rag.knowledge_graph.graph_store.tidb_graph_editor import TiDBGraphEditor

from app.rag.knowledge_graph import KnowledgeGraphIndex
from app.rag.chat_config import ChatEngineConfig, get_default_embedding_model, KnowledgeGraphOption
Expand Down Expand Up @@ -995,6 +995,7 @@ def get_graph_data_from_langfuse(trace_url: str):


def get_chat_message_subgraph(
graph_editor: TiDBGraphEditor,
session: Session,
chat_message: DBChatMessage,
embed_model: BaseEmbedding,
Expand All @@ -1007,12 +1008,12 @@ def get_chat_message_subgraph(
# try to get subgraph from chat_message.graph_data
try:
if (
chat_message.graph_data
and "relationships" in chat_message.graph_data
and len(chat_message.graph_data["relationships"]) > 0
chat_message.graph_data
and "relationships" in chat_message.graph_data
and len(chat_message.graph_data["relationships"]) > 0
):
relationship_ids = chat_message.graph_data["relationships"]
all_entities, all_relationships = legacy_tidb_graph_editor.get_relationship_by_ids(
all_entities, all_relationships = graph_editor.get_relationship_by_ids(
session, relationship_ids
)
entities = [
Expand Down

0 comments on commit aa68dba

Please sign in to comment.