diff --git a/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt b/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt index 39c92fbc26..60180c3274 100644 --- a/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt +++ b/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt @@ -181,10 +181,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { query = RAGQuery( history = getDialogHistory(dialog), questionAnsweringLlmSetting = ragConfiguration.llmSetting, - questionAnsweringPromptInputs = mapOf( - "question" to action.toString(), - "locale" to userPreferences.locale.displayLanguage, - "no_answer" to ragConfiguration.noAnswerSentence + questionAnsweringPrompt = PromptTemplate( + formatter = Formatter.F_STRING.id, + template = ragConfiguration.llmSetting.prompt, + inputs = mapOf( + "question" to action.toString(), + "locale" to userPreferences.locale.displayLanguage, + "no_answer" to ragConfiguration.noAnswerSentence + ) ), embeddingQuestionEmSetting = ragConfiguration.emSetting, documentIndexName = indexName, diff --git a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt index 458ab163ef..1124632a40 100644 --- a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt +++ b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt @@ -23,10 +23,10 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting data class RAGQuery( // val condenseQuestionLlmSetting: LLMSetting, - // val condenseQuestionPromptInputs: Map, + // val condenseQuestionPrompt: PromptTemplate, val history: List = emptyList(), val questionAnsweringLlmSetting: LLMSetting, - val questionAnsweringPromptInputs: Map, + val questionAnsweringPrompt: PromptTemplate, val embeddingQuestionEmSetting: EMSetting, val documentIndexName: String, val documentSearchParams: DocumentSearchParamsBase, diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py index 929894f42b..c694f942bd 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py @@ -37,8 +37,3 @@ class BaseLLMSetting(BaseModel): ge=0, le=2, ) - prompt: str = Field( - description='The prompt to generate completions for.', - examples=['How to learn to ride a bike without wheels!'], - min_length=1, - ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py index ca27181d2d..9a8eb5983d 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py @@ -131,20 +131,17 @@ class RagQuery(BaseQuery): history: list[ChatMessage] = Field( description="Conversation history, used to reformulate the user's question." ) - question_answering_prompt_inputs: Any = Field( - description='Key-value inputs for the llm prompt when used as a template. Please note that the ' - 'chat_history field must not be specified here, it will be override by the history field', - ) # condense_question_llm_setting: LLMSetting = # Field(description="LLM setting, used to condense the user's question.") - # condense_question_prompt_inputs: Any = ( - # Field( - # description='Key-value inputs for the condense question llm prompt, when used as a template.', - # ), + # condense_question_prompt: PromptTemplate = Field( + # description='Prompt template, used to create a prompt with inputs for jinja and fstring format' # ) question_answering_llm_setting: LLMSetting = Field( description='LLM setting, used to perform a QA Prompt.' ) + question_answering_prompt : PromptTemplate = Field( + description='Prompt template, used to create a prompt with inputs for jinja and fstring format' + ) model_config = { 'json_schema_extra': { @@ -164,7 +161,11 @@ class RagQuery(BaseQuery): 'value': 'ab7***************************A1IV4B', }, 'temperature': 1.2, - 'prompt': """Use the following context to answer the question at the end. + 'model': 'gpt-3.5-turbo', + }, + 'question_answering_prompt': { + 'formatter': 'f-string', + 'template': """Use the following context to answer the question at the end. If you don't know the answer, just say {no_answer}. Context: @@ -174,12 +175,11 @@ class RagQuery(BaseQuery): {question} Answer in {locale}:""", - 'model': 'gpt-3.5-turbo', - }, - 'question_answering_prompt_inputs': { - 'question': 'How to get started playing guitar ?', - 'no_answer': "Sorry, I don't know.", - 'locale': 'French', + 'inputs': { + 'question': 'How to get started playing guitar ?', + 'no_answer': 'Sorry, I don t know.', + 'locale': 'French', + } }, 'embedding_question_em_setting': { 'provider': 'OpenAI', diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py index 2d9bdda6a3..fc8b10ada4 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py @@ -16,23 +16,14 @@ import logging import time -from typing import Optional -from jinja2 import Template, TemplateError from langchain_core.output_parsers import NumberedListOutputParser from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate -from langchain_core.runnables import RunnableConfig -from gen_ai_orchestrator.errors.exceptions.exceptions import ( - GenAIPromptTemplateException, -) from gen_ai_orchestrator.errors.handlers.openai.openai_exception_handler import ( openai_exception_handler, ) -from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo from gen_ai_orchestrator.models.observability.observability_trace import ObservabilityTrace -from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter -from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate from gen_ai_orchestrator.routers.requests.requests import ( SentenceGenerationQuery, ) @@ -42,6 +33,7 @@ from gen_ai_orchestrator.services.langchain.factories.langchain_factory import ( get_llm_factory, create_observability_callback_handler, ) +from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template logger = logging.getLogger(__name__) @@ -90,29 +82,3 @@ async def generate_and_split_sentences( ) return SentenceGenerationResponse(sentences=sentences) - - -def validate_prompt_template(prompt: PromptTemplate): - """ - Prompt template validation - - Args: - prompt: The prompt template - - Returns: - Nothing. - Raises: - GenAIPromptTemplateException: if template is incorrect - """ - if PromptFormatter.JINJA2 == prompt.formatter: - try: - Template(prompt.template).render(prompt.inputs) - except TemplateError as exc: - logger.error('Prompt completion - template validation failed!') - logger.error(exc) - raise GenAIPromptTemplateException( - ErrorInfo( - error=exc.__class__.__name__, - cause=str(exc), - ) - ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py index 33b4839337..895ab11eba 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py @@ -18,7 +18,6 @@ """ import logging -import re import time from logging import ERROR, WARNING from typing import List, Optional @@ -26,7 +25,7 @@ from langchain.chains import ConversationalRetrievalChain from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.documents import Document -from langchain_core.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate from gen_ai_orchestrator.errors.exceptions.exceptions import ( GenAIGuardCheckException, @@ -47,9 +46,6 @@ RagDocumentMetadata, TextWithFootnotes, ) -from gen_ai_orchestrator.models.vector_stores.vectore_store_provider import ( - VectorStoreProvider, -) from gen_ai_orchestrator.routers.requests.requests import RagQuery from gen_ai_orchestrator.routers.responses.responses import RagResponse from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( @@ -60,6 +56,7 @@ get_llm_factory, get_vector_store_factory, create_observability_callback_handler, ) +from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template logger = logging.getLogger(__name__) @@ -93,7 +90,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: message_history.add_ai_message(msg.text) inputs = { - **query.question_answering_prompt_inputs, + **query.question_answering_prompt.inputs, 'chat_history': message_history.messages, } @@ -180,6 +177,11 @@ def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain: index_name=query.document_index_name, embedding_function=em_factory.get_embedding_model()) + logger.info('RAG chain - LLM template validation') + validate_prompt_template(query.question_answering_prompt) + + + logger.debug('RAG chain - Document index name: %s', query.document_index_name) logger.debug('RAG chain - Create a ConversationalRetrievalChain from LLM') return ConversationalRetrievalChain.from_llm( @@ -188,27 +190,13 @@ def create_rag_chain(query: RagQuery) -> ConversationalRetrievalChain: return_source_documents=True, return_generated_question=True, combine_docs_chain_kwargs={ - 'prompt': PromptTemplate( - template=llm_factory.setting.prompt, - input_variables=__find_input_variables(llm_factory.setting.prompt), + 'prompt': LangChainPromptTemplate.from_template( + template=query.question_answering_prompt.template, + template_format=query.question_answering_prompt.formatter.value, ) }, ) - -def __find_input_variables(template): - """ - Search for input variables on a given template - - Args: - template: the template to search on - """ - - motif = r'\{([^}]+)\}' - variables = re.findall(motif, template) - return variables - - def __rag_guard(inputs, response): """ If a 'no_answer' input was given as a rag setting, @@ -315,7 +303,7 @@ def get_rag_debug_data( """RAG debug data assembly""" return RagDebugData( - user_question=query.question_answering_prompt_inputs['question'], + user_question=query.question_answering_prompt.inputs['question'], condense_question_prompt=get_llm_prompts(records_callback_handler)[0], condense_question=get_condense_question(records_callback_handler), question_answering_prompt=get_llm_prompts(records_callback_handler)[1], diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py index 14536566e9..b9f9a25d3a 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py @@ -51,28 +51,3 @@ async def check_llm_setting(query: LLMProviderSettingStatusQuery) -> bool: trace_name=ObservabilityTrace.CHECK_LLM_SETTINGS.value) return await get_llm_factory(query.setting).check_llm_setting(langfuse_callback_handler) - - -def llm_inference_with_parser( - llm_factory: LangChainLLMFactory, parser: BaseOutputParser -) -> AIMessage: - """ - Perform LLM inference and format the output content based on the given parser. - - :param llm_factory: LangChain LLM Factory. - :param parser: Parser to format the output. - - :return: Result of the language model inference with the content formatted. - """ - - # Change the prompt with added format instructions - format_instructions = parser.get_format_instructions() - formatted_prompt = llm_factory.setting.prompt + '\n' + format_instructions - - # Inference of the LLM with the formatted prompt - llm_output = llm_factory.invoke(formatted_prompt) - - # Apply the parsing on the LLM output - llm_output.content = parser.parse(llm_output.content) - - return llm_output diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/__init__.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/__init__.py new file mode 100644 index 0000000000..0b6c73c789 --- /dev/null +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024 Credit Mutuel Arkea +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/prompt_utility.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/prompt_utility.py new file mode 100644 index 0000000000..7ec3af49fb --- /dev/null +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/prompt_utility.py @@ -0,0 +1,37 @@ +import logging + +from jinja2 import Template, TemplateError + +from gen_ai_orchestrator.errors.exceptions.exceptions import ( + GenAIPromptTemplateException, +) +from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo +from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter +from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate + +logger = logging.getLogger(__name__) + +def validate_prompt_template(prompt: PromptTemplate): + """ + Prompt template validation + + Args: + prompt: The prompt template + + Returns: + Nothing. + Raises: + GenAIPromptTemplateException: if template is incorrect + """ + if PromptFormatter.JINJA2 == prompt.formatter: + try: + Template(prompt.template).render(prompt.inputs) + except TemplateError as exc: + logger.error('Prompt completion - template validation failed!') + logger.error(exc) + raise GenAIPromptTemplateException( + ErrorInfo( + error=exc.__class__.__name__, + cause=str(exc), + ) + ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py index 6637b7af7d..58c1a00aa2 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py @@ -13,9 +13,7 @@ # limitations under the License. # from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate -from gen_ai_orchestrator.services.completion.completion_service import ( - validate_prompt_template, -) +from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template def test_validate_prompt_template(): diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py index 6a8da1fea3..73ae69eea3 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py @@ -47,8 +47,7 @@ @patch('gen_ai_orchestrator.services.langchain.rag_chain.get_llm_factory') @patch('gen_ai_orchestrator.services.langchain.rag_chain.get_em_factory') @patch('gen_ai_orchestrator.services.langchain.rag_chain.get_vector_store_factory') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.PromptTemplate') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__find_input_variables') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.LangChainPromptTemplate') @patch( 'gen_ai_orchestrator.services.langchain.rag_chain.ConversationalRetrievalChain.from_llm' ) @@ -65,7 +64,6 @@ async def test_rag_chain( mocked_rag_guard, mocked_callback_init, mocked_chain_builder, - mocked_find_input_variables, mocked_prompt_template, mocked_get_vector_store_factory, mocked_get_em_factory, @@ -86,7 +84,11 @@ async def test_rag_chain( 'provider': 'OpenAI', 'api_key': {'type': 'Raw', 'value': 'ab7***************************A1IV4B'}, 'temperature': 1.2, - 'prompt': """Use the following context to answer the question at the end. + 'model': 'gpt-3.5-turbo', + }, + 'question_answering_prompt': { + 'formatter': 'f-string', + 'template': """Use the following context to answer the question at the end. If you don't know the answer, just say {no_answer}. Context: @@ -96,12 +98,11 @@ async def test_rag_chain( {question} Answer in {locale}:""", - 'model': 'gpt-3.5-turbo', - }, - 'question_answering_prompt_inputs': { - 'question': 'How to get started playing guitar ?', - 'no_answer': 'Sorry, I don t know.', - 'locale': 'French', + 'inputs' : { + 'question': 'How to get started playing guitar ?', + 'no_answer': 'Sorry, I don t know.', + 'locale': 'French', + } }, 'embedding_question_em_setting': { 'provider': 'OpenAI', @@ -179,16 +180,16 @@ async def test_rag_chain( return_generated_question=True, combine_docs_chain_kwargs={ # PromptTemplate must be mocked or searching for params in it will fail - 'prompt': mocked_prompt_template( - template=query.question_answering_llm_setting.prompt, - input_variables=['no_answer', 'context', 'question', 'locale'], + 'prompt': mocked_prompt_template.from_template( + template=query.question_answering_prompt.template, + template_format=query.question_answering_prompt.formatter.value, ) }, ) # Assert qa chain is ainvoke()d with the expected settings from query mocked_chain.ainvoke.assert_called_once_with( input={ - **query.question_answering_prompt_inputs, + **query.question_answering_prompt.inputs, 'chat_history': [ HumanMessage(content='Hello, how can I do this?'), AIMessage(content='you can do this with the following method ....'), @@ -206,12 +207,6 @@ async def test_rag_chain( ) -def test_find_input_variables(): - template = 'This is a {sample} text with {multiple} curly brace sections' - input_vars = rag_chain.__find_input_variables(template) - assert input_vars == ['sample', 'multiple'] - - @patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_log') def test_rag_guard_fails_if_no_docs_in_valid_answer(mocked_log): inputs = {'no_answer': "Sorry, I don't know."}