diff --git a/backend/src/api/routers/graphs.py b/backend/src/api/routers/graphs.py index 23f36e8b..f07d7a04 100644 --- a/backend/src/api/routers/graphs.py +++ b/backend/src/api/routers/graphs.py @@ -110,13 +110,20 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse: if ( isinstance(output, list) and len(output) > 2 - and "generate" in output[2] - and "messages" in output[2]["generate"] - and len(output[2]["generate"]["messages"]) > 0 + and "generate" in output[-1] + and "messages" in output[-1]["generate"] + and len(output[-1]["generate"]["messages"]) > 0 ): - llm_response = output[2]["generate"]["messages"][0] - tool = list(output[-2].keys())[0] - urls = list(set(output[-2][tool]["urls"])) + llm_response = output[-1]["generate"]["messages"][0] + tools = output[0]["agent"]["tools"] + + urls = [] + context = [] + tool_index = 1 + for tool in tools: + urls.extend(list(output[tool_index].values())[0]["urls"]) + context.extend(list(set(list(output[tool_index].values())[0]["context"]))) + tool_index += 1 else: llm_response = "LLM response extraction failed" logging.error("LLM response extraction failed") @@ -126,13 +133,13 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse: "response": llm_response, "sources": (urls), "context": (context), - "tool": tool, + "tool": tools, } elif user_input.list_sources: - response = {"response": llm_response, "sources": (urls), "tool": tool} + response = {"response": llm_response, "sources": (urls), "tool": tools} elif user_input.list_context: - response = {"response": llm_response, "context": (context), "tool": tool} + response = {"response": llm_response, "context": (context), "tool": tools} else: - response = {"response": llm_response, "tool": tool} + response = {"response": llm_response, "tool": tools} return ChatResponse(**response)