Skip to content

Commit

Permalink
fix(deepeval/metrics): catch attribute error on metric async evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
ottingbob committed Oct 5, 2024
1 parent 37c5fdf commit 81da3d9
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
15 changes: 14 additions & 1 deletion deepeval/metrics/contextual_relevancy/contextual_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def _a_generate_reason(self, input: str):
try:
res: Reason = await self.model.a_generate(prompt, schema=Reason)
return res.reason
except TypeError:
except (TypeError, AttributeError):
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
return data["reason"]
Expand Down Expand Up @@ -175,6 +175,19 @@ def _calculate_score(self):
if total_verdicts == 0:
return 0

# Convert verdicts to specific type if LLM has constructed them as a string
if (
isinstance(self.verdicts, list)
and len(self.verdicts) > 0
and isinstance(self.verdicts[0], str)
):
self.verdicts = [
ContextualRelevancyVerdict(
verdict=trimAndLoadJson(v, self).get("verdict")
)
for v in self.verdicts
]

relevant_nodes = 0
for verdict in self.verdicts:
if verdict.verdict.lower() == "yes":
Expand Down
8 changes: 4 additions & 4 deletions deepeval/metrics/faithfulness/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def _a_generate_verdicts(self) -> List[FaithfulnessVerdict]:
)
verdicts = [item for item in res.verdicts]
return verdicts
except TypeError:
except (TypeError, AttributeError):
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
verdicts = [
Expand Down Expand Up @@ -253,7 +253,7 @@ async def _a_generate_truths(self, retrieval_context: str) -> List[str]:
try:
res: Truths = await self.model.a_generate(prompt, schema=Truths)
return res.truths
except TypeError:
except (TypeError, AttributeError):
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
return data["truths"]
Expand All @@ -272,7 +272,7 @@ def _generate_truths(self, retrieval_context: str) -> List[str]:
try:
res: Truths = self.model.generate(prompt, schema=Truths)
return res.truths
except TypeError:
except (TypeError, AttributeError):
res = self.model.generate(prompt)
data = trimAndLoadJson(res, self)
return data["truths"]
Expand All @@ -288,7 +288,7 @@ async def _a_generate_claims(self, actual_output: str) -> List[str]:
try:
res: Claims = await self.model.a_generate(prompt, schema=Claims)
return res.claims
except TypeError:
except (TypeError, AttributeError):
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
return data["claims"]
Expand Down
2 changes: 1 addition & 1 deletion deepeval/metrics/g_eval/g_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def _a_evaluate(
prompt, schema=ReasonScore
)
return res.score, res.reason
except TypeError:
except (TypeError, AttributeError):
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
return data["score"], data["reason"]
Expand Down
3 changes: 3 additions & 0 deletions deepeval/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import json
import re
from typing import Any, Dict, Optional, List, Union, Tuple
from deepeval.errors import MissingTestCaseParamsError
from deepeval.models import (
Expand Down Expand Up @@ -235,6 +236,8 @@ def trimAndLoadJson(
end = len(input_string)

jsonStr = input_string[start:end] if start != -1 and end != 0 else ""
# Remove trailing comma if one is present
jsonStr = re.sub(r",\s*([\]}])", r"\1", jsonStr)

try:
return json.loads(jsonStr)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,15 @@ def test_check_llm_test_case_params_raies_ValueError_for_wrong_type():
test_case_params=[LLMTestCaseParams.ACTUAL_OUTPUT],
metric=BaseMetric(),
)


def test_trimAndLoadJson_correctly_parses_with_trailing_comma():
test_data = [
'{\n "verdict": "yes",\n}',
'{\n "verdict": "yes",\n}',
]
verdicts = [utils.trimAndLoadJson(v) for v in test_data]

assert len(verdicts) == 2
for v in verdicts:
assert v.get("verdict") == "yes"

0 comments on commit 81da3d9

Please sign in to comment.