diff --git a/deepeval/dataset/dataset.py b/deepeval/dataset/dataset.py index a6f258a1..b9828496 100644 --- a/deepeval/dataset/dataset.py +++ b/deepeval/dataset/dataset.py @@ -23,7 +23,7 @@ DatasetHttpResponse, ) from deepeval.dataset.golden import Golden, ConversationalGolden -from deepeval.test_case import LLMTestCase, ConversationalTestCase +from deepeval.test_case import LLMTestCase, ConversationalTestCase, MLLMTestCase from deepeval.utils import convert_keys_to_snake_case, is_confident from deepeval.synthesizer.types import * @@ -31,13 +31,12 @@ def validate_test_case_type( - test_case: Union[LLMTestCase, ConversationalTestCase], subject: str + test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], subject: str ): if not isinstance(test_case, LLMTestCase) and not isinstance( - test_case, ConversationalTestCase - ): + test_case, ConversationalTestCase) and not isinstance(test_case, MLLMTestCase): raise TypeError( - f"Provided `{subject}` must be a list of LLMTestCase or ConversationalTestCase" + f"Provided `{subject}` must be a list of LLMTestCase, ConversationalTestCase, or MLLMTestCase" ) @@ -54,7 +53,7 @@ class EvaluationDataset: def __init__( self, - test_cases: List[Union[LLMTestCase, ConversationalTestCase]] = [], + test_cases: List[Union[LLMTestCase, ConversationalTestCase, MLLMTestCase]] = [], goldens: List[Golden] = [], conversational_goldens: List[ConversationalGolden] = [], ): @@ -67,6 +66,7 @@ def __init__( llm_test_cases = [] conversational_test_cases = [] + mllm_test_cases = [] for test_case in test_cases: if isinstance(test_case, LLMTestCase): test_case._dataset_rank = len(llm_test_cases) @@ -74,9 +74,13 @@ def __init__( elif isinstance(test_case, ConversationalTestCase): test_case._dataset_rank = len(conversational_test_cases) conversational_test_cases.append(test_case) + elif isinstance(test_case, MLLMTestCase): + test_case._dataset_rank = len(mllm_test_cases) + mllm_test_cases.append(test_case) self._llm_test_cases = llm_test_cases self._conversational_test_cases = conversational_test_cases + self._mllm_test_cases = mllm_test_cases def __repr__(self): return ( @@ -87,22 +91,23 @@ def __repr__(self): ) @property - def test_cases(self) -> List[Union[LLMTestCase, ConversationalTestCase]]: - return self._llm_test_cases + self._conversational_test_cases + def test_cases(self) -> List[Union[LLMTestCase, ConversationalTestCase, MLLMTestCase]]: + return self._llm_test_cases + self._conversational_test_cases + self._mllm_test_cases @test_cases.setter def test_cases( - self, test_cases: List[Union[LLMTestCase, ConversationalTestCase]] + self, test_cases: List[Union[LLMTestCase, ConversationalTestCase, MLLMTestCase]] ): if not isinstance(test_cases, list): raise TypeError("'test_cases' must be a list") llm_test_cases = [] conversational_test_cases = [] + mllm_test_cases = [] for test_case in test_cases: if not isinstance(test_case, LLMTestCase) and not isinstance( - test_case, ConversationalTestCase - ): + test_case, ConversationalTestCase) and not isinstance( + test_case, MLLMTestCase): continue test_case._dataset_alias = self._alias @@ -113,12 +118,16 @@ def test_cases( elif isinstance(test_case, ConversationalTestCase): test_case._dataset_rank = len(conversational_test_cases) conversational_test_cases.append(test_case) + elif isinstance(test_case, MLLMTestCase): + test_case._dataset_rank = len(mllm_test_cases) + mllm_test_cases.append(test_case) self._llm_test_cases = llm_test_cases self._conversational_test_cases = conversational_test_cases + self._mllm_test_cases = mllm_test_cases def add_test_case( - self, test_case: Union[LLMTestCase, ConversationalTestCase] + self, test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase] ): validate_test_case_type(test_case, subject="test cases") @@ -130,7 +139,10 @@ def add_test_case( elif isinstance(test_case, ConversationalTestCase): test_case._dataset_rank = len(self._conversational_test_cases) self._conversational_test_cases.append(test_case) - + elif isinstance(test_case, MLLMTestCase): + test_case._dataset_rank = len(self._mllm_test_cases) + self._mllm_test_cases.append(test_case) + def __iter__(self): return iter(self.test_cases) diff --git a/deepeval/evaluate.py b/deepeval/evaluate.py index 7b7b15a0..318ad333 100644 --- a/deepeval/evaluate.py +++ b/deepeval/evaluate.py @@ -8,6 +8,7 @@ from tqdm.asyncio import tqdm_asyncio from tqdm import tqdm +from deepeval.types import Image from deepeval.metrics.utils import copy_metrics from deepeval.test_run.hyperparameters import process_hyperparameters from deepeval.utils import ( @@ -17,16 +18,17 @@ should_verbose_print, ) from deepeval.telemetry import capture_evaluation_run -from deepeval.metrics import BaseMetric, BaseConversationalMetric +from deepeval.metrics import BaseMetric, BaseConversationalMetric, BaseMultimodalMetric from deepeval.metrics.indicator import ( format_metric_description, measure_metrics_with_indicator, ) -from deepeval.test_case import LLMTestCase, ConversationalTestCase +from deepeval.test_case import LLMTestCase, ConversationalTestCase, MLLMTestCase from deepeval.constants import PYTEST_RUN_TEST_NAME from deepeval.test_run import ( global_test_run_manager, LLMApiTestCase, + MLLMApiTestCase, ConversationalApiTestCase, MetricData, TestRunManager, @@ -44,6 +46,7 @@ @dataclass class TestResult: + """Returned from run_test""" success: bool metrics_data: Union[List[MetricData], None] conversational: bool @@ -53,6 +56,13 @@ class TestResult: context: Optional[List[str]] = None retrieval_context: Optional[List[str]] = None +@dataclass +class MLLMTestResult: + """Returned from run_test""" + success: bool + metrics_data: List[MetricData] + input: List[Union[str, Image]] + actual_output: List[Union[str, Image]] def create_metric_data(metric: BaseMetric) -> MetricData: if metric.error is not None: @@ -84,8 +94,16 @@ def create_metric_data(metric: BaseMetric) -> MetricData: def create_test_result( - test_case: Union[LLMApiTestCase, ConversationalApiTestCase], -) -> TestResult: + test_case: Union[LLMApiTestCase, ConversationalApiTestCase, MLLMApiTestCase], +) -> Union[TestResult, MLLMTestResult]: + if isinstance(test_case, MLLMApiTestCase): + return MLLMTestResult( + success=test_case.success, + metrics_data=test_case.metrics_data, + input=test_case.input, + actual_output=test_case.actual_output, + ) + if isinstance(test_case, ConversationalApiTestCase): return TestResult( success=test_case.success, @@ -111,7 +129,7 @@ def create_test_result( def create_api_test_case( - test_case: Union[LLMTestCase, ConversationalTestCase], + test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], index: Optional[int] = None, conversational_instance_id: Optional[int] = None, additional_metadata: Optional[Dict] = None, @@ -196,11 +214,27 @@ def create_api_test_case( ] return api_test_case + + elif isinstance(test_case, MLLMTestCase): + name = os.getenv( + PYTEST_RUN_TEST_NAME, f"mllm_test_case_{index}" + ) + api_test_case = MLLMApiTestCase( + name=name, + input=test_case.input, + actualOutput=test_case.actual_output, + success=True, + metricsData=None, + runDuration=0, + evaluationCost=None, + order=test_case._dataset_rank + ) + return api_test_case def execute_test_cases( - test_cases: List[Union[LLMTestCase, ConversationalTestCase]], - metrics: List[Union[BaseMetric, BaseConversationalMetric]], + test_cases: List[Union[LLMTestCase, ConversationalTestCase, MLLMTestCase]], + metrics: List[Union[BaseMetric, BaseConversationalMetric, BaseMultimodalMetric]], ignore_errors: bool, use_cache: bool, show_indicator: bool, @@ -223,11 +257,14 @@ def execute_test_cases( conversational_metrics: List[BaseConversationalMetric] = [] llm_metrics: List[BaseMetric] = [] + mllm_metrics: List[BaseMultimodalMetric] = [] for metric in metrics: if isinstance(metric, BaseMetric): llm_metrics.append(metric) elif isinstance(metric, BaseConversationalMetric): conversational_metrics.append(metric) + elif isinstance(metric, BaseMultimodalMetric): + mllm_metrics.append(metric) global llm_test_case_lookup_map llm_test_case_lookup_map = {} @@ -237,11 +274,12 @@ def execute_test_cases( if message.should_evaluate: test_cases.append(message.llm_test_case) - test_results: List[TestResult] = [] + test_results: List[Union[TestResult, MLLMTestResult]] = [] def evaluate_test_cases(pbar: Optional[tqdm] = None): llm_test_case_count = -1 conversational_test_case_count = -1 + mllm_test_case_count = -1 show_metric_indicator = show_indicator and not _use_bar_indicator for test_case in test_cases: with capture_evaluation_run("test case"): @@ -386,7 +424,51 @@ def evaluate_test_cases(pbar: Optional[tqdm] = None): ### Update Test Run ### test_run_manager.update_test_run(api_test_case, test_case) + + + # No caching and not sending test cases to Confident AI for multimodal metrics yet + elif isinstance(test_case, MLLMTestCase): + mllm_test_case_count += 1 + api_test_case: MLLMApiTestCase = create_api_test_case( + test_case, mllm_test_case_count + ) + + test_start_time = time.perf_counter() + for metric in mllm_metrics: + # Skip non multimodal metrics for mllm test cases + if isinstance(metric, BaseMultimodalMetric): + metric.async_mode = False # Override metric async + try: + metric.measure( + test_case, + _show_indicator=show_metric_indicator, + ) + except TypeError: + try: + metric.measure(test_case) + except Exception as e: + if ignore_errors: + metric.error = str(e) + metric.success = False + else: + raise + except Exception as e: + if ignore_errors: + metric.error = str(e) + metric.success = False + else: + raise + metric_data = create_metric_data(metric) + api_test_case.update_metric_data(metric_data) + test_end_time = time.perf_counter() + if len(mllm_metrics) > 0: + run_duration = test_end_time - test_start_time + api_test_case.update_run_duration(run_duration) + + ### Update Test Run ### + test_run_manager.update_test_run(api_test_case, test_case) + test_result = create_test_result(api_test_case) test_results.append(test_result) @@ -403,13 +485,13 @@ def evaluate_test_cases(pbar: Optional[tqdm] = None): evaluate_test_cases(pbar) else: evaluate_test_cases() - + return test_results async def a_execute_test_cases( - test_cases: List[Union[LLMTestCase, ConversationalTestCase]], - metrics: List[Union[BaseMetric, BaseConversationalMetric]], + test_cases: List[Union[LLMTestCase, ConversationalTestCase, MLLMTestCase]], + metrics: List[Union[BaseMetric, BaseConversationalMetric, BaseMultimodalMetric]], ignore_errors: bool, use_cache: bool, show_indicator: bool, @@ -419,6 +501,7 @@ async def a_execute_test_cases( test_run_manager: Optional[TestRunManager] = None, _use_bar_indicator: bool = True, ) -> List[TestResult]: + global_test_run_cache_manager.disable_write_cache = save_to_disk == False if test_run_manager is None: @@ -433,11 +516,14 @@ async def a_execute_test_cases( llm_metrics: List[BaseMetric] = [] conversational_metrics: List[BaseConversationalMetric] = [] + mllm_metrics: List[BaseMultimodalMetric] = [] for metric in metrics: if isinstance(metric, BaseMetric): llm_metrics.append(metric) elif isinstance(metric, BaseConversationalMetric): conversational_metrics.append(metric) + elif isinstance(metric, BaseMultimodalMetric): + mllm_metrics.append(metric) global llm_test_case_lookup_map llm_test_case_lookup_map = {} @@ -453,7 +539,8 @@ async def a_execute_test_cases( llm_test_case_counter = -1 conversational_test_case_counter = -1 - test_results: List[TestResult] = [] + mllm_test_case_counter = -1 + test_results: List[Union[TestResult, MLLMTestCase]] = [] tasks = [] if show_indicator and _use_bar_indicator: @@ -506,6 +593,24 @@ async def a_execute_test_cases( pbar=pbar, ) tasks.append(asyncio.create_task(task)) + + elif isinstance(test_case, MLLMTestCase): + mllm_test_case_counter += 1 + copied_multimodal_metrics: List[ + BaseMultimodalMetric + ] = copy_metrics(mllm_metrics) + task = a_execute_mllm_test_cases( + metrics=copied_multimodal_metrics, + test_case=test_case, + test_run_manager=test_run_manager, + test_results=test_results, + count=mllm_test_case_counter, + ignore_errors=ignore_errors, + show_indicator=show_indicator, + _use_bar_indicator=_use_bar_indicator, + pbar=pbar, + ) + tasks.append(asyncio.create_task(task)) await asyncio.sleep(throttle_value) await asyncio.gather(*tasks) @@ -553,6 +658,24 @@ async def a_execute_test_cases( show_indicator=show_indicator, ) tasks.append(asyncio.create_task((task))) + + elif isinstance(test_case, MLLMTestCase): + mllm_test_case_counter += 1 + copied_multimodal_metrics: List[ + BaseMultimodalMetric + ] = copy_metrics(mllm_metrics) + task = a_execute_mllm_test_cases( + metrics=copied_multimodal_metrics, + test_case=test_case, + test_run_manager=test_run_manager, + test_results=test_results, + count=mllm_test_case_counter, + ignore_errors=ignore_errors, + _use_bar_indicator=_use_bar_indicator, + show_indicator=show_indicator, + ) + tasks.append(asyncio.create_task(task)) + await asyncio.sleep(throttle_value) await asyncio.gather(*tasks) @@ -563,7 +686,7 @@ async def a_execute_llm_test_cases( metrics: List[BaseMetric], test_case: LLMTestCase, test_run_manager: TestRunManager, - test_results: List[TestResult], + test_results: List[Union[TestResult, MLLMTestCase]], count: int, test_run: TestRun, ignore_errors: bool, @@ -647,7 +770,7 @@ async def a_execute_conversational_test_cases( metrics: List[BaseConversationalMetric], test_case: ConversationalTestCase, test_run_manager: TestRunManager, - test_results: List[TestResult], + test_results: List[Union[TestResult, MLLMTestCase]], count: int, ignore_errors: bool, show_indicator: bool, @@ -689,9 +812,53 @@ async def a_execute_conversational_test_cases( pbar.update(1) +async def a_execute_mllm_test_cases( + metrics: List[BaseMultimodalMetric], + test_case: MLLMTestCase, + test_run_manager: TestRunManager, + test_results: List[Union[TestResult, MLLMTestCase]], + count: int, + ignore_errors: bool, + show_indicator: bool, + _use_bar_indicator: bool, + pbar: Optional[tqdm_asyncio] = None, +): + show_metrics_indicator = show_indicator and not _use_bar_indicator + + for metric in metrics: + metric.error = None # Reset metric error + + api_test_case: MLLMApiTestCase = create_api_test_case( + test_case, count + ) + test_start_time = time.perf_counter() + await measure_metrics_with_indicator( + metrics=metrics, + test_case=test_case, + cached_test_case=None, + ignore_errors=ignore_errors, + show_indicator=show_metrics_indicator, + ) + for metric in metrics: + metric_data = create_metric_data(metric) + api_test_case.update_metric_data(metric_data) + + test_end_time = time.perf_counter() + if len(metrics) > 0: + run_duration = test_end_time - test_start_time + api_test_case.update_run_duration(run_duration) + + ### Update Test Run ### + test_run_manager.update_test_run(api_test_case, test_case) + test_results.append(create_test_result(api_test_case)) + + if pbar is not None: + pbar.update(1) + + def assert_test( - test_case: Union[LLMTestCase, ConversationalTestCase], - metrics: List[Union[BaseMetric, BaseConversationalMetric]], + test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], + metrics: List[Union[BaseMetric, BaseConversationalMetric, BaseMultimodalMetric]], run_async: bool = True, ): if run_async: @@ -747,7 +914,7 @@ def assert_test( def evaluate( - test_cases: List[Union[LLMTestCase, ConversationalTestCase]], + test_cases: List[Union[LLMTestCase, ConversationalTestCase, MLLMTestCase]], metrics: List[BaseMetric], hyperparameters: Optional[Dict[str, Union[str, int, float]]] = None, run_async: bool = True, @@ -820,7 +987,7 @@ def evaluate( return test_results -def print_test_result(test_result: TestResult): +def print_test_result(test_result: Union[TestResult, MLLMTestResult]): print("") print("=" * 70 + "\n") print("Metrics Summary\n") @@ -849,21 +1016,28 @@ def print_test_result(test_result: TestResult): ) print("") - if test_result.conversational: - print("For conversational test case:\n") - print( - f" - Unable to print conversational test case. Login to Confident AI (https://app.confident-ai.com) to view conversational evaluations in full." - ) - else: + + if isinstance(test_result, MLLMTestResult): print("For test case:\n") print(f" - input: {test_result.input}") print(f" - actual output: {test_result.actual_output}") - print(f" - expected output: {test_result.expected_output}") - print(f" - context: {test_result.context}") - print(f" - retrieval context: {test_result.retrieval_context}") + + else: + if test_result.conversational: + print("For conversational test case:\n") + print( + f" - Unable to print conversational test case. Login to Confident AI (https://app.confident-ai.com) to view conversational evaluations in full." + ) + else: + print("For test case:\n") + print(f" - input: {test_result.input}") + print(f" - actual output: {test_result.actual_output}") + print(f" - expected output: {test_result.expected_output}") + print(f" - context: {test_result.context}") + print(f" - retrieval context: {test_result.retrieval_context}") -def aggregate_metric_pass_rates(test_results: List[TestResult]) -> dict: +def aggregate_metric_pass_rates(test_results: List[Union[TestResult, MLLMTestResult]]) -> dict: metric_counts = {} metric_successes = {} diff --git a/deepeval/metrics/__init__.py b/deepeval/metrics/__init__.py index e968f4bb..2c8437bf 100644 --- a/deepeval/metrics/__init__.py +++ b/deepeval/metrics/__init__.py @@ -1,4 +1,4 @@ -from .base_metric import BaseMetric, BaseConversationalMetric +from .base_metric import BaseMetric, BaseConversationalMetric, BaseMultimodalMetric from .bias.bias import BiasMetric from .toxicity.toxicity import ToxicityMetric @@ -12,6 +12,7 @@ from .contextual_precision.contextual_precision import ContextualPrecisionMetric from .knowledge_retention.knowledge_retention import KnowledgeRetentionMetric from .tool_correctness.tool_correctness import ToolCorrectnessMetric +from .viescore.viescore import VIEScore, VIEScoreTask from .conversation_relevancy.conversation_relevancy import ( ConversationRelevancyMetric, ) diff --git a/deepeval/metrics/base_metric.py b/deepeval/metrics/base_metric.py index d361efcc..46884c1c 100644 --- a/deepeval/metrics/base_metric.py +++ b/deepeval/metrics/base_metric.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Optional, Dict -from deepeval.test_case import LLMTestCase, ConversationalTestCase +from deepeval.test_case import LLMTestCase, ConversationalTestCase, MLLMTestCase class BaseMetric: @@ -74,3 +74,47 @@ def is_successful(self) -> bool: @property def __name__(self): return "Base Conversational Metric" + +class BaseMultimodalMetric: + score: Optional[float] = None + score_breakdown: Dict = None + reason: Optional[str] = None + success: Optional[bool] = None + evaluation_model: Optional[str] = None + strict_mode: bool = False + async_mode: bool = True + verbose_mode: bool = True + include_reason: bool = False + error: Optional[str] = None + evaluation_cost: Optional[float] = None + verbose_logs: Optional[str] = None + + @property + def threshold(self) -> float: + return self._threshold + + @threshold.setter + def threshold(self, value: float): + self._threshold = value + + @abstractmethod + def measure( + self, test_case: MLLMTestCase, *args, **kwargs + ) -> float: + raise NotImplementedError + + @abstractmethod + async def a_measure( + self, test_case: MLLMTestCase, *args, **kwargs + ) -> float: + raise NotImplementedError( + f"Async execution for {self.__class__.__name__} not supported yet. Please set 'async_mode' to 'False'." + ) + + @abstractmethod + def is_successful(self) -> bool: + raise NotImplementedError + + @property + def __name__(self): + return "Base Multimodal Metric" \ No newline at end of file diff --git a/deepeval/metrics/indicator.py b/deepeval/metrics/indicator.py index 878ef0b6..ddd02206 100644 --- a/deepeval/metrics/indicator.py +++ b/deepeval/metrics/indicator.py @@ -6,8 +6,8 @@ import time import asyncio -from deepeval.metrics import BaseMetric, BaseConversationalMetric -from deepeval.test_case import LLMTestCase, ConversationalTestCase +from deepeval.metrics import BaseMetric, BaseConversationalMetric, BaseMultimodalMetric +from deepeval.test_case import LLMTestCase, ConversationalTestCase, MLLMTestCase from deepeval.test_run.cache import CachedTestCase, Cache from deepeval.telemetry import capture_metric_type @@ -53,8 +53,8 @@ def metric_progress_indicator( async def measure_metric_task( task_id, progress, - metric: Union[BaseMetric, BaseConversationalMetric], - test_case: Union[LLMTestCase, ConversationalTestCase], + metric: Union[BaseMetric, BaseConversationalMetric, BaseMultimodalMetric], + test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], cached_test_case: Union[CachedTestCase, None], ignore_errors: bool, ): @@ -109,8 +109,8 @@ async def measure_metric_task( async def measure_metrics_with_indicator( - metrics: List[Union[BaseMetric, BaseConversationalMetric]], - test_case: Union[LLMTestCase, ConversationalTestCase], + metrics: List[Union[BaseMetric, BaseConversationalMetric, BaseMultimodalMetric]], + test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], cached_test_case: Union[CachedTestCase, None], ignore_errors: bool, show_indicator: bool, @@ -172,8 +172,8 @@ async def measure_metrics_with_indicator( async def safe_a_measure( - metric: Union[BaseMetric, BaseConversationalMetric], - tc: Union[LLMTestCase, ConversationalTestCase], + metric: Union[BaseMetric, BaseConversationalMetric, BaseMultimodalMetric], + tc: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], ignore_errors: bool, ): try: diff --git a/deepeval/metrics/utils.py b/deepeval/metrics/utils.py index 1d1ef23f..f39c0402 100644 --- a/deepeval/metrics/utils.py +++ b/deepeval/metrics/utils.py @@ -1,21 +1,24 @@ import inspect import json from typing import Any, Dict, Optional, List, Union, Tuple -from deepeval.models import GPTModel, DeepEvalBaseLLM +from deepeval.models import GPTModel, DeepEvalBaseLLM, MultimodalGPTModel, DeepEvalBaseMLLM from deepeval.models.gpt_model_schematic import SchematicGPTModel -from deepeval.metrics import BaseMetric, BaseConversationalMetric +from deepeval.metrics import BaseMetric, BaseConversationalMetric, BaseMultimodalMetric from deepeval.test_case import ( LLMTestCase, LLMTestCaseParams, + MLLMTestCase, + MLLMTestCaseParams, ConversationalTestCase, Message, ) +from deepeval.types import Image def copy_metrics( - metrics: Union[List[BaseMetric], List[BaseConversationalMetric]] -) -> Union[List[BaseMetric], List[BaseConversationalMetric]]: + metrics: Union[List[BaseMetric], List[BaseConversationalMetric], List[BaseMultimodalMetric]] +) -> Union[List[BaseMetric], List[BaseConversationalMetric], List[BaseMultimodalMetric]]: copied_metrics = [] for metric in metrics: metric_class = type(metric) @@ -157,6 +160,54 @@ def check_llm_test_case_params( raise ValueError(error_str) +def check_mllm_test_case_params( + test_case: MLLMTestCase, + test_case_params: List[MLLMTestCaseParams], + input_image_count: int, + actual_output_image_count: int, + metric: BaseMetric, +): + count = 0 + for ele in test_case.input: + if isinstance(ele, Image): + count += 1 + if count != input_image_count: + error_str = f"Can only evaluate test cases with '{input_image_count}' input images using the '{metric.__name__}' metric. `{count}` found." + raise ValueError(error_str) + + count = 0 + for ele in test_case.actual_output: + if isinstance(ele, Image): + count += 1 + if count != actual_output_image_count: + error_str = f"Unable to evaluate test cases with '{actual_output_image_count}' output images using the '{metric.__name__}' metric. `{count}` found." + raise ValueError(error_str) + + if isinstance(test_case, MLLMTestCase) is False: + error_str = f"Unable to evaluate test cases that are not of type 'MLLMTestCase' using the '{metric.__name__}' metric." + metric.error = error_str + raise ValueError(error_str) + + missing_params = [] + for param in test_case_params: + if getattr(test_case, param.value) is None: + missing_params.append(f"'{param.value}'") + + if missing_params: + if len(missing_params) == 1: + missing_params_str = missing_params[0] + elif len(missing_params) == 2: + missing_params_str = " and ".join(missing_params) + else: + missing_params_str = ( + ", ".join(missing_params[:-1]) + ", and " + missing_params[-1] + ) + + error_str = f"{missing_params_str} cannot be None for the '{metric.__name__}' metric" + metric.error = error_str + raise ValueError(error_str) + + def trimAndLoadJson( input_string: str, metric: Optional[BaseMetric] = None ) -> Any: @@ -196,6 +247,32 @@ def initialize_model( return GPTModel(model=model), True +def initialize_multimodal_model( + model: Optional[Union[str, DeepEvalBaseMLLM, MultimodalGPTModel]] = None, +) -> Tuple[DeepEvalBaseLLM, bool]: + """ + Returns a tuple of (initialized DeepEvalBaseMLLM, using_native_model boolean) + """ + # If model is a MultimodalGPTModel, it should be deemed as using native model + if isinstance(model, MultimodalGPTModel): + return model, True + # If model is a DeepEvalBaseMLLM but not a MultimodalGPTModel, we can not assume it is a native model + if isinstance(model, DeepEvalBaseMLLM): + return model, False + # Otherwise (the model is a string or None), we initialize a GPTModel and use as a native model + return MultimodalGPTModel(model=model), True + + +def print_verbose_logs(metric: str, logs: str): + print("*" * 50) + print(f"{metric} Verbose Logs") + print("*" * 50) + print("") + print(logs) + print("") + print("=" * 70) + + def initialize_schematic_model( model: Optional[Union[str, DeepEvalBaseLLM, SchematicGPTModel]] = None, ) -> Tuple[DeepEvalBaseLLM, bool]: diff --git a/deepeval/metrics/viescore/__init__.py b/deepeval/metrics/viescore/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deepeval/metrics/viescore/schema.py b/deepeval/metrics/viescore/schema.py new file mode 100644 index 00000000..2dcc1125 --- /dev/null +++ b/deepeval/metrics/viescore/schema.py @@ -0,0 +1,6 @@ +from typing import List +from pydantic import BaseModel, Field + +class ReasonScore(BaseModel): + reasoning: str + score: List[float] diff --git a/deepeval/metrics/viescore/task.py b/deepeval/metrics/viescore/task.py new file mode 100644 index 00000000..f171a5b7 --- /dev/null +++ b/deepeval/metrics/viescore/task.py @@ -0,0 +1,5 @@ +from enum import Enum + +class VIEScoreTask(Enum): + TEXT_TO_IMAGE_GENERATION = "Text2Image Generation" + TEXT_TO_IMAGE_EDITING = "Text2Image Editing" diff --git a/deepeval/metrics/viescore/template.py b/deepeval/metrics/viescore/template.py new file mode 100644 index 00000000..604bcc24 --- /dev/null +++ b/deepeval/metrics/viescore/template.py @@ -0,0 +1,76 @@ +from .task import VIEScoreTask +import textwrap + +class VIEScoreTemplate: + + context = textwrap.dedent(""" + You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. + All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials. + + You will have to give your output in this way (Keep your reasoning concise and short.): + { + "score" : [...], + "reasoning" : "..." + } + """) + + @staticmethod + def generate_semantic_consistency_evaluation_results(text_prompt: str, task: VIEScoreTask): + if task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: + return textwrap.dedent(f""" + {VIEScoreTemplate.context} + + RULES: + + The image is an AI-generated image according to the text prompt. + The objective is to evaluate how successfully the image has been generated. + + From scale 0 to 10: + A score from 0 to 10 will be given based on the success in following the prompt. + (0 indicates that the AI generated image does not follow the prompt at all. 10 indicates the AI generated image follows the prompt perfectly.) + + Put the score in a list such that output score = [score]. + + Text Prompt: {text_prompt} + """) + elif task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: + return textwrap.dedent(f""" + {VIEScoreTemplate.context} + + RULES: + + Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. + The objective is to evaluate how successfully the editing instruction has been executed in the second image. + + From scale 0 to 10: + A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) + A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) + Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + + Editing instruction: {text_prompt} + """) + + @staticmethod + def generate_perceptual_quality_evaluation_results(): + return textwrap.dedent(f""" + {VIEScoreTemplate.context} + + RULES: + + The image is an AI-generated image. + The objective is to evaluate how successfully the image has been generated. + + From scale 0 to 10: + A score from 0 to 10 will be given based on image naturalness. + ( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. + ) + A second score from 0 to 10 will rate the image artifacts. + ( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. + ) + Put the score in a list such that output score = [naturalness, artifacts] + """) + diff --git a/deepeval/metrics/viescore/viescore.py b/deepeval/metrics/viescore/viescore.py new file mode 100644 index 00000000..42155ad5 --- /dev/null +++ b/deepeval/metrics/viescore/viescore.py @@ -0,0 +1,282 @@ +from typing import Optional, List, Tuple, Union +import math +import textwrap + +from deepeval.types import Image +from deepeval.metrics import BaseMultimodalMetric +from deepeval.test_case import ( + MLLMTestCaseParams, MLLMTestCase +) +from deepeval.metrics.viescore.template import VIEScoreTemplate +from deepeval.utils import get_or_create_event_loop +from deepeval.metrics.utils import ( + construct_verbose_logs, + trimAndLoadJson, + check_mllm_test_case_params, + initialize_multimodal_model, +) +from deepeval.models import DeepEvalBaseMLLM +from deepeval.metrics.viescore.schema import ReasonScore +from deepeval.metrics.viescore.task import VIEScoreTask +from deepeval.metrics.indicator import metric_progress_indicator + +required_params: List[MLLMTestCaseParams] = [ + MLLMTestCaseParams.INPUT, + MLLMTestCaseParams.ACTUAL_OUTPUT, +] + +class VIEScore(BaseMultimodalMetric): + def __init__( + self, + model: Optional[Union[str, DeepEvalBaseMLLM]] = None, + task: VIEScoreTask = VIEScoreTask.TEXT_TO_IMAGE_GENERATION, + threshold: float = 0.5, + async_mode: bool = True, + strict_mode: bool = False, + verbose_mode: bool = False, + _include_VIEScore_task_name: bool = True + ): + self.model, self.using_native_model = initialize_multimodal_model(model) + self.evaluation_model = self.model.get_model_name() + self.threshold = 1 if strict_mode else threshold + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose_mode = verbose_mode + self.task = task + self._include_VIEScore_task_name = _include_VIEScore_task_name + + def measure( + self, + test_case: MLLMTestCase, + _show_indicator: bool = True + ) -> float: + if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: + check_mllm_test_case_params(test_case, required_params, 0, 1, self) + elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: + check_mllm_test_case_params(test_case, required_params, 1, 1, self) + + self.evaluation_cost = 0 if self.using_native_model else None + with metric_progress_indicator(self, _show_indicator=_show_indicator): + if self.async_mode: + loop = get_or_create_event_loop() + loop.run_until_complete( + self.a_measure(test_case, _show_indicator=False) + ) + else: + input_texts, input_images = self.separate_images_from_text(test_case.input) + _, output_images = self.separate_images_from_text(test_case.actual_output) + + self.SC_scores, self.SC_reasoning = self._evaluate_semantic_consistency( + "\n".join(input_texts), + None if len(input_images) == 0 else input_images[0], + output_images[0], + ) + self.PQ_scores, self.PQ_reasoning = self._evaluate_perceptual_quality(output_images[0]) + self.score = self._calculate_score() + self.score = ( + 0 + if self.strict_mode and self.score < self.threshold + else self.score + ) + self.reason = self._generate_reason() + self.success = self.score >= self.threshold + self.verbose_logs = construct_verbose_logs( + self, + steps=[ + f"Semantic Consistency Scores:\n{self.SC_scores}", + f"Semantic Consistency Reasoning:\n{self.SC_reasoning}", + f"Perceptual Quality Scores:\n{self.SC_scores}", + f"Perceptual Quality Reasoning:\n{self.PQ_reasoning}", + f"Score: {self.score}\nReason: {self.reason}", + ], + ) + return self.score + + async def a_measure( + self, + test_case: MLLMTestCase, + _show_indicator: bool = True, + ) -> float: + if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: + check_mllm_test_case_params(test_case, required_params, 0, 1, self) + elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: + check_mllm_test_case_params(test_case, required_params, 1, 1, self) + + self.evaluation_cost = 0 if self.using_native_model else None + with metric_progress_indicator( + self, + async_mode=True, + _show_indicator=_show_indicator, + ): + input_texts, input_images = self.separate_images_from_text(test_case.input) + _, output_images = self.separate_images_from_text(test_case.actual_output) + + self.SC_scores, self.SC_reasoning = await self._a_evaluate_semantic_consistency( + "\n".join(input_texts), + None if len(input_images) == 0 else input_images[0], + output_images[0], + ) + self.PQ_scores, self.PQ_reasoning = await self._a_evaluate_perceptual_quality(output_images[0]) + self.score = self._calculate_score() + self.score = ( + 0 + if self.strict_mode and self.score < self.threshold + else self.score + ) + self.reason = self._generate_reason() + self.success = self.score >= self.threshold + self.verbose_logs = construct_verbose_logs( + self, + steps=[ + f"Semantic Consistency Scores:\n{self.SC_scores}", + f"Semantic Consistency Reasoning:\n{self.SC_reasoning}", + f"Perceptual Quality Scores:\n{self.SC_scores}", + f"Perceptual Quality Reasoning:\n{self.PQ_reasoning}", + f"Score: {self.score}\nReason: {self.reason}", + ], + ) + return self.score + + def separate_images_from_text( + self, + multimodal_list: List[Union[Image, str]] + ) -> Tuple[List[str], List[Image]]: + images: List[Image] = [] + texts: List[str] = [] + for item in multimodal_list: + if isinstance(item, Image): + images.append(item) + elif isinstance(item, str): + texts.append(item) + return texts, images + + async def _a_evaluate_semantic_consistency( + self, + text_prompt: str, + image_input: Image, + actual_image_output: Image + ) -> Tuple[List[int], str]: + images: List[Image] = [] + if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: + images.append(image_input) + elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: + images.extend([image_input, actual_image_output]) + prompt = [VIEScoreTemplate.generate_semantic_consistency_evaluation_results( + text_prompt=text_prompt, task=self.task + )] + if self.using_native_model: + res, cost = await self.model.a_generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = await self.model.a_generate(prompt + images, schema=ReasonScore) + return res.score, res.reasoning + except TypeError: + res = await self.model.a_generate(prompt + images, input_text=prompt) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + def _evaluate_semantic_consistency( + self, + text_prompt: str, + image_input: Image, + actual_image_output: Image + ) -> Tuple[List[int], str]: + images: List[Image] = [] + if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: + images.append(image_input) + elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: + images.extend([image_input, actual_image_output]) + prompt = [VIEScoreTemplate.generate_semantic_consistency_evaluation_results( + text_prompt=text_prompt, task=self.task + )] + if self.using_native_model: + res, cost = self.model.generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = self.model.generate(prompt + images, schema=ReasonScore) + return res.score, res.reasoning + except TypeError: + res = self.model.generate(prompt + images) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + async def _a_evaluate_perceptual_quality( + self, + actual_image_output: Image + ) -> Tuple[List[int], str]: + images: List[Image] = [actual_image_output] + prompt = [VIEScoreTemplate.generate_perceptual_quality_evaluation_results()] + if self.using_native_model: + res, cost = await self.model.a_generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = await self.model.a_generate(prompt + images, schema=ReasonScore) + return res.score, res.reasoning + except TypeError: + res = await self.model.a_generate(prompt + images) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + def _evaluate_perceptual_quality( + self, + actual_image_output: Image + ) -> Tuple[List[int], str]: + images: List[Image] = [actual_image_output] + prompt = [VIEScoreTemplate.generate_perceptual_quality_evaluation_results()] + if self.using_native_model: + res, cost = self.model.generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = self.model.generate(prompt + images, schema=ReasonScore) + return res.score, res.reasoning + except TypeError: + res = self.model.generate(prompt + images) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + def _calculate_score( + self + ) -> List[str]: + min_SC_score = min(self.SC_scores) + min_PQ_score = min(self.PQ_scores) + return math.sqrt(min_SC_score * min_PQ_score)/10 + + def is_successful(self) -> bool: + if self.error is not None: + self.success = False + else: + try: + self.score >= self.threshold + except: + self.success = False + return self.success + + def _generate_reason( + self, + ) -> Tuple[List[float], str]: + return textwrap.dedent(f""" + The overall score is {self.score:.2f} because the lowest score from semantic consistency was {min(self.SC_scores)} + and the lowest score from perceptual quality was {min(self.PQ_scores)}. These scores were combined to reflect the + overall effectiveness and quality of the AI-generated image(s). + Reason for Semantic Consistency score: {self.SC_reasoning} + Reason for Perceptual Quality score: {self.PQ_reasoning} + """) + + @property + def __name__(self): + if self._include_VIEScore_task_name: + return f"{self.task.value} (VIEScore)" + else: + return "VIEScore" \ No newline at end of file diff --git a/deepeval/models/__init__.py b/deepeval/models/__init__.py index 4ae24543..44bf6c82 100644 --- a/deepeval/models/__init__.py +++ b/deepeval/models/__init__.py @@ -1,9 +1,10 @@ from deepeval.models.base_model import ( DeepEvalBaseModel, DeepEvalBaseLLM, + DeepEvalBaseMLLM, DeepEvalBaseEmbeddingModel, ) -from deepeval.models.gpt_model import GPTModel +from deepeval.models.gpt_model import GPTModel, MultimodalGPTModel from deepeval.models.openai_embedding_model import OpenAIEmbeddingModel # TODO: uncomment out once fixed diff --git a/deepeval/models/base_model.py b/deepeval/models/base_model.py index 58bfed4a..8e5ea7e1 100644 --- a/deepeval/models/base_model.py +++ b/deepeval/models/base_model.py @@ -73,6 +73,31 @@ def batch_generate(self, *args, **kwargs) -> List[str]: def get_model_name(self, *args, **kwargs) -> str: pass +class DeepEvalBaseMLLM(ABC): + def __init__(self, model_name: Optional[str] = None, *args, **kwargs): + self.model_name = model_name + + @abstractmethod + def generate(self, *args, **kwargs) -> str: + """Runs the model to output MLLM response. + + Returns: + A string. + """ + pass + + @abstractmethod + async def a_generate(self, *args, **kwargs) -> str: + """Runs the model to output MLLM response. + + Returns: + A string. + """ + pass + + @abstractmethod + def get_model_name(self, *args, **kwargs) -> str: + pass class DeepEvalBaseEmbeddingModel(ABC): def __init__(self, model_name: Optional[str] = None, *args, **kwargs): diff --git a/deepeval/models/gpt_model.py b/deepeval/models/gpt_model.py index 05a8aa29..6ad43704 100644 --- a/deepeval/models/gpt_model.py +++ b/deepeval/models/gpt_model.py @@ -1,24 +1,29 @@ import logging +import PIL.Image import openai +import base64 +from io import BytesIO +from openai import OpenAI, AsyncOpenAI +from typing import Optional, Tuple, List, Union -from typing import Optional, Tuple, List from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_community.callbacks import get_openai_callback from langchain.schema import HumanMessage from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatResult from tenacity import retry, retry_if_exception_type, wait_exponential_jitter +from PIL.Image import Image as PILImage +import PIL from deepeval.key_handler import KeyValues, KEY_FILE_HANDLER -from deepeval.models import DeepEvalBaseLLM - +from deepeval.models import DeepEvalBaseLLM, DeepEvalBaseMLLM +from deepeval.types import Image def log_retry_error(retry_state): logging.error( f"OpenAI rate limit exceeded. Retrying: {retry_state.attempt_number} time(s)..." ) - valid_gpt_models = [ "gpt-4o-mini", "gpt-4o", @@ -36,8 +41,22 @@ def log_retry_error(retry_state): "gpt-3.5-turbo-0125", ] -default_gpt_model = "gpt-4o" +model_pricing = { + "gpt-4o-mini": {"input": 0.150 / 1e6, "output": 0.600 / 1e6}, + "gpt-4o": {"input": 5.00 / 1e6, "output": 15.00 / 1e6}, + "gpt-4-turbo": {"input": 10.00 / 1e6, "output": 30.00 / 1e6}, + "gpt-4-turbo-preview": {"input": 10.00 / 1e6, "output": 30.00 / 1e6}, + "gpt-4-0125-preview": {"input": 10.00 / 1e6, "output": 30.00 / 1e6}, + "gpt-4-1106-preview": {"input": 10.00 / 1e6, "output": 30.00 / 1e6}, + "gpt-4": {"input": 30.00 / 1e6, "output": 60.00 / 1e6}, + "gpt-4-32k": {"input": 60.00 / 1e6, "output": 120.00 / 1e6}, + "gpt-3.5-turbo-1106": {"input": 1.00 / 1e6, "output": 2.00 / 1e6}, + "gpt-3.5-turbo": {"input": 1.50 / 1e6, "output": 2.00 / 1e6}, + "gpt-3.5-turbo-16k": {"input": 3.00 / 1e6, "output": 4.00 / 1e6}, + "gpt-3.5-turbo-0125": {"input": 0.50 / 1e6, "output": 1.50 / 1e6}, +} +default_gpt_model = "gpt-4o" # Adding a custom class to enable json mode in Ollama during API calls class CustomChatOpenAI(ChatOpenAI): @@ -230,3 +249,156 @@ def get_model_name(self): return "local model" elif self.model_name: return self.model_name + +############################################### +# Multimodal Model +############################################### + + +valid_multimodal_gpt_models = [ + "gpt-4o-mini", + "gpt-4o", + "gpt-4-turbo", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-32k", + "gpt-4-0613", + "gpt-4-32k-0613", +] + +default_multimodal_gpt_model = "gpt-4o" + +class MultimodalGPTModel(DeepEvalBaseMLLM): + def __init__( + self, + model: Optional[str] = None, + _openai_api_key: Optional[str] = None, + *args, + **kwargs, + ): + model_name = None + if isinstance(model, str): + model_name = model + if model_name not in valid_multimodal_gpt_models: + raise ValueError( + f"Invalid model. Available Multimodal GPT models: {', '.join(model for model in valid_multimodal_gpt_models)}" + ) + elif model is None: + model_name = default_multimodal_gpt_model + + self._openai_api_key = _openai_api_key + self.args = args + self.kwargs = kwargs + self.model_name = model_name + + def calculate_cost(self, input_tokens: int, output_tokens: int, model_name: str) -> float: + pricing = model_pricing.get(model_name, model_pricing["gpt-4o"]) # Default to 'gpt-4o' if model not found + input_cost = input_tokens * pricing["input"] + output_cost = output_tokens * pricing["output"] + return input_cost + output_cost + + def calculate_image_tokens(self, image: PILImage, detail: str = 'auto') -> int: + width, height = image.size + + def high_detail_cost() -> int: + if max(width, height) > 2048: + scale_factor = 2048 / max(width, height) + width = int(width * scale_factor) + height = int(height * scale_factor) + scale_factor = 768 / min(width, height) + width = int(width * scale_factor) + height = int(height * scale_factor) + tiles = (width // 512) * (height // 512) + return 85 + (170 * tiles) + + if detail == 'low': + return 85 + if detail == 'high': + return high_detail_cost() + if width > 1024 or height > 1024: + return high_detail_cost() + return 85 + + + def encode_pil_image(self, pil_image: PILImage): + image_buffer = BytesIO() + pil_image.save(image_buffer, format='JPEG') + image_bytes = image_buffer.getvalue() + base64_encoded_image = base64.b64encode(image_bytes).decode('utf-8') + return base64_encoded_image + + def generate_prompt(self, multimodal_input: List[Union[str, Image]] = []): + prompt = [] + for ele in multimodal_input: + if isinstance(ele, str): + prompt.append({ + "type": "text", + "text": ele + }) + elif isinstance(ele, Image): + if ele.local == True: + image = PIL.Image.open(ele.url) + visual_dict = { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{self.encode_pil_image(image)}"} + } + else: + visual_dict = { + "type": "image_url", + "image_url": { + "url": ele.url, + }, + }, + prompt.append(visual_dict) + return prompt + + @retry( + wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10), + retry=retry_if_exception_type(openai.RateLimitError), + after=log_retry_error, + ) + def generate(self, multimodal_input: List[Union[str, Image]]) -> Tuple[str, float]: + client = OpenAI() + prompt = self.generate_prompt(multimodal_input) + response = client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "user", + "content": prompt + } + ], + ) + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + total_cost = self.calculate_cost(input_tokens, output_tokens, self.model_name) + generated_text = response.choices[0].message.content + return generated_text, total_cost + + @retry( + wait=wait_exponential_jitter(initial=1, exp_base=2, jitter=2, max=10), + retry=retry_if_exception_type(openai.RateLimitError), + after=log_retry_error, + ) + async def a_generate(self, multimodal_input: List[Union[str, Image]]) -> Tuple[str, float]: + client = AsyncOpenAI() + prompt = self.generate_prompt(multimodal_input) + response = await client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "user", + "content": prompt + } + ], + ) + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + total_cost = self.calculate_cost(input_tokens, output_tokens, self.model_name) + generated_text = response.choices[0].message.content + return generated_text, total_cost + + def get_model_name(self): + return self.model_name diff --git a/deepeval/test_case/__init__.py b/deepeval/test_case/__init__.py index d4ad5009..05719c65 100644 --- a/deepeval/test_case/__init__.py +++ b/deepeval/test_case/__init__.py @@ -1,2 +1,3 @@ from .llm_test_case import LLMTestCase, LLMTestCaseParams from .conversational_test_case import ConversationalTestCase, Message +from .mllm_test_case import MLLMTestCase, MLLMTestCaseParams diff --git a/deepeval/test_case/conversational_test_case.py b/deepeval/test_case/conversational_test_case.py index 598ac578..55eb234f 100644 --- a/deepeval/test_case/conversational_test_case.py +++ b/deepeval/test_case/conversational_test_case.py @@ -4,7 +4,6 @@ from deepeval.test_case import LLMTestCase - @dataclass class Message: llm_test_case: LLMTestCase @@ -14,7 +13,6 @@ def __post_init__(self): # prevent user referencing the wrong LLM test case in a conversation self.llm_test_case = deepcopy(self.llm_test_case) - @dataclass class ConversationalTestCase: messages: List[Message] diff --git a/deepeval/test_case/llm_test_case.py b/deepeval/test_case/llm_test_case.py index e5fd5119..9bb2412c 100644 --- a/deepeval/test_case/llm_test_case.py +++ b/deepeval/test_case/llm_test_case.py @@ -3,7 +3,6 @@ from typing import List, Optional, Dict from enum import Enum - class LLMTestCaseParams(Enum): INPUT = "input" ACTUAL_OUTPUT = "actual_output" @@ -14,7 +13,6 @@ class LLMTestCaseParams(Enum): EXPECTED_TOOLS = "expected_tools" REASONING = "reasoning" - @dataclass class LLMTestCase: input: str diff --git a/deepeval/test_case/mllm_test_case.py b/deepeval/test_case/mllm_test_case.py new file mode 100644 index 00000000..1d181a53 --- /dev/null +++ b/deepeval/test_case/mllm_test_case.py @@ -0,0 +1,20 @@ +from pydantic import Field +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Union +from enum import Enum + +from deepeval.types import Image + +class MLLMTestCaseParams(Enum): + INPUT = "input" + ACTUAL_OUTPUT = "actual_output" + +@dataclass +class MLLMTestCase: + input: List[Union[str, Image]] + actual_output: List[Union[str, Image]] + additional_metadata: Optional[Dict] = None + comments: Optional[str] = None + _dataset_rank: Optional[int] = field(default=None, repr=False) + _dataset_alias: Optional[str] = field(default=None, repr=False) + _dataset_id: Optional[str] = field(default=None, repr=False) \ No newline at end of file diff --git a/deepeval/test_run/__init__.py b/deepeval/test_run/__init__.py index 2eea5a1a..2766b96f 100644 --- a/deepeval/test_run/__init__.py +++ b/deepeval/test_run/__init__.py @@ -3,6 +3,7 @@ global_test_run_manager, TEMP_FILE_NAME, LLMApiTestCase, + MLLMApiTestCase, ConversationalApiTestCase, TestRunManager, ) diff --git a/deepeval/test_run/api.py b/deepeval/test_run/api.py index b425a358..6e1fad88 100644 --- a/deepeval/test_run/api.py +++ b/deepeval/test_run/api.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List, Union, Dict +from deepeval.types import Image class MetricData(BaseModel): name: str @@ -70,6 +71,49 @@ def update_run_duration(self, run_duration: float): self.run_duration = run_duration +class MLLMApiTestCase(BaseModel): + name: str + input: List[Union[str, Image]] = Field(..., alias="input") + actual_output: List[Union[str, Image]] = Field(..., alias="actualOutput") + success: Union[bool, None] = Field(None) + # make optional, not all test cases in a conversation will be evaluated + metrics_data: Union[List[MetricData], None] = Field( + None, alias="metricsData" + ) + run_duration: Union[float, None] = Field(None, alias="runDuration") + evaluation_cost: Union[float, None] = Field(None, alias="evaluationCost") + order: Union[int, None] = Field(None) + + # Allow arbitrary types + model_config = ConfigDict(arbitrary_types_allowed=True) + + def update_metric_data(self, metric_data: MetricData): + if self.metrics_data is None: + self.metrics_data = [metric_data] + else: + self.metrics_data.append(metric_data) + + if self.success is None: + # self.success will be None when it is a message + # in that case we will be setting success for the first time + self.success = metric_data.success + else: + if metric_data.success is False: + self.success = False + + evaluationCost = metric_data.evaluation_cost + if evaluationCost is None: + return + + if self.evaluation_cost is None: + self.evaluation_cost = evaluationCost + else: + self.evaluation_cost += evaluationCost + + def update_run_duration(self, run_duration: float): + self.run_duration = run_duration + + class ConversationalApiTestCase(BaseModel): name: str success: bool diff --git a/deepeval/test_run/test_run.py b/deepeval/test_run/test_run.py index 0d23b656..9cbe4986 100644 --- a/deepeval/test_run/test_run.py +++ b/deepeval/test_run/test_run.py @@ -10,15 +10,19 @@ from rich.table import Table from rich.console import Console from rich import print +import base64 +from io import BytesIO +from PIL import Image from deepeval.metrics import BaseMetric from deepeval.confident.api import Api, Endpoints, HttpMethods from deepeval.test_run.api import ( LLMApiTestCase, + MLLMApiTestCase, ConversationalApiTestCase, TestRunHttpResponse, ) -from deepeval.test_case import LLMTestCase, ConversationalTestCase +from deepeval.test_case import LLMTestCase, ConversationalTestCase, MLLMTestCase from deepeval.utils import ( delete_file_if_exists, get_is_running_deepeval, @@ -88,6 +92,9 @@ class TestRun(BaseModel): conversational_test_cases: List[ConversationalApiTestCase] = Field( alias="conversationalTestCases", default_factory=lambda: [] ) + mllm_test_cases: List[MLLMApiTestCase] = Field( + alias="MLLMTestCases", default_factory=lambda: [] + ) metrics_scores: List[MetricScores] = Field( default_factory=lambda: [], alias="metricsScores" ) @@ -100,10 +107,12 @@ class TestRun(BaseModel): dataset_id: Optional[str] = Field(None, alias="datasetId") def add_test_case( - self, api_test_case: Union[LLMApiTestCase, ConversationalApiTestCase] + self, api_test_case: Union[LLMApiTestCase, ConversationalApiTestCase, MLLMApiTestCase] ): if isinstance(api_test_case, ConversationalApiTestCase): self.conversational_test_cases.append(api_test_case) + elif isinstance(api_test_case, MLLMApiTestCase): + self.mllm_test_cases.append(api_test_case) else: if api_test_case.conversational_instance_id is not None: for conversational_test_case in self.conversational_test_cases: @@ -142,7 +151,7 @@ def add_test_case( self.evaluation_cost += api_test_case.evaluation_cost def set_dataset_properties( - self, test_case: Union[LLMTestCase, ConversationalTestCase] + self, test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase] ): if self.dataset_alias is None: self.dataset_alias = test_case._dataset_alias @@ -170,6 +179,12 @@ def sort_test_cases(self): if test_case.order is None: test_case.order = highest_order highest_order = test_case.order + 1 + # Optionally update order only if not already set + highest_order = 0 + for test_case in self.mllm_test_cases: + if test_case.order is None: + test_case.order = highest_order + highest_order = test_case.order + 1 def delete_test_case_instance_ids(self): for conversational_test_case in self.conversational_test_cases: @@ -223,8 +238,22 @@ def construct_metrics_scores(self) -> int: metrics_dict[name].append(score) else: metrics_dict[name] = [score] + + for test_case in self.mllm_test_cases: + if test_case.metrics_data is None: + continue + for metric_data in test_case.metrics_data: + name = metric_data.name + score = metric_data.score + if score is None: + continue + valid_scores += 1 + if name in metrics_dict: + metrics_dict[name].append(score) + else: + metrics_dict[name] = [score] - # metrics_scores combines both conversational and nonconvo scores + # metrics_scores combines both conversational and nonconvo and mllm scores # might need to separate in the future self.metrics_scores = [ MetricScores(metric=metric, scores=scores) @@ -249,6 +278,13 @@ def calculate_test_passes_and_fails(self): test_passed += 1 else: test_failed += 1 + + for test_case in self.mllm_test_cases: + if test_case.success is not None: + if test_case.success: + test_passed += 1 + else: + test_failed += 1 self.test_passed = test_passed self.test_failed = test_failed @@ -264,7 +300,8 @@ def save(self, f): @classmethod def load(cls, f): - return cls(**json.load(f)) + data: dict = json.load(f) + return cls(**data) class TestRunManager: @@ -336,8 +373,8 @@ def save_test_run(self): def update_test_run( self, - api_test_case: Union[LLMApiTestCase, ConversationalApiTestCase], - test_case: Union[LLMTestCase, ConversationalTestCase], + api_test_case: Union[LLMApiTestCase, ConversationalApiTestCase, MLLMApiTestCase], + test_case: Union[LLMTestCase, ConversationalTestCase, MLLMTestCase], ): if self.save_to_disk: try: @@ -380,6 +417,8 @@ def display_results_table(self, test_run: TestRun): table.add_column("Score", justify="left") table.add_column("Status", justify="left") table.add_column("Overall Success Rate", justify="left") + print(test_run.mllm_test_cases) + print(test_run.test_cases) for index, test_case in enumerate(test_run.test_cases): pass_count = 0 @@ -536,6 +575,59 @@ def display_results_table(self, test_run: TestRun): "", "", ) + + for index, test_case in enumerate(test_run.mllm_test_cases): + pass_count = 0 + fail_count = 0 + test_case_name = test_case.name + + for metric_data in test_case.metrics_data: + if metric_data.success: + pass_count += 1 + else: + fail_count += 1 + + table.add_row( + test_case_name, + "", + "", + "", + f"{round((100*pass_count)/(pass_count+fail_count),2)}%", + ) + + for metric_data in test_case.metrics_data: + if metric_data.error: + status = "[red]ERRORED[/red]" + elif metric_data.success: + status = "[green]PASSED[/green]" + else: + status = "[red]FAILED[/red]" + + evaluation_model = metric_data.evaluation_model + if evaluation_model is None: + evaluation_model = "n/a" + + if metric_data.score is not None: + metric_score = round(metric_data.score, 2) + else: + metric_score = None + + table.add_row( + "", + str(metric_data.name), + f"{metric_score} (threshold={metric_data.threshold}, evaluation model={evaluation_model}, reason={metric_data.reason}, error={metric_data.error})", + status, + "", + ) + + if index is not len(self.test_run.test_cases) - 1: + table.add_row( + "", + "", + "", + "", + "", + ) print(table) print( @@ -573,11 +665,13 @@ def post_test_run(self, test_run: TestRun): test_run.test_cases = initial_batch test_run.conversational_test_cases = initial_conversational_batch try: - body = test_run.model_dump(by_alias=True, exclude_none=True) + body = test_run.model_dump(by_alias=True, exclude_none=True, exclude={"MLLMTestCases"}) except AttributeError: # Pydantic version below 2.0 - body = test_run.dict(by_alias=True, exclude_none=True) - + body = test_run.dict(by_alias=True, exclude_none=True, exclude={"MLLMTestCases"}) + if 'MLLMTestCases' in body: + del body['MLLMTestCases'] + api = Api() result = api.send_request( method=HttpMethods.POST, @@ -689,6 +783,7 @@ def wrap_up_test_run(self, runDuration: float, display_table: bool = True): elif ( len(test_run.test_cases) == 0 and len(test_run.conversational_test_cases) == 0 + and len(test_run.mllm_test_cases) == 0 ): print("No test cases found, please try again.") delete_file_if_exists(self.temp_file_name) @@ -719,7 +814,9 @@ def wrap_up_test_run(self, runDuration: float, display_table: bool = True): self.save_test_run_locally() delete_file_if_exists(self.temp_file_name) - self.post_test_run(test_run) + + if len(test_run.test_cases) > 0 or len(test_run.conversational_test_cases) > 0: + self.post_test_run(test_run) global_test_run_manager = TestRunManager() diff --git a/deepeval/types.py b/deepeval/types.py index 952365c7..6b34c313 100644 --- a/deepeval/types.py +++ b/deepeval/types.py @@ -1,6 +1,30 @@ from enum import Enum - +from dataclasses import dataclass +from urllib.parse import urlparse +from typing import Optional +import os class Languages(Enum): ENGLISH = "English" SPANISH = "Spanish" + +@dataclass +class Image: + url: str + local: Optional[bool] = None + + def __post_init__(self): + if self.local == None: + self.local = self.is_local_path(self.url) + + @staticmethod + def is_local_path(url): + # Parse the URL + parsed_url = urlparse(url) + + # Check if it's a file scheme or an empty scheme with a local path + if parsed_url.scheme == 'file' or parsed_url.scheme == '': + # Check if the path exists on the filesystem + return os.path.exists(parsed_url.path) + + return False diff --git a/docs/docs/evaluation-test-cases.mdx b/docs/docs/evaluation-test-cases.mdx index 7708e347..c81e2b48 100644 --- a/docs/docs/evaluation-test-cases.mdx +++ b/docs/docs/evaluation-test-cases.mdx @@ -277,6 +277,71 @@ message = Message(llm_test_case=LLMTestCase(...), should_evaluate=True) Most metrics in `deepeval` are non-conversational metrics, and the reason why `should_evaluate` is defaulted to `True` for the **last message in a `ConversationalTestCase`**, is because often times users prefer evaluating the next best LLM response given the previous conversation context, instead of all `Message`s in a `ConversationalTestCase`. ::: +## MLLM Test Case + +An `MLLMTestCase` in deepeval is designed to unit test outputs from MLLM (Multimodal Large Language Model) applications. Unlike an `LLMTestCase`, which only handles textual parameters, an `MLLMTestCase` accepts both text and image inputs and outputs. This is particularly useful for evaluating tasks such as text-to-image generation or MLLM-driven image editing. + +:::caution +You may only evaluate `MLLMTestCase`s using multimodal metrics such as `VIEScore`. +::: + +```python +from deepeval.test_case import MLLMTestCase +from deepeval.types import Image + +mllm_test_case = MLLMTestCase( + # Replace this with your user input + input=["Change the color of the shoes to blue.", Image(url="./shoes.png", local=True)] + # Replace this with your actual MLLM application + actual_output=["The original image of red shoes now shows the shoes in blue.", Image(url="https://shoe-images.com/edited-shoes", local=False)] +) +``` + +### Input + +The `input` mimics a user interacting with your MLLM application. Like an `LLMTestCase` input, an `MLLMTestCase` input is the direct input to your prompt template, and so **SHOULD NOT CONTAIN** your prompt template. + +```python +from deepeval.test_case import MLLMTestCase +from deepeval.types import Image + +mllm_test_case = MLLMTestCase( + input=["Change the color of the shoes to blue.", Image(url="./shoes.png", local=True)] +) +``` + +:::info +The `input` parameter accepts a list of strings and `Image`s, which is a class specific `deepeval`. The `Image` class accepts an image path and automatically sets the `local` attribute to `true` or `false` depending on whether the image is locally stored or hosted online. By default, `local` is set to `false`. + +````python +### Example: + +```python +from deepeval import Image + +# Example of using the Image class +image_input = Image(image_path="path/to/image.jpg") + +# image_input.local will automatically be set to `true` if the image is local +# and `false` if the image is hosted online. +```` + +::: + +### Actual Output + +The actual_output is simply what your MLLM application returns for a given input. Similarly, it also accepts a list of strings and `Image`s. + +```python +from deepeval.test_case import MLLMTestCase +from deepeval.types import Image + +mllm_test_case = MLLMTestCase( + input=["Change the color of the shoes to blue.", Image(url="./shoes.png", local=True)], + actual_output=["The original image of red shoes now shows the shoes in blue.", Image(url="https://shoe-images.com/edited-shoes", local=False)] +) +``` + ## Assert A Test Case Before we begin going through the final sections, we highly recommend you to login to [Confident AI](https://confident-ai.com) (the platform powering deepeval) via the CLI. This way, you can keep track of all evaluation results generated each time you execute `deepeval test run`. @@ -287,7 +352,7 @@ deepeval login Similar to Pytest, `deepeval` allows you to assert any test case you create by calling the `assert_test` function by running `deepeval test run` via the CLI. -**A test case passes only if all metrics passess.** Depending on the metric, a combination of `input`, `actual_output`, `expected_output`, `context`, and `retrieval_context` is used to ascertain whether their criterion have been met. +**A test case passes only if all metrics passes.** Depending on the metric, a combination of `input`, `actual_output`, `expected_output`, `context`, and `retrieval_context` is used to ascertain whether their criterion have been met. ```python title="test_assert_example.py" # A hypothetical LLM application example diff --git a/docs/docs/metrics-viescore.mdx b/docs/docs/metrics-viescore.mdx new file mode 100644 index 00000000..ed367ee6 --- /dev/null +++ b/docs/docs/metrics-viescore.mdx @@ -0,0 +1,81 @@ +--- +id: metrics-viescore +title: VIEScore +sidebar_label: VIEScore +--- + +import Equation from "@site/src/components/equation"; + +`VIEScore` assesses the performance of **image generation and editing tasks** by evaluating the quality of synthesized images based on semantic consistency and perceptual quality. `deepeval`'s VIEScore metric is a self-explaining MLLM-Eval, meaning it outputs a reason for its metric score. + +:::tip +Using `VIEScore` with GPT-4v as the evaluation model achieves scores comparable to human ratings in text-to-image generation tasks, and is especially good at detecting undesirable artifacts. +::: + +## Required Arguments + +To use the `VIEScore`, you'll have to provide the following arguments when creating an `MLLMTestCase`: + +- `input` +- `actual_output` + +## Example + +```python +from deepeval import evaluate +from deepeval.metrics import VIEScore, VIEScoreTask +from deepeval.types import Image +from deepeval.test_case import MLLMTestCase + +# Replace this with your actual MLLM application output +actual_output=[Image(url="https://shoe-images.com/edited-shoes", local=False)] + +metric = VIEScore( + threshold=0.7, + model="gpt-4o", + include_reason=True, + task=VIEScoreTask.TEXT_TO_IMAGE_EDITING +) +test_case = MLLMTestCase( + input=["Change the color of the shoes to blue.", Image(url="./shoes.png", local=True)], + actual_output=actual_output, + retrieval_context=retrieval_context +) + +metric.measure(test_case) +print(metric.score) +print(metric.reason) + +# or evaluate test cases in bulk +evaluate([test_case], [metric]) +``` + +There are six optional parameters when creating a `FaithfulnessMetric`: + +- [Optional] `threshold`: a float representing the minimum passing threshold, defaulted to 0.5. +- [Optional] `model`: a string specifying which of OpenAI's GPT models to use, **OR** [any custom MLLM model](metrics-introduction#using-a-custom-llm) of type `DeepEvalBaseMLLM`. Defaulted to 'gpt-4o'. +- [Optional] `include_reason`: a boolean which when set to `True`, will include a reason for its evaluation score. Defaulted to `True`. +- [Optional] `strict_mode`: a boolean which when set to `True`, enforces a binary metric score: 1 for perfection, 0 otherwise. It also overrides the current threshold and sets it to 1. Defaulted to `False`. +- [Optional] `async_mode`: a boolean which when set to `True`, enables [concurrent execution within the `measure()` method.](metrics-introduction#measuring-metrics-in-async) Defaulted to `True`. +- [Optional] `verbose_mode`: a boolean which when set to `True`, prints the intermediate steps used to calculate said metric to the console, as outlined in the [How Is It Calculated](#how-is-it-calculated) section. Defaulted to `False`. +- [Optional] `task`: a `VIEScoreTask` enum indicating whether the task is image generation or image editing. Defaulted to `VIEScoreTask.TEXT_TO_IMAGE_GENERATION`. + +:::info +`VIEScoreTask` is an **enumeration that includes two types of tasks**: + +- `TEXT_TO_IMAGE_GENERATION`: the input should contain exactly **0 images**, and the output should contain exactly **1 image**. +- `TEXT_TO_IMAGE_EDITING`: For this task type, both the input and output should each contain exactly **1 image**. + +::: + +## How Is It Calculated? + +The `VIEScore` score is calculated according to the following equation: + + + +The `VIEScore` score combines Semantic Consistency (SC) and Perceptual Quality (PQ) sub-scores to provide a comprehensive evaluation of the synthesized image. The final overall score is derived by taking the square root of the product of the minimum SC and PQ scores. + +**1. SC Scores**: These scores assess aspects such as alignment with the prompt and resemblance to concepts. The minimum value among these sub-scores represents the SC score. During the SC evaluation, both the input conditions and the synthesized image are used. + +**2. PQ Scores**: These scores evaluate the naturalness and absence of artifacts in the image. The minimum value among these sub-scores represents the PQ score. For the PQ evaluation, only the synthesized image is used to prevent confusion from the input conditions. diff --git a/docs/sidebars.js b/docs/sidebars.js index c02c2316..9fc83858 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -29,6 +29,7 @@ module.exports = { "metrics-contextual-relevancy", "metrics-tool-correctness", "metrics-hallucination", + "metrics-viescore", "metrics-bias", "metrics-toxicity", "metrics-ragas", diff --git a/tests/test_viescore.py b/tests/test_viescore.py new file mode 100644 index 00000000..8d0495dd --- /dev/null +++ b/tests/test_viescore.py @@ -0,0 +1,60 @@ +from PIL import Image +import pytest + +from deepeval.dataset import EvaluationDataset +from deepeval import assert_test, evaluate +from deepeval.test_case import MLLMTestCase, LLMTestCase +from deepeval.metrics import VIEScore, AnswerRelevancyMetric, VIEScoreTask +from deepeval.types import Image + +image_path = "./data/image.webp" +edited_image_path = "./data/edited_image.webp" + +test_case_1 = MLLMTestCase( + input=["gesnerate a castle school in fantasy land with the words LLM evaluation on it"], + actual_output=[Image(image_path, local=True)], +) + +test_case_2 = MLLMTestCase( + input=["edit this image so that it is night themed, and LLM evaluation is spelled correctly", Image(image_path, local=True)], + actual_output=[Image(edited_image_path, local=True)], +) + +test_case_3 = LLMTestCase( + input="What is this again?", + actual_output="this is a latte", + expected_output="this is a mocha", + retrieval_context=["I love coffee"], + context=["I love coffee"], + expected_tools=["mixer", "creamer", "dripper"], + tools_called=["mixer", "creamer", "mixer"], +) + +dataset = EvaluationDataset( + test_cases=[ + test_case_1, + test_case_2, + test_case_3] +) +# dataset.evaluate([ +# VIEScore(verbose_mode=True), +# VIEScore(verbose_mode=True, task=VIEScoreTask.TEXT_TO_IMAGE_EDITING), +# AnswerRelevancyMetric()]) + +# evaluate( +# test_cases=[ +# #test_case_1, +# test_case_2, +# test_case_3], +# metrics=[ +# #VIEScore(verbose_mode=True), +# VIEScore(verbose_mode=True, task=VIEScoreTask.TEXT_TO_IMAGE_EDITING), +# AnswerRelevancyMetric()], +# #run_async=False +# ) + +#@pytest.mark.skip(reason="openai is expensive") +def test_viescore(): + vie_score = VIEScore(verbose_mode=True) + vie_score_2 = VIEScore(verbose_mode=True, task=VIEScoreTask.TEXT_TO_IMAGE_EDITING) + assert_test(test_case_2, [vie_score_2, AnswerRelevancyMetric()], run_async=False)