-
Notifications
You must be signed in to change notification settings - Fork 358
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a19f13
commit 7952613
Showing
6 changed files
with
154 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from deepeval.metrics import JsonCorrectnessMetric | ||
|
||
from deepeval.metrics.faithfulness.schema import FaithfulnessVerdict, Verdicts | ||
from deepeval.test_case import LLMTestCase | ||
|
||
metric = JsonCorrectnessMetric(expected_schema=Verdicts, verbose_mode=True) | ||
|
||
answer = """{\n"verdicts": [\n{\n"verdict": "yes"\n},\n{\n "verdict": "no",\n "reason": "blah blah"\n},' | ||
'\n{\n "verdict": "yes",\n "reason":null \n}\n]\n}""" | ||
|
||
test_case = LLMTestCase(input="...", actual_output=answer) | ||
|
||
metric.measure(test_case=test_case) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import List, Union | ||
import json | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from deepeval.test_case import ( | ||
LLMTestCase, | ||
LLMTestCaseParams, | ||
ConversationalTestCase, | ||
) | ||
from deepeval.metrics import BaseMetric | ||
from deepeval.metrics.utils import ( | ||
construct_verbose_logs, | ||
check_llm_test_case_params, | ||
) | ||
from deepeval.metrics.indicator import metric_progress_indicator | ||
|
||
|
||
required_params: List[LLMTestCaseParams] = [ | ||
LLMTestCaseParams.INPUT, | ||
LLMTestCaseParams.ACTUAL_OUTPUT, | ||
] | ||
|
||
|
||
class JsonCorrectnessMetric(BaseMetric): | ||
def __init__( | ||
self, | ||
expected_schema: BaseModel, | ||
threshold: float = 0.5, | ||
include_reason: bool = True, | ||
strict_mode: bool = False, | ||
verbose_mode: bool = False, | ||
): | ||
self.threshold = 1 if strict_mode else threshold | ||
self.include_reason = include_reason | ||
self.strict_mode = strict_mode | ||
self.verbose_mode = verbose_mode | ||
self.expected_schema = expected_schema | ||
|
||
def measure( | ||
self, | ||
test_case: Union[LLMTestCase, ConversationalTestCase], | ||
_show_indicator: bool = True, | ||
) -> float: | ||
if isinstance(test_case, ConversationalTestCase): | ||
test_case = test_case.turns[0] | ||
check_llm_test_case_params(test_case, required_params, self) | ||
|
||
self.evaluation_cost = 0 | ||
with metric_progress_indicator(self, _show_indicator=_show_indicator): | ||
valid_json = True | ||
try: | ||
self.expected_schema.model_validate_json( | ||
test_case.actual_output | ||
) | ||
except ValidationError as e: | ||
valid_json = False | ||
if self.include_reason: | ||
self.reason = self.generate_friendly_error_message(e) | ||
|
||
self.score = 1 if valid_json else 0 | ||
self.success = self.score >= self.threshold | ||
self.verbose_logs = construct_verbose_logs( | ||
self, | ||
steps=[ | ||
f"LLM outputed Json:\n{test_case.actual_output}", | ||
f"Expected Json Schema:\n{json.dumps(self.expected_schema.model_json_schema(), indent=4)}", | ||
f"Score: {self.score}\nReason: {self.reason}", | ||
], | ||
) | ||
|
||
return self.score | ||
|
||
async def a_measure( | ||
self, | ||
test_case: Union[LLMTestCase, ConversationalTestCase], | ||
_show_indicator: bool = True, | ||
) -> float: | ||
return self.measure(test_case, _show_indicator=_show_indicator) | ||
|
||
def generate_friendly_error_message(self, error: ValidationError) -> str: | ||
error_messages = [] | ||
for err in error.errors(): | ||
# Extract error location, message, and type | ||
loc = " -> ".join(map(str, err.get("loc", []))) | ||
msg = err.get("msg", "Unknown error") | ||
error_type = err.get("type", "Unknown type") | ||
|
||
# Format each error message in a readable way | ||
error_message = f"Error in '{loc}': {msg} (Type: {error_type})" | ||
error_messages.append(error_message) | ||
|
||
# Join all error messages into a single formatted string | ||
return "\n".join(error_messages) | ||
|
||
def is_successful(self) -> bool: | ||
if self.error is not None: | ||
self.success = False | ||
else: | ||
try: | ||
self.success = self.score >= self.threshold | ||
except: | ||
self.success = False | ||
return self.success | ||
|
||
@property | ||
def __name__(self): | ||
return "Json Correctness" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Optional | ||
|
||
|
||
class JsonCorrectnessTemplate: | ||
@staticmethod | ||
def generate_reason( | ||
generated_json: str, expected_schema: str, is_valid_json: bool | ||
): | ||
return f"""Based on the given generated json, generated by an LLM, and a boolean stating whether it is a valid JSON based on the expected json schema, give a reason why it is OR is not a valid Json. | ||
** | ||
IMPORTANT: Please make sure to only return in JSON format, with the 'reason' key providing the reason. | ||
Example JSON: | ||
{{ | ||
"reason": "The generated Json is <is_valid_json> because <your_reason>." | ||
}} | ||
If the json is not a valid one, your reason MUST compare `Expected Json Schema` and `Generated Json` in your reason. | ||
** | ||
Generated Json: | ||
{generated_json} | ||
Expected Json Schema: | ||
{expected_schema} | ||
Is Valid Json? | ||
{is_valid_json} | ||
JSON: | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters