Skip to content

Commit

Permalink
Merge pull request #24 from decodingml/rag-eval
Browse files Browse the repository at this point in the history
RAG evaluation using RAGAs
  • Loading branch information
Joywalker authored Jun 9, 2024
2 parents 6353fcd + 014a95e commit 57f5ec5
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 31 deletions.
12 changes: 6 additions & 6 deletions course/module-3/insert_data_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def download_dataset(output_dir: Path = Path("data")) -> list:
"type": "post",
"author_id": "2",
},
{
"file_name": "repositories_paul_iusztin.json",
"file_id": "1tSWrlj_u85twAqVus-l0mzqgYVV6WHVz",
"type": "repository",
"author_id": "2",
},
# {
# "file_name": "repositories_paul_iusztin.json",
# "file_id": "1tSWrlj_u85twAqVus-l0mzqgYVV6WHVz",
# "type": "repository",
# "author_id": "2",
# },
]
for file in files:
file["file_path"] = str(output_dir / file["file_name"])
Expand Down
1 change: 0 additions & 1 deletion course/module-5/evaluation/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from langchain_openai import ChatOpenAI
from llm_components.chain import GeneralChain
from llm_components.chain import GeneralChain
from llm_components.prompt_templates import LLMEvaluationTemplate
from settings import settings

Expand Down
63 changes: 61 additions & 2 deletions course/module-5/evaluation/rag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,77 @@
import llm_components.prompt_templates as templates
from datasets import Dataset
from langchain_openai import ChatOpenAI
from llm_components.chain import GeneralChain
from pandas import DataFrame
from ragas import evaluate
from ragas.embeddings import HuggingfaceEmbeddings
from ragas.metrics import (
answer_correctness,
answer_similarity,
context_entity_recall,
context_recall,
context_relevancy,
context_utilization,
)
from settings import settings

# Evaluating against the following metrics
# RETRIEVAL BASED
# 1. Context Utilization - How well the context is utilized
# 2. Context Relevancy - (VDB based) measures the relevance of retrieved context
# 3. Context Recall - How well the context is recalled in the answer
# 4. Context Entity Recall - a measure of what fraction of entities are recalled from ground_truths

def evaluate(query: str, context: list[str], output: str) -> str:
# END-TO-END
# 5. Answer Similarity - measures the semantic resemblance between the answer and gt answer
# 6. Answer Corectness - measures the correctness of the answer compared to gt

METRICS = [
context_utilization,
context_relevancy,
context_recall,
answer_similarity,
context_entity_recall,
answer_correctness,
]


def evaluate_w_template(query: str, context: list[str], output: str) -> str:
evaluation_template = templates.RAGEvaluationTemplate()
prompt_template = evaluation_template.create_template()

model = ChatOpenAI(model=settings.OPENAI_MODEL_ID)
model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, api_key=settings.OPENAI_API_KEY)
chain = GeneralChain.get_chain(
llm=model, output_key="rag_eval", template=prompt_template
)

response = chain.invoke({"query": query, "context": context, "output": output})

return response["rag_eval"]


def evaluate_w_ragas(query: str, context: list[str], output: str) -> DataFrame:
"""
Evaluate the RAG (query,context,response) using RAGAS
"""
data_sample = {
"question": [query], # Question as Sequence(str)
"answer": [output], # Answer as Sequence(str)
"contexts": [context], # Context as Sequence(str)
"ground_truth": ["".join(context)], # Ground Truth as Sequence(str)
}

oai_model = ChatOpenAI(
model=settings.OPENAI_MODEL_ID,
api_key=settings.OPENAI_API_KEY,
)
embd_model = HuggingfaceEmbeddings(model=settings.EMBEDDING_MODEL_ID)
dataset = Dataset.from_dict(data_sample)
score = evaluate(
llm=oai_model,
embeddings=embd_model,
dataset=dataset,
metrics=METRICS,
)

return score
46 changes: 37 additions & 9 deletions course/module-5/inference_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import time

import pandas as pd
from evaluation import evaluate_llm
from evaluation.rag import evaluate_w_ragas
from llm_components.prompt_templates import InferenceTemplate
from monitoring import PromptMonitoringManager
from qwak_inference import RealTimeClient
Expand All @@ -14,6 +17,12 @@ def __init__(self) -> None:
)
self.template = InferenceTemplate()
self.prompt_monitoring_manager = PromptMonitoringManager()
self._timings = {
"retrieval": 0.0,
"generation": 0.0,
"evaluation_rag": 0.0,
"evaluation_llm": 0.0,
}

def generate(
self,
Expand All @@ -28,6 +37,7 @@ def generate(
}

if enable_rag is True:
st_time = time.time_ns()
retriever = VectorRetriever(query=query)
hits = retriever.retrieve_top_k(
k=settings.TOP_K, to_expand_to_n_queries=settings.EXPAND_N_QUERY
Expand All @@ -36,34 +46,52 @@ def generate(
prompt_template_variables["context"] = context

prompt = prompt_template.format(question=query, context=context)
en_time = time.time_ns()
self._timings["retrieval"] = (en_time - st_time) / 1e9
else:
prompt = prompt_template.format(question=query)

st_time = time.time_ns()
input_ = pd.DataFrame([{"instruction": prompt}]).to_json()

response: list[dict] = self.qwak_client.predict(input_)
answer = response[0]["content"][0]
answer = response[0]["content"]
en_time = time.time_ns()
self._timings["generation"] = (en_time - st_time) / 1e9

if enable_evaluation is True:
evaluation_result = evaluate_llm(query=query, output=answer)
if enable_rag:
st_time = time.time_ns()
rag_eval_scores = evaluate_w_ragas(
query=query, output=answer, context=context
)
en_time = time.time_ns()
self._timings["evaluation_rag"] = (en_time - st_time) / 1e9
st_time = time.time_ns()
llm_eval = evaluate_llm(query=query, output=answer)
en_time = time.time_ns()
self._timings["evaluation_llm"] = (en_time - st_time) / 1e9
evaluation_result = {
"llm_evaluation": "" if not llm_eval else llm_eval,
"rag_evaluation": {} if not rag_eval_scores else rag_eval_scores,
}
else:
evaluation_result = None

if enable_monitoring is True:
if evaluation_result is not None:
metadata = {"llm_evaluation_result": evaluation_result}
else:
metadata = None

self.prompt_monitoring_manager.log(
prompt=prompt,
prompt_template=prompt_template.template,
prompt_template_variables=prompt_template_variables,
output=answer,
metadata=metadata,
)
self.prompt_monitoring_manager.log_chain(
query=query, response=answer, eval_output=evaluation_result
query=query,
context=context,
llm_gen=answer,
llm_eval_output=evaluation_result["llm_evaluation"],
rag_eval_scores=evaluation_result["rag_evaluation"],
timings=self._timings,
)

return {"answer": answer, "llm_evaluation_result": evaluation_result}
6 changes: 3 additions & 3 deletions course/module-5/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
query = """
Hello my author_id is 1.
Could you please draft a LinkedIn post discussing Vector Databases?
I'm particularly interested in how do they work.
Could you please draft a LinkedIn post discussing Feature Stores?
I'm particularly interested in their importance and how they can be used in ML systems.
"""

response = inference_endpoint.generate(
query=query,
enable_rag=True,
enable_evaluation=False,
enable_evaluation=True,
enable_monitoring=True,
)

Expand Down
59 changes: 52 additions & 7 deletions course/module-5/monitoring/prompt_monitoring.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import comet_llm
from settings import settings

Expand Down Expand Up @@ -32,7 +34,15 @@ def log(
)

@classmethod
def log_chain(cls, query: str, response: str, eval_output: str):
def log_chain(
cls,
query: str,
context: List[str],
llm_gen: str,
llm_eval_output: str,
rag_eval_scores: dict | None = None,
timings: dict | None = None,
) -> None:
comet_llm.init(project=f"{settings.COMET_PROJECT}-monitoring")
comet_llm.start_chain(
inputs={"user_query": query},
Expand All @@ -41,14 +51,49 @@ def log_chain(cls, query: str, response: str, eval_output: str):
workspace=settings.COMET_WORKSPACE,
)
with comet_llm.Span(
category="twin_response",
category="Vector Retrieval",
name="retrieval_step",
inputs={"user_query": query},
metadata={"duration": timings.get("retrieval")},
) as span:
span.set_outputs(outputs={"retrieved_context": context})

with comet_llm.Span(
category="LLM Generation",
name="generation_step",
inputs={"user_query": query},
metadata={
"model_used": settings.OPENAI_MODEL_ID,
"duration": timings.get("generation"),
},
) as span:
span.set_outputs(outputs={"generation": llm_gen})

with comet_llm.Span(
category="Evaluation",
name="llm_eval_step",
inputs={"query": llm_gen, "user_query": query},
metadata={
"model_used": settings.OPENAI_MODEL_ID,
"duration": timings.get("evaluation_llm"),
},
) as span:
span.set_outputs(outputs=response)
span.set_outputs(outputs={"llm_eval_result": llm_eval_output})

with comet_llm.Span(
category="gpt3.5-eval",
inputs={"eval_result": eval_output},
category="Evaluation",
name="rag_eval_step",
inputs={
"user_query": query,
"retrieved_context": context,
"llm_gen": llm_gen,
},
metadata={
"model_used": settings.OPENAI_MODEL_ID,
"embd_model": settings.EMBEDDING_MODEL_ID,
"eval_framework": "RAGAS",
"duration": timings.get("evaluation_rag"),
},
) as span:
span.set_outputs(outputs=response)
comet_llm.end_chain(outputs={"response": response, "eval_output": eval_output})
span.set_outputs(outputs={"rag_eval_scores": rag_eval_scores})
comet_llm.end_chain(outputs={"response": llm_gen})
3 changes: 1 addition & 2 deletions course/module-5/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ authors = [
"Paul Iusztin <[email protected]>",
"Alex Vesa <[email protected]>",
]
package-mode = false
readme = "README.md"


Expand Down Expand Up @@ -39,7 +38,7 @@ datasets = "^2.19.1"
peft = "^0.11.1"
bitsandbytes = "^0.43.1"
qwak-inference = "^0.1.17"

ragas= "^0.1.9"

[build-system]
requires = ["poetry-core"]
Expand Down
2 changes: 1 addition & 1 deletion course/module-5/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AppSettings(BaseSettings):
QDRANT_DATABASE_URL: str = "http://localhost:6333"

QDRANT_CLOUD_URL: str = "str"
USE_QDRANT_CLOUD: bool = True
USE_QDRANT_CLOUD: bool = False
QDRANT_APIKEY: str | None = None

# MQ config
Expand Down

0 comments on commit 57f5ec5

Please sign in to comment.