Skip to content

Commit

Permalink
[GEN AI] Improving Langfuse's tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
assouktim committed Nov 22, 2024
1 parent 1f7e145 commit a4222d9
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 36 deletions.
9 changes: 8 additions & 1 deletion bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
try {
val response = ragService.rag(
query = RAGQuery(
history = getDialogHistory(dialog),
dialog = DialogDetails(
dialogId = dialog.id.toString(),
userId = dialog.playerIds.firstOrNull { PlayerType.user == it.type }?.id,
history = getDialogHistory(dialog),
tags = listOf(
"connector:${underlyingConnector.connectorType.id}"
)
),
questionAnsweringLlmSetting = ragConfiguration.llmSetting,
questionAnsweringPromptInputs = mapOf(
"question" to action.toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting
data class RAGQuery(
// val condenseQuestionLlmSetting: LLMSetting,
// val condenseQuestionPromptInputs: Map<String, String>,
val history: List<ChatMessage> = emptyList(),
val dialog: DialogDetails,
val questionAnsweringLlmSetting: LLMSetting,
val questionAnsweringPromptInputs: Map<String, String>,
val embeddingQuestionEmSetting: EMSetting,
Expand All @@ -34,6 +34,13 @@ data class RAGQuery(
val observabilitySetting: ObservabilitySetting?
)

data class DialogDetails(
val dialogId: String? = null,
val userId: String? = null,
val history: List<ChatMessage> = emptyList(),
val tags: List<String> = emptyList(),
)

data class ChatMessage(
val text: String,
val type: ChatMessageType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __eq__(self, other):
)

def __hash__(self):
return hash((self.title, self.url, self.content))
return hash((self.title, str(self.url or ''), self.content))


class Footnote(Source):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,22 @@ class VectorStoreProviderSettingStatusQuery(BaseModel):
default=None,
)

class DialogDetails(BaseModel):
"""The dialog details model"""

dialog_id: Optional[str] = Field(description="The dialog ID.", default=None, examples=["uuid-0123"])
user_id: Optional[str] = Field(description="The user ID.", default=None, examples=["[email protected]"])
history: list[ChatMessage] = Field(description="Conversation history, used to reformulate the user's question.")
tags: list[str] = Field(description='List of tags', examples=[["my-Tag"]])


class RagQuery(BaseQuery):
"""The RAG query model"""

history: list[ChatMessage] = Field(
description="Conversation history, used to reformulate the user's question."
)
dialog: DialogDetails = Field(description='The user dialog details.')
question_answering_prompt_inputs: Any = Field(
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
'chat_history field must not be specified here, it will be override by the history field',
'chat_history field must not be specified here, it will be override by the dialog.history field',
)
# condense_question_llm_setting: LLMSetting =
# Field(description="LLM setting, used to condense the user's question.")
Expand All @@ -156,7 +162,7 @@ class RagQuery(BaseQuery):
)
question_answering_prompt_inputs: Any = Field(
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
'chat_history field must not be specified here, it will be override by the history field',
'chat_history field must not be specified here, it will be override by the dialog.history field',
)
embedding_question_em_setting: EMSetting = Field(
description="Embedding model setting, used to calculate the user's question vector."
Expand All @@ -182,13 +188,15 @@ class RagQuery(BaseQuery):
'json_schema_extra': {
'examples': [
{
'history': [
{'text': 'Hello, how can I do this?', 'type': 'HUMAN'},
{
'text': 'you can do this with the following method ....',
'type': 'AI',
},
],
'dialog' : {
'history': [
{'text': 'Hello, how can I do this?', 'type': 'HUMAN'},
{
'text': 'you can do this with the following method ....',
'type': 'AI',
},
]
},
'question_answering_llm_setting': {
'provider': 'OpenAI',
'api_key': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def generate_and_split_sentences(
config = {"callbacks": [
create_observability_callback_handler(
observability_setting=query.observability_setting,
trace_name=ObservabilityTrace.SENTENCE_GENERATION
trace_name=ObservabilityTrace.SENTENCE_GENERATION.value
)]}

sentences = await chain.ainvoke(query.prompt.inputs, config=config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import logging
from typing import Optional
from typing import Optional, Any

from langchain_core.embeddings import Embeddings
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
Expand Down Expand Up @@ -338,22 +338,21 @@ def get_callback_handler_factory(

def create_observability_callback_handler(
observability_setting: Optional[ObservabilitySetting],
trace_name: ObservabilityTrace,
**kwargs: Any
) -> Optional[LangfuseCallbackHandler]:
"""
Create the Observability Callback Handler
Args:
observability_setting: The Observability Settings
trace_name: The trace name
Returns:
The Observability Callback Handler
"""
if observability_setting is not None:
return get_callback_handler_factory(
setting=observability_setting
).get_callback_handler(trace_name=trace_name.value)
setting=observability_setting,
).get_callback_handler(**kwargs)

return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,17 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:

conversational_retrieval_chain = create_rag_chain(query=query)

message_history = ChatMessageHistory()
if query.dialog:
for msg in query.dialog.history:
if ChatMessageType.HUMAN == msg.type:
message_history.add_user_message(msg.text)
else:
message_history.add_ai_message(msg.text)

logger.debug(
'RAG chain - Use chat history: %s', 'Yes' if len(query.history) > 0 else 'No'
'RAG chain - Use chat history: %s', 'Yes' if len(message_history.messages) > 0 else 'No'
)
message_history = ChatMessageHistory()
for msg in query.history:
if ChatMessageType.HUMAN == msg.type:
message_history.add_user_message(msg.text)
else:
message_history.add_ai_message(msg.text)

inputs = {
**query.question_answering_prompt_inputs,
Expand All @@ -123,7 +125,10 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:
callback_handlers.append(
create_observability_callback_handler(
observability_setting=query.observability_setting,
trace_name=ObservabilityTrace.RAG,
trace_name=ObservabilityTrace.RAG.value,
session_id=query.dialog.dialog_id,
user_id=query.dialog.user_id,
tags=query.dialog.tags,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ async def test_rag_chain(
"""Test the full execute_qa_chain method by mocking all external calls."""
# Build a test RagQuery
query_dict = {
'history': [
{'text': 'Hello, how can I do this?', 'type': 'HUMAN'},
{
'text': 'you can do this with the following method ....',
'type': 'AI',
},
],
'dialog': {
'history': [
{'text': 'Hello, how can I do this?', 'type': 'HUMAN'},
{'text': 'you can do this with the following method ....', 'type': 'AI'}
],
'tags': []
},
'question_answering_llm_setting': {
'provider': 'OpenAI',
'api_key': {'type': 'Raw', 'value': 'ab7***************************A1IV4B'},
Expand Down

0 comments on commit a4222d9

Please sign in to comment.