From 795261371b461eafa32516aa7bff8569f3c38f6d Mon Sep 17 00:00:00 2001 From: Jeffrey Ip Date: Wed, 13 Nov 2024 12:42:05 +0800 Subject: [PATCH] New metric --- c.py | 13 +++ deepeval/metrics/__init__.py | 1 + deepeval/metrics/json_correctness/__init__.py | 0 .../json_correctness/json_correctness.py | 107 ++++++++++++++++++ deepeval/metrics/json_correctness/template.py | 31 +++++ deepeval/metrics/utils.py | 3 +- 6 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 c.py create mode 100644 deepeval/metrics/json_correctness/__init__.py create mode 100644 deepeval/metrics/json_correctness/json_correctness.py create mode 100644 deepeval/metrics/json_correctness/template.py diff --git a/c.py b/c.py new file mode 100644 index 000000000..6531fe776 --- /dev/null +++ b/c.py @@ -0,0 +1,13 @@ +from deepeval.metrics import JsonCorrectnessMetric + +from deepeval.metrics.faithfulness.schema import FaithfulnessVerdict, Verdicts +from deepeval.test_case import LLMTestCase + +metric = JsonCorrectnessMetric(expected_schema=Verdicts, verbose_mode=True) + +answer = """{\n"verdicts": [\n{\n"verdict": "yes"\n},\n{\n "verdict": "no",\n "reason": "blah blah"\n},' + '\n{\n "verdict": "yes",\n "reason":null \n}\n]\n}""" + +test_case = LLMTestCase(input="...", actual_output=answer) + +metric.measure(test_case=test_case) diff --git a/deepeval/metrics/__init__.py b/deepeval/metrics/__init__.py index a7d551507..b29d60b04 100644 --- a/deepeval/metrics/__init__.py +++ b/deepeval/metrics/__init__.py @@ -16,6 +16,7 @@ from .contextual_precision.contextual_precision import ContextualPrecisionMetric from .knowledge_retention.knowledge_retention import KnowledgeRetentionMetric from .tool_correctness.tool_correctness import ToolCorrectnessMetric +from .json_correctness.json_correctness import JsonCorrectnessMetric from .text_to_image.text_to_image import TextToImageMetric from .image_editing.image_editing import ImageEditingMetric from .conversation_relevancy.conversation_relevancy import ( diff --git a/deepeval/metrics/json_correctness/__init__.py b/deepeval/metrics/json_correctness/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/deepeval/metrics/json_correctness/json_correctness.py b/deepeval/metrics/json_correctness/json_correctness.py new file mode 100644 index 000000000..bea354aa3 --- /dev/null +++ b/deepeval/metrics/json_correctness/json_correctness.py @@ -0,0 +1,107 @@ +from typing import List, Union +import json +from pydantic import BaseModel, ValidationError + +from deepeval.test_case import ( + LLMTestCase, + LLMTestCaseParams, + ConversationalTestCase, +) +from deepeval.metrics import BaseMetric +from deepeval.metrics.utils import ( + construct_verbose_logs, + check_llm_test_case_params, +) +from deepeval.metrics.indicator import metric_progress_indicator + + +required_params: List[LLMTestCaseParams] = [ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, +] + + +class JsonCorrectnessMetric(BaseMetric): + def __init__( + self, + expected_schema: BaseModel, + threshold: float = 0.5, + include_reason: bool = True, + strict_mode: bool = False, + verbose_mode: bool = False, + ): + self.threshold = 1 if strict_mode else threshold + self.include_reason = include_reason + self.strict_mode = strict_mode + self.verbose_mode = verbose_mode + self.expected_schema = expected_schema + + def measure( + self, + test_case: Union[LLMTestCase, ConversationalTestCase], + _show_indicator: bool = True, + ) -> float: + if isinstance(test_case, ConversationalTestCase): + test_case = test_case.turns[0] + check_llm_test_case_params(test_case, required_params, self) + + self.evaluation_cost = 0 + with metric_progress_indicator(self, _show_indicator=_show_indicator): + valid_json = True + try: + self.expected_schema.model_validate_json( + test_case.actual_output + ) + except ValidationError as e: + valid_json = False + if self.include_reason: + self.reason = self.generate_friendly_error_message(e) + + self.score = 1 if valid_json else 0 + self.success = self.score >= self.threshold + self.verbose_logs = construct_verbose_logs( + self, + steps=[ + f"LLM outputed Json:\n{test_case.actual_output}", + f"Expected Json Schema:\n{json.dumps(self.expected_schema.model_json_schema(), indent=4)}", + f"Score: {self.score}\nReason: {self.reason}", + ], + ) + + return self.score + + async def a_measure( + self, + test_case: Union[LLMTestCase, ConversationalTestCase], + _show_indicator: bool = True, + ) -> float: + return self.measure(test_case, _show_indicator=_show_indicator) + + def generate_friendly_error_message(self, error: ValidationError) -> str: + error_messages = [] + for err in error.errors(): + # Extract error location, message, and type + loc = " -> ".join(map(str, err.get("loc", []))) + msg = err.get("msg", "Unknown error") + error_type = err.get("type", "Unknown type") + + # Format each error message in a readable way + error_message = f"Error in '{loc}': {msg} (Type: {error_type})" + error_messages.append(error_message) + + # Join all error messages into a single formatted string + return "\n".join(error_messages) + + def is_successful(self) -> bool: + if self.error is not None: + self.success = False + else: + try: + self.success = self.score >= self.threshold + except: + self.success = False + return self.success + + @property + def __name__(self): + return "Json Correctness" diff --git a/deepeval/metrics/json_correctness/template.py b/deepeval/metrics/json_correctness/template.py new file mode 100644 index 000000000..c4aa02868 --- /dev/null +++ b/deepeval/metrics/json_correctness/template.py @@ -0,0 +1,31 @@ +from typing import Optional + + +class JsonCorrectnessTemplate: + @staticmethod + def generate_reason( + generated_json: str, expected_schema: str, is_valid_json: bool + ): + return f"""Based on the given generated json, generated by an LLM, and a boolean stating whether it is a valid JSON based on the expected json schema, give a reason why it is OR is not a valid Json. + +** +IMPORTANT: Please make sure to only return in JSON format, with the 'reason' key providing the reason. +Example JSON: +{{ + "reason": "The generated Json is because ." +}} + +If the json is not a valid one, your reason MUST compare `Expected Json Schema` and `Generated Json` in your reason. +** + +Generated Json: +{generated_json} + +Expected Json Schema: +{expected_schema} + +Is Valid Json? +{is_valid_json} + +JSON: +""" diff --git a/deepeval/metrics/utils.py b/deepeval/metrics/utils.py index ace63657a..8dd653ecd 100644 --- a/deepeval/metrics/utils.py +++ b/deepeval/metrics/utils.py @@ -226,7 +226,8 @@ def check_mllm_test_case_params( def trimAndLoadJson( - input_string: str, metric: Optional[BaseMetric] = None + input_string: str, + metric: Optional[BaseMetric] = None, ) -> Any: start = input_string.find("{") end = input_string.rfind("}") + 1