Skip to content

Commit

Permalink
New metric
Browse files Browse the repository at this point in the history
  • Loading branch information
penguine-ip committed Nov 13, 2024
1 parent 9a19f13 commit 7952613
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 1 deletion.
13 changes: 13 additions & 0 deletions c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from deepeval.metrics import JsonCorrectnessMetric

from deepeval.metrics.faithfulness.schema import FaithfulnessVerdict, Verdicts
from deepeval.test_case import LLMTestCase

metric = JsonCorrectnessMetric(expected_schema=Verdicts, verbose_mode=True)

answer = """{\n"verdicts": [\n{\n"verdict": "yes"\n},\n{\n "verdict": "no",\n "reason": "blah blah"\n},'
'\n{\n "verdict": "yes",\n "reason":null \n}\n]\n}"""

test_case = LLMTestCase(input="...", actual_output=answer)

metric.measure(test_case=test_case)
1 change: 1 addition & 0 deletions deepeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .contextual_precision.contextual_precision import ContextualPrecisionMetric
from .knowledge_retention.knowledge_retention import KnowledgeRetentionMetric
from .tool_correctness.tool_correctness import ToolCorrectnessMetric
from .json_correctness.json_correctness import JsonCorrectnessMetric
from .text_to_image.text_to_image import TextToImageMetric
from .image_editing.image_editing import ImageEditingMetric
from .conversation_relevancy.conversation_relevancy import (
Expand Down
Empty file.
107 changes: 107 additions & 0 deletions deepeval/metrics/json_correctness/json_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import List, Union
import json
from pydantic import BaseModel, ValidationError

from deepeval.test_case import (
LLMTestCase,
LLMTestCaseParams,
ConversationalTestCase,
)
from deepeval.metrics import BaseMetric
from deepeval.metrics.utils import (
construct_verbose_logs,
check_llm_test_case_params,
)
from deepeval.metrics.indicator import metric_progress_indicator


required_params: List[LLMTestCaseParams] = [
LLMTestCaseParams.INPUT,
LLMTestCaseParams.ACTUAL_OUTPUT,
]


class JsonCorrectnessMetric(BaseMetric):
def __init__(
self,
expected_schema: BaseModel,
threshold: float = 0.5,
include_reason: bool = True,
strict_mode: bool = False,
verbose_mode: bool = False,
):
self.threshold = 1 if strict_mode else threshold
self.include_reason = include_reason
self.strict_mode = strict_mode
self.verbose_mode = verbose_mode
self.expected_schema = expected_schema

def measure(
self,
test_case: Union[LLMTestCase, ConversationalTestCase],
_show_indicator: bool = True,
) -> float:
if isinstance(test_case, ConversationalTestCase):
test_case = test_case.turns[0]
check_llm_test_case_params(test_case, required_params, self)

self.evaluation_cost = 0
with metric_progress_indicator(self, _show_indicator=_show_indicator):
valid_json = True
try:
self.expected_schema.model_validate_json(
test_case.actual_output
)
except ValidationError as e:
valid_json = False
if self.include_reason:
self.reason = self.generate_friendly_error_message(e)

self.score = 1 if valid_json else 0
self.success = self.score >= self.threshold
self.verbose_logs = construct_verbose_logs(
self,
steps=[
f"LLM outputed Json:\n{test_case.actual_output}",
f"Expected Json Schema:\n{json.dumps(self.expected_schema.model_json_schema(), indent=4)}",
f"Score: {self.score}\nReason: {self.reason}",
],
)

return self.score

async def a_measure(
self,
test_case: Union[LLMTestCase, ConversationalTestCase],
_show_indicator: bool = True,
) -> float:
return self.measure(test_case, _show_indicator=_show_indicator)

def generate_friendly_error_message(self, error: ValidationError) -> str:
error_messages = []
for err in error.errors():
# Extract error location, message, and type
loc = " -> ".join(map(str, err.get("loc", [])))
msg = err.get("msg", "Unknown error")
error_type = err.get("type", "Unknown type")

# Format each error message in a readable way
error_message = f"Error in '{loc}': {msg} (Type: {error_type})"
error_messages.append(error_message)

# Join all error messages into a single formatted string
return "\n".join(error_messages)

def is_successful(self) -> bool:
if self.error is not None:
self.success = False
else:
try:
self.success = self.score >= self.threshold
except:
self.success = False
return self.success

@property
def __name__(self):
return "Json Correctness"
31 changes: 31 additions & 0 deletions deepeval/metrics/json_correctness/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional


class JsonCorrectnessTemplate:
@staticmethod
def generate_reason(
generated_json: str, expected_schema: str, is_valid_json: bool
):
return f"""Based on the given generated json, generated by an LLM, and a boolean stating whether it is a valid JSON based on the expected json schema, give a reason why it is OR is not a valid Json.
**
IMPORTANT: Please make sure to only return in JSON format, with the 'reason' key providing the reason.
Example JSON:
{{
"reason": "The generated Json is <is_valid_json> because <your_reason>."
}}
If the json is not a valid one, your reason MUST compare `Expected Json Schema` and `Generated Json` in your reason.
**
Generated Json:
{generated_json}
Expected Json Schema:
{expected_schema}
Is Valid Json?
{is_valid_json}
JSON:
"""
3 changes: 2 additions & 1 deletion deepeval/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def check_mllm_test_case_params(


def trimAndLoadJson(
input_string: str, metric: Optional[BaseMetric] = None
input_string: str,
metric: Optional[BaseMetric] = None,
) -> Any:
start = input_string.find("{")
end = input_string.rfind("}") + 1
Expand Down

0 comments on commit 7952613

Please sign in to comment.