Skip to content

Commit

Permalink
[DERCBOT-1037] Use of PromptTemplate
Browse files Browse the repository at this point in the history
  • Loading branch information
assouktim committed Oct 21, 2024
1 parent 47c90f4 commit 1353a37
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 133 deletions.
12 changes: 8 additions & 4 deletions bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
query = RAGQuery(
history = getDialogHistory(dialog),
questionAnsweringLlmSetting = ragConfiguration.llmSetting,
questionAnsweringPromptInputs = mapOf(
"question" to action.toString(),
"locale" to userPreferences.locale.displayLanguage,
"no_answer" to ragConfiguration.noAnswerSentence
questionAnsweringPrompt = PromptTemplate(
formatter = Formatter.F_STRING.id,
template = ragConfiguration.llmSetting.prompt,
inputs = mapOf(
"question" to action.toString(),
"locale" to userPreferences.locale.displayLanguage,
"no_answer" to ragConfiguration.noAnswerSentence
)
),
embeddingQuestionEmSetting = ragConfiguration.emSetting,
documentIndexName = indexName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting

data class RAGQuery(
// val condenseQuestionLlmSetting: LLMSetting,
// val condenseQuestionPromptInputs: Map<String, String>,
// val condenseQuestionPrompt: PromptTemplate,
val history: List<ChatMessage> = emptyList(),
val questionAnsweringLlmSetting: LLMSetting,
val questionAnsweringPromptInputs: Map<String, String>,
val questionAnsweringPrompt: PromptTemplate,
val embeddingQuestionEmSetting: EMSetting,
val documentIndexName: String,
val documentSearchParams: DocumentSearchParamsBase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,3 @@ class BaseLLMSetting(BaseModel):
ge=0,
le=2,
)
prompt: str = Field(
description='The prompt to generate completions for.',
examples=['How to learn to ride a bike without wheels!'],
min_length=1,
)
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,17 @@ class RagQuery(BaseQuery):
history: list[ChatMessage] = Field(
description="Conversation history, used to reformulate the user's question."
)
question_answering_prompt_inputs: Any = Field(
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
'chat_history field must not be specified here, it will be override by the history field',
)
# condense_question_llm_setting: LLMSetting =
# Field(description="LLM setting, used to condense the user's question.")
# condense_question_prompt_inputs: Any = (
# Field(
# description='Key-value inputs for the condense question llm prompt, when used as a template.',
# ),
# condense_question_prompt: PromptTemplate = Field(
# description='Prompt template, used to create a prompt with inputs for jinja and fstring format'
# )
question_answering_llm_setting: LLMSetting = Field(
description='LLM setting, used to perform a QA Prompt.'
)
question_answering_prompt : PromptTemplate = Field(
description='Prompt template, used to create a prompt with inputs for jinja and fstring format'
)

model_config = {
'json_schema_extra': {
Expand All @@ -164,7 +161,11 @@ class RagQuery(BaseQuery):
'value': 'ab7***************************A1IV4B',
},
'temperature': 1.2,
'prompt': """Use the following context to answer the question at the end.
'model': 'gpt-3.5-turbo',
},
'question_answering_prompt': {
'formatter': 'f-string',
'template': """Use the following context to answer the question at the end.
If you don't know the answer, just say {no_answer}.
Context:
Expand All @@ -174,12 +175,11 @@ class RagQuery(BaseQuery):
{question}
Answer in {locale}:""",
'model': 'gpt-3.5-turbo',
},
'question_answering_prompt_inputs': {
'question': 'How to get started playing guitar ?',
'no_answer': "Sorry, I don't know.",
'locale': 'French',
'inputs': {
'question': 'How to get started playing guitar ?',
'no_answer': 'Sorry, I don t know.',
'locale': 'French',
}
},
'embedding_question_em_setting': {
'provider': 'OpenAI',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@

import logging
import time
from typing import Optional

from jinja2 import Template, TemplateError
from langchain_core.output_parsers import NumberedListOutputParser
from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate
from langchain_core.runnables import RunnableConfig

from gen_ai_orchestrator.errors.exceptions.exceptions import (
GenAIPromptTemplateException,
)
from gen_ai_orchestrator.errors.handlers.openai.openai_exception_handler import (
openai_exception_handler,
)
from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo
from gen_ai_orchestrator.models.observability.observability_trace import ObservabilityTrace
from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter
from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate
from gen_ai_orchestrator.routers.requests.requests import (
SentenceGenerationQuery,
)
Expand All @@ -42,6 +33,7 @@
from gen_ai_orchestrator.services.langchain.factories.langchain_factory import (
get_llm_factory, create_observability_callback_handler,
)
from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,29 +82,3 @@ async def generate_and_split_sentences(
)

return SentenceGenerationResponse(sentences=sentences)


def validate_prompt_template(prompt: PromptTemplate):
"""
Prompt template validation
Args:
prompt: The prompt template
Returns:
Nothing.
Raises:
GenAIPromptTemplateException: if template is incorrect
"""
if PromptFormatter.JINJA2 == prompt.formatter:
try:
Template(prompt.template).render(prompt.inputs)
except TemplateError as exc:
logger.error('Prompt completion - template validation failed!')
logger.error(exc)
raise GenAIPromptTemplateException(
ErrorInfo(
error=exc.__class__.__name__,
cause=str(exc),
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
"""

import logging
import re
import time
from logging import ERROR, WARNING
from typing import List, Optional

from langchain.chains import ConversationalRetrievalChain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate

from gen_ai_orchestrator.errors.exceptions.exceptions import (
GenAIGuardCheckException,
Expand All @@ -47,9 +46,6 @@
RagDocumentMetadata,
TextWithFootnotes,
)
from gen_ai_orchestrator.models.vector_stores.vectore_store_provider import (
VectorStoreProvider,
)
from gen_ai_orchestrator.routers.requests.requests import RagQuery
from gen_ai_orchestrator.routers.responses.responses import RagResponse
from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import (
Expand All @@ -60,6 +56,7 @@
get_llm_factory,
get_vector_store_factory, create_observability_callback_handler,
)
from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,7 +90,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:
message_history.add_ai_message(msg.text)

inputs = {
**query.question_answering_prompt_inputs,
**query.question_answering_prompt.inputs,
'chat_history': message_history.messages,
}

Expand Down Expand Up @@ -180,6 +177,11 @@ def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain:
index_name=query.document_index_name,
embedding_function=em_factory.get_embedding_model())

logger.info('RAG chain - LLM template validation')
validate_prompt_template(query.question_answering_prompt)



logger.debug('RAG chain - Document index name: %s', query.document_index_name)
logger.debug('RAG chain - Create a ConversationalRetrievalChain from LLM')
return ConversationalRetrievalChain.from_llm(
Expand All @@ -188,27 +190,13 @@ def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain:
return_source_documents=True,
return_generated_question=True,
combine_docs_chain_kwargs={
'prompt': PromptTemplate(
template=llm_factory.setting.prompt,
input_variables=__find_input_variables(llm_factory.setting.prompt),
'prompt': LangChainPromptTemplate.from_template(
template=query.question_answering_prompt.template,
template_format=query.question_answering_prompt.formatter.value,
)
},
)


def __find_input_variables(template):
"""
Search for input variables on a given template
Args:
template: the template to search on
"""

motif = r'\{([^}]+)\}'
variables = re.findall(motif, template)
return variables


def __rag_guard(inputs, response):
"""
If a 'no_answer' input was given as a rag setting,
Expand Down Expand Up @@ -315,7 +303,7 @@ def get_rag_debug_data(
"""RAG debug data assembly"""

return RagDebugData(
user_question=query.question_answering_prompt_inputs['question'],
user_question=query.question_answering_prompt.inputs['question'],
condense_question_prompt=get_llm_prompts(records_callback_handler)[0],
condense_question=get_condense_question(records_callback_handler),
question_answering_prompt=get_llm_prompts(records_callback_handler)[1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,3 @@ async def check_llm_setting(query: LLMProviderSettingStatusQuery) -> bool:
trace_name=ObservabilityTrace.CHECK_LLM_SETTINGS.value)

return await get_llm_factory(query.setting).check_llm_setting(langfuse_callback_handler)


def llm_inference_with_parser(
llm_factory: LangChainLLMFactory, parser: BaseOutputParser
) -> AIMessage:
"""
Perform LLM inference and format the output content based on the given parser.
:param llm_factory: LangChain LLM Factory.
:param parser: Parser to format the output.
:return: Result of the language model inference with the content formatted.
"""

# Change the prompt with added format instructions
format_instructions = parser.get_format_instructions()
formatted_prompt = llm_factory.setting.prompt + '\n' + format_instructions

# Inference of the LLM with the formatted prompt
llm_output = llm_factory.invoke(formatted_prompt)

# Apply the parsing on the LLM output
llm_output.content = parser.parse(llm_output.content)

return llm_output
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (C) 2024 Credit Mutuel Arkea
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from jinja2 import Template, TemplateError

from gen_ai_orchestrator.errors.exceptions.exceptions import (
GenAIPromptTemplateException,
)
from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo
from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter
from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate

logger = logging.getLogger(__name__)

def validate_prompt_template(prompt: PromptTemplate):
"""
Prompt template validation
Args:
prompt: The prompt template
Returns:
Nothing.
Raises:
GenAIPromptTemplateException: if template is incorrect
"""
if PromptFormatter.JINJA2 == prompt.formatter:
try:
Template(prompt.template).render(prompt.inputs)
except TemplateError as exc:
logger.error('Prompt completion - template validation failed!')
logger.error(exc)
raise GenAIPromptTemplateException(
ErrorInfo(
error=exc.__class__.__name__,
cause=str(exc),
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.
#
from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate
from gen_ai_orchestrator.services.completion.completion_service import (
validate_prompt_template,
)
from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template


def test_validate_prompt_template():
Expand Down
Loading

0 comments on commit 1353a37

Please sign in to comment.