diff --git a/deepeval/metrics/__init__.py b/deepeval/metrics/__init__.py index 28d75e85b..a7d551507 100644 --- a/deepeval/metrics/__init__.py +++ b/deepeval/metrics/__init__.py @@ -16,7 +16,8 @@ from .contextual_precision.contextual_precision import ContextualPrecisionMetric from .knowledge_retention.knowledge_retention import KnowledgeRetentionMetric from .tool_correctness.tool_correctness import ToolCorrectnessMetric -from .viescore.viescore import VIEScore, VIEScoreTask +from .text_to_image.text_to_image import TextToImageMetric +from .image_editing.image_editing import ImageEditingMetric from .conversation_relevancy.conversation_relevancy import ( ConversationRelevancyMetric, ) diff --git a/deepeval/metrics/viescore/__init__.py b/deepeval/metrics/image_editing/__init__.py similarity index 100% rename from deepeval/metrics/viescore/__init__.py rename to deepeval/metrics/image_editing/__init__.py diff --git a/deepeval/metrics/viescore/viescore.py b/deepeval/metrics/image_editing/image_editing.py similarity index 83% rename from deepeval/metrics/viescore/viescore.py rename to deepeval/metrics/image_editing/image_editing.py index 081a29e9f..e6e8bd0d8 100644 --- a/deepeval/metrics/viescore/viescore.py +++ b/deepeval/metrics/image_editing/image_editing.py @@ -5,7 +5,7 @@ from deepeval.metrics import BaseMultimodalMetric from deepeval.test_case import MLLMTestCaseParams, MLLMTestCase, MLLMImage -from deepeval.metrics.viescore.template import VIEScoreTemplate +from deepeval.metrics.image_editing.template import ImageEditingTemplate from deepeval.utils import get_or_create_event_loop from deepeval.metrics.utils import ( construct_verbose_logs, @@ -14,8 +14,7 @@ initialize_multimodal_model, ) from deepeval.models import DeepEvalBaseMLLM -from deepeval.metrics.viescore.schema import ReasonScore -from deepeval.metrics.viescore.task import VIEScoreTask +from deepeval.metrics.image_editing.schema import ReasonScore from deepeval.metrics.indicator import metric_progress_indicator required_params: List[MLLMTestCaseParams] = [ @@ -24,16 +23,14 @@ ] -class VIEScore(BaseMultimodalMetric): +class ImageEditingMetric(BaseMultimodalMetric): def __init__( self, model: Optional[Union[str, DeepEvalBaseMLLM]] = None, - task: VIEScoreTask = VIEScoreTask.TEXT_TO_IMAGE_GENERATION, threshold: float = 0.5, async_mode: bool = True, strict_mode: bool = False, verbose_mode: bool = False, - _include_VIEScore_task_name: bool = True, ): self.model, self.using_native_model = initialize_multimodal_model(model) self.evaluation_model = self.model.get_model_name() @@ -41,16 +38,11 @@ def __init__( self.strict_mode = strict_mode self.async_mode = async_mode self.verbose_mode = verbose_mode - self.task = task - self._include_VIEScore_task_name = _include_VIEScore_task_name def measure( self, test_case: MLLMTestCase, _show_indicator: bool = True ) -> float: - if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: - check_mllm_test_case_params(test_case, required_params, 0, 1, self) - elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: - check_mllm_test_case_params(test_case, required_params, 1, 1, self) + check_mllm_test_case_params(test_case, required_params, 1, 1, self) self.evaluation_cost = 0 if self.using_native_model else None with metric_progress_indicator(self, _show_indicator=_show_indicator): @@ -102,10 +94,7 @@ async def a_measure( test_case: MLLMTestCase, _show_indicator: bool = True, ) -> float: - if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: - check_mllm_test_case_params(test_case, required_params, 0, 1, self) - elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: - check_mllm_test_case_params(test_case, required_params, 1, 1, self) + check_mllm_test_case_params(test_case, required_params, 1, 1, self) self.evaluation_cost = 0 if self.using_native_model else None with metric_progress_indicator( @@ -169,13 +158,10 @@ async def _a_evaluate_semantic_consistency( actual_image_output: MLLMImage, ) -> Tuple[List[int], str]: images: List[MLLMImage] = [] - if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: - images.append(image_input) - elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: - images.extend([image_input, actual_image_output]) + images.extend([image_input, actual_image_output]) prompt = [ - VIEScoreTemplate.generate_semantic_consistency_evaluation_results( - text_prompt=text_prompt, task=self.task + ImageEditingTemplate.generate_semantic_consistency_evaluation_results( + text_prompt=text_prompt ) ] if self.using_native_model: @@ -203,13 +189,10 @@ def _evaluate_semantic_consistency( actual_image_output: MLLMImage, ) -> Tuple[List[int], str]: images: List[MLLMImage] = [] - if self.task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: - images.append(image_input) - elif self.task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: - images.extend([image_input, actual_image_output]) + images.extend([image_input, actual_image_output]) prompt = [ - VIEScoreTemplate.generate_semantic_consistency_evaluation_results( - text_prompt=text_prompt, task=self.task + ImageEditingTemplate.generate_semantic_consistency_evaluation_results( + text_prompt=text_prompt ) ] if self.using_native_model: @@ -233,7 +216,7 @@ async def _a_evaluate_perceptual_quality( ) -> Tuple[List[int], str]: images: List[MLLMImage] = [actual_image_output] prompt = [ - VIEScoreTemplate.generate_perceptual_quality_evaluation_results() + ImageEditingTemplate.generate_perceptual_quality_evaluation_results() ] if self.using_native_model: res, cost = await self.model.a_generate(prompt + images) @@ -256,7 +239,7 @@ def _evaluate_perceptual_quality( ) -> Tuple[List[int], str]: images: List[MLLMImage] = [actual_image_output] prompt = [ - VIEScoreTemplate.generate_perceptual_quality_evaluation_results() + ImageEditingTemplate.generate_perceptual_quality_evaluation_results() ] if self.using_native_model: res, cost = self.model.generate(prompt + images) @@ -304,7 +287,4 @@ def _generate_reason( @property def __name__(self): - if self._include_VIEScore_task_name: - return f"{self.task.value} (VIEScore)" - else: - return "VIEScore" + return "Image Editing" diff --git a/deepeval/metrics/viescore/schema.py b/deepeval/metrics/image_editing/schema.py similarity index 100% rename from deepeval/metrics/viescore/schema.py rename to deepeval/metrics/image_editing/schema.py diff --git a/deepeval/metrics/image_editing/template.py b/deepeval/metrics/image_editing/template.py new file mode 100644 index 000000000..0dc75c176 --- /dev/null +++ b/deepeval/metrics/image_editing/template.py @@ -0,0 +1,65 @@ +import textwrap + + +class ImageEditingTemplate: + + context = textwrap.dedent( + """ + You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. + All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials. + + You will have to give your output in this way (Keep your reasoning concise and short.): + { + "score" : [...], + "reasoning" : "..." + } + """ + ) + + @staticmethod + def generate_semantic_consistency_evaluation_results( + text_prompt: str + ): + return textwrap.dedent( + f""" + {ImageEditingTemplate.context} + + RULES: + + Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. + The objective is to evaluate how successfully the editing instruction has been executed in the second image. + + From scale 0 to 10: + A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) + A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) + Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + + Editing instruction: {text_prompt} + """ + ) + + @staticmethod + def generate_perceptual_quality_evaluation_results(): + return textwrap.dedent( + f""" + {ImageEditingTemplate.context} + + RULES: + + The image is an AI-generated image. + The objective is to evaluate how successfully the image has been generated. + + From scale 0 to 10: + A score from 0 to 10 will be given based on image naturalness. + ( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. + ) + A second score from 0 to 10 will rate the image artifacts. + ( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. + ) + Put the score in a list such that output score = [naturalness, artifacts] + """ + ) diff --git a/deepeval/metrics/text_to_image/__init__.py b/deepeval/metrics/text_to_image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/deepeval/metrics/text_to_image/schema.py b/deepeval/metrics/text_to_image/schema.py new file mode 100644 index 000000000..4646b6a5f --- /dev/null +++ b/deepeval/metrics/text_to_image/schema.py @@ -0,0 +1,7 @@ +from typing import List +from pydantic import BaseModel, Field + + +class ReasonScore(BaseModel): + reasoning: str + score: List[float] diff --git a/deepeval/metrics/text_to_image/template.py b/deepeval/metrics/text_to_image/template.py new file mode 100644 index 000000000..df5e6fc2a --- /dev/null +++ b/deepeval/metrics/text_to_image/template.py @@ -0,0 +1,66 @@ +import textwrap + + +class TextToImageTemplate: + + context = textwrap.dedent( + """ + You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. + All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials. + + You will have to give your output in this way (Keep your reasoning concise and short.): + { + "score" : [...], + "reasoning" : "..." + } + """ + ) + + @staticmethod + def generate_semantic_consistency_evaluation_results( + text_prompt: str + ): + return textwrap.dedent( + f""" + {TextToImageTemplate.context} + + RULES: + + The image is an AI-generated image according to the text prompt. + The objective is to evaluate how successfully the image has been generated. + + From scale 0 to 10: + A score from 0 to 10 will be given based on the success in following the prompt. + (0 indicates that the AI generated image does not follow the prompt at all. 10 indicates the AI generated image follows the prompt perfectly.) + + Put the score in a list such that output score = [score]. + + Text Prompt: {text_prompt} + """ + ) + + @staticmethod + def generate_perceptual_quality_evaluation_results(): + return textwrap.dedent( + f""" + {TextToImageTemplate.context} + + RULES: + + The image is an AI-generated image. + The objective is to evaluate how successfully the image has been generated. + + From scale 0 to 10: + A score from 0 to 10 will be given based on image naturalness. + ( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. + ) + A second score from 0 to 10 will rate the image artifacts. + ( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. + ) + Put the score in a list such that output score = [naturalness, artifacts] + """ + ) diff --git a/deepeval/metrics/text_to_image/text_to_image.py b/deepeval/metrics/text_to_image/text_to_image.py new file mode 100644 index 000000000..fce63af54 --- /dev/null +++ b/deepeval/metrics/text_to_image/text_to_image.py @@ -0,0 +1,286 @@ +import asyncio +from typing import Optional, List, Tuple, Union +import math +import textwrap + +from deepeval.metrics import BaseMultimodalMetric +from deepeval.test_case import MLLMTestCaseParams, MLLMTestCase, MLLMImage +from deepeval.metrics.text_to_image.template import TextToImageTemplate +from deepeval.utils import get_or_create_event_loop +from deepeval.metrics.utils import ( + construct_verbose_logs, + trimAndLoadJson, + check_mllm_test_case_params, + initialize_multimodal_model, +) +from deepeval.models import DeepEvalBaseMLLM +from deepeval.metrics.text_to_image.schema import ReasonScore +from deepeval.metrics.indicator import metric_progress_indicator + +required_params: List[MLLMTestCaseParams] = [ + MLLMTestCaseParams.INPUT, + MLLMTestCaseParams.ACTUAL_OUTPUT, +] + + +class TextToImageMetric(BaseMultimodalMetric): + def __init__( + self, + model: Optional[Union[str, DeepEvalBaseMLLM]] = None, + threshold: float = 0.5, + async_mode: bool = True, + strict_mode: bool = False, + verbose_mode: bool = False, + ): + self.model, self.using_native_model = initialize_multimodal_model(model) + self.evaluation_model = self.model.get_model_name() + self.threshold = 1 if strict_mode else threshold + self.strict_mode = strict_mode + self.async_mode = async_mode + self.verbose_mode = verbose_mode + + def measure( + self, test_case: MLLMTestCase, _show_indicator: bool = True + ) -> float: + check_mllm_test_case_params(test_case, required_params, 0, 1, self) + + self.evaluation_cost = 0 if self.using_native_model else None + with metric_progress_indicator(self, _show_indicator=_show_indicator): + if self.async_mode: + loop = get_or_create_event_loop() + loop.run_until_complete( + self.a_measure(test_case, _show_indicator=False) + ) + else: + input_texts, input_images = self.separate_images_from_text( + test_case.input + ) + _, output_images = self.separate_images_from_text( + test_case.actual_output + ) + + self.SC_scores, self.SC_reasoning = ( + self._evaluate_semantic_consistency( + "\n".join(input_texts), + None if len(input_images) == 0 else input_images[0], + ) + ) + self.PQ_scores, self.PQ_reasoning = ( + self._evaluate_perceptual_quality(output_images[0]) + ) + self.score = self._calculate_score() + self.score = ( + 0 + if self.strict_mode and self.score < self.threshold + else self.score + ) + self.reason = self._generate_reason() + self.success = self.score >= self.threshold + self.verbose_logs = construct_verbose_logs( + self, + steps=[ + f"Semantic Consistency Scores:\n{self.SC_scores}", + f"Semantic Consistency Reasoning:\n{self.SC_reasoning}", + f"Perceptual Quality Scores:\n{self.SC_scores}", + f"Perceptual Quality Reasoning:\n{self.PQ_reasoning}", + f"Score: {self.score}\nReason: {self.reason}", + ], + ) + return self.score + + async def a_measure( + self, + test_case: MLLMTestCase, + _show_indicator: bool = True, + ) -> float: + check_mllm_test_case_params(test_case, required_params, 0, 1, self) + + self.evaluation_cost = 0 if self.using_native_model else None + with metric_progress_indicator( + self, + async_mode=True, + _show_indicator=_show_indicator, + ): + input_texts, input_images = self.separate_images_from_text( + test_case.input + ) + _, output_images = self.separate_images_from_text( + test_case.actual_output + ) + (self.SC_scores, self.SC_reasoning), ( + self.PQ_scores, + self.PQ_reasoning, + ) = await asyncio.gather( + self._a_evaluate_semantic_consistency( + "\n".join(input_texts), + None if len(input_images) == 0 else input_images[0], + ), + self._a_evaluate_perceptual_quality(output_images[0]), + ) + self.score = self._calculate_score() + self.score = ( + 0 + if self.strict_mode and self.score < self.threshold + else self.score + ) + self.reason = self._generate_reason() + self.success = self.score >= self.threshold + self.verbose_logs = construct_verbose_logs( + self, + steps=[ + f"Semantic Consistency Scores:\n{self.SC_scores}", + f"Semantic Consistency Reasoning:\n{self.SC_reasoning}", + f"Perceptual Quality Scores:\n{self.SC_scores}", + f"Perceptual Quality Reasoning:\n{self.PQ_reasoning}", + f"Score: {self.score}\nReason: {self.reason}", + ], + ) + return self.score + + def separate_images_from_text( + self, multimodal_list: List[Union[MLLMImage, str]] + ) -> Tuple[List[str], List[MLLMImage]]: + images: List[MLLMImage] = [] + texts: List[str] = [] + for item in multimodal_list: + if isinstance(item, MLLMImage): + images.append(item) + elif isinstance(item, str): + texts.append(item) + return texts, images + + async def _a_evaluate_semantic_consistency( + self, + text_prompt: str, + image_input: MLLMImage, + ) -> Tuple[List[int], str]: + images: List[MLLMImage] = [] + images.append(image_input) + prompt = [ + TextToImageTemplate.generate_semantic_consistency_evaluation_results( + text_prompt=text_prompt + ) + ] + if self.using_native_model: + res, cost = await self.model.a_generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = await self.model.a_generate( + prompt + images, schema=ReasonScore + ) + return res.score, res.reasoning + except TypeError: + res = await self.model.a_generate( + prompt + images, input_text=prompt + ) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + def _evaluate_semantic_consistency( + self, + text_prompt: str, + image_input: MLLMImage, + ) -> Tuple[List[int], str]: + images: List[MLLMImage] = [] + images.append(image_input) + prompt = [ + TextToImageTemplate.generate_semantic_consistency_evaluation_results( + text_prompt=text_prompt + ) + ] + if self.using_native_model: + res, cost = self.model.generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = self.model.generate( + prompt + images, schema=ReasonScore + ) + return res.score, res.reasoning + except TypeError: + res = self.model.generate(prompt + images) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + async def _a_evaluate_perceptual_quality( + self, actual_image_output: MLLMImage + ) -> Tuple[List[int], str]: + images: List[MLLMImage] = [actual_image_output] + prompt = [ + TextToImageTemplate.generate_perceptual_quality_evaluation_results() + ] + if self.using_native_model: + res, cost = await self.model.a_generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = await self.model.a_generate( + prompt + images, schema=ReasonScore + ) + return res.score, res.reasoning + except TypeError: + res = await self.model.a_generate(prompt + images) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + def _evaluate_perceptual_quality( + self, actual_image_output: MLLMImage + ) -> Tuple[List[int], str]: + images: List[MLLMImage] = [actual_image_output] + prompt = [ + TextToImageTemplate.generate_perceptual_quality_evaluation_results() + ] + if self.using_native_model: + res, cost = self.model.generate(prompt + images) + self.evaluation_cost += cost + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + else: + try: + res: ReasonScore = self.model.generate( + prompt + images, schema=ReasonScore + ) + return res.score, res.reasoning + except TypeError: + res = self.model.generate(prompt + images) + data = trimAndLoadJson(res, self) + return data["score"], data["reasoning"] + + def _calculate_score(self) -> List[str]: + min_SC_score = min(self.SC_scores) + min_PQ_score = min(self.PQ_scores) + return math.sqrt(min_SC_score * min_PQ_score) / 10 + + def is_successful(self) -> bool: + if self.error is not None: + self.success = False + else: + try: + self.score >= self.threshold + except: + self.success = False + return self.success + + def _generate_reason( + self, + ) -> Tuple[List[float], str]: + return textwrap.dedent( + f""" + The overall score is {self.score:.2f} because the lowest score from semantic consistency was {min(self.SC_scores)} + and the lowest score from perceptual quality was {min(self.PQ_scores)}. These scores were combined to reflect the + overall effectiveness and quality of the AI-generated image(s). + Reason for Semantic Consistency score: {self.SC_reasoning} + Reason for Perceptual Quality score: {self.PQ_reasoning} + """ + ) + + @property + def __name__(self): + return "Text to Image" diff --git a/deepeval/metrics/viescore/task.py b/deepeval/metrics/viescore/task.py deleted file mode 100644 index ce8f6ae66..000000000 --- a/deepeval/metrics/viescore/task.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class VIEScoreTask(Enum): - TEXT_TO_IMAGE_GENERATION = "Text2Image Generation" - TEXT_TO_IMAGE_EDITING = "Text2Image Editing" diff --git a/deepeval/metrics/viescore/template.py b/deepeval/metrics/viescore/template.py deleted file mode 100644 index 8d232e22b..000000000 --- a/deepeval/metrics/viescore/template.py +++ /dev/null @@ -1,86 +0,0 @@ -from .task import VIEScoreTask -import textwrap - - -class VIEScoreTemplate: - - context = textwrap.dedent( - """ - You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. - All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials. - - You will have to give your output in this way (Keep your reasoning concise and short.): - { - "score" : [...], - "reasoning" : "..." - } - """ - ) - - @staticmethod - def generate_semantic_consistency_evaluation_results( - text_prompt: str, task: VIEScoreTask - ): - if task == VIEScoreTask.TEXT_TO_IMAGE_GENERATION: - return textwrap.dedent( - f""" - {VIEScoreTemplate.context} - - RULES: - - The image is an AI-generated image according to the text prompt. - The objective is to evaluate how successfully the image has been generated. - - From scale 0 to 10: - A score from 0 to 10 will be given based on the success in following the prompt. - (0 indicates that the AI generated image does not follow the prompt at all. 10 indicates the AI generated image follows the prompt perfectly.) - - Put the score in a list such that output score = [score]. - - Text Prompt: {text_prompt} - """ - ) - elif task == VIEScoreTask.TEXT_TO_IMAGE_EDITING: - return textwrap.dedent( - f""" - {VIEScoreTemplate.context} - - RULES: - - Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. - The objective is to evaluate how successfully the editing instruction has been executed in the second image. - - From scale 0 to 10: - A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) - A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) - Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. - - Editing instruction: {text_prompt} - """ - ) - - @staticmethod - def generate_perceptual_quality_evaluation_results(): - return textwrap.dedent( - f""" - {VIEScoreTemplate.context} - - RULES: - - The image is an AI-generated image. - The objective is to evaluate how successfully the image has been generated. - - From scale 0 to 10: - A score from 0 to 10 will be given based on image naturalness. - ( - 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. - 10 indicates that the image looks natural. - ) - A second score from 0 to 10 will rate the image artifacts. - ( - 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. - 10 indicates the image has no artifacts. - ) - Put the score in a list such that output score = [naturalness, artifacts] - """ - ) diff --git a/docs/docs/metrics-viescore.mdx b/docs/docs/metrics-image-editing.mdx similarity index 58% rename from docs/docs/metrics-viescore.mdx rename to docs/docs/metrics-image-editing.mdx index 5da1278c9..2db5d5ecc 100644 --- a/docs/docs/metrics-viescore.mdx +++ b/docs/docs/metrics-image-editing.mdx @@ -1,38 +1,37 @@ --- -id: metrics-viescore -title: VIEScore -sidebar_label: VIEScore +id: metrics-image-editing +title: Image Editing +sidebar_label: Image Editing --- import Equation from "@site/src/components/equation"; -`VIEScore` assesses the performance of **image generation and editing tasks** by evaluating the quality of synthesized images based on semantic consistency and perceptual quality. `deepeval`'s VIEScore metric is a self-explaining MLLM-Eval, meaning it outputs a reason for its metric score. - -:::tip -Using `VIEScore` with GPT-4v as the evaluation model achieves scores comparable to human ratings in text-to-image generation tasks, and is especially good at detecting undesirable artifacts. -::: +The Image Editing metric assesses the performance of **image editing tasks** by evaluating the quality of synthesized images based on semantic consistency and perceptual quality (similar to the `TextToImageMetric`). `deepeval`'s Image Editing metric is a self-explaining MLLM-Eval, meaning it outputs a reason for its metric score. ## Required Arguments -To use the `VIEScore`, you'll have to provide the following arguments when creating an `MLLMTestCase`: +To use the `ImageEditingMetric`, you'll have to provide the following arguments when creating an `MLLMTestCase`: - `input` - `actual_output` +:::note +Both the input and output should each contain exactly **1 image**. +::: + ## Example ```python from deepeval import evaluate -from deepeval.metrics import VIEScore, VIEScoreTask +from deepeval.metrics import ImageEditingMetric from deepeval.test_case import MLLMTestCase, MLLMImage # Replace this with your actual MLLM application output actual_output=[MLLMImage(url="https://shoe-images.com/edited-shoes", local=False)] -metric = VIEScore( +metric = ImageEditingMetric( threshold=0.7, include_reason=True, - task=VIEScoreTask.TEXT_TO_IMAGE_EDITING ) test_case = MLLMTestCase( input=["Change the color of the shoes to blue.", MLLMImage(url="./shoes.png", local=True)], @@ -48,30 +47,21 @@ print(metric.reason) evaluate([test_case], [metric]) ``` -There are six optional parameters when creating a `VIEScore`: +There are five optional parameters when creating a `ImageEditingMetric`: - [Optional] `threshold`: a float representing the minimum passing threshold, defaulted to 0.5. - [Optional] `include_reason`: a boolean which when set to `True`, will include a reason for its evaluation score. Defaulted to `True`. - [Optional] `strict_mode`: a boolean which when set to `True`, enforces a binary metric score: 1 for perfection, 0 otherwise. It also overrides the current threshold and sets it to 1. Defaulted to `False`. - [Optional] `async_mode`: a boolean which when set to `True`, enables [concurrent execution within the `measure()` method.](metrics-introduction#measuring-metrics-in-async) Defaulted to `True`. - [Optional] `verbose_mode`: a boolean which when set to `True`, prints the intermediate steps used to calculate said metric to the console, as outlined in the [How Is It Calculated](#how-is-it-calculated) section. Defaulted to `False`. -- [Optional] `task`: a `VIEScoreTask` enum indicating whether the task is image generation or image editing. Defaulted to `VIEScoreTask.TEXT_TO_IMAGE_GENERATION`. - -:::info -`VIEScoreTask` is an **enumeration that includes two types of tasks**: - -- `TEXT_TO_IMAGE_GENERATION`: the input should contain exactly **0 images**, and the output should contain exactly **1 image**. -- `TEXT_TO_IMAGE_EDITING`: For this task type, both the input and output should each contain exactly **1 image**. - -::: ## How Is It Calculated? -The `VIEScore` score is calculated according to the following equation: +The `ImageEditingMetric` score is calculated according to the following equation: -The `VIEScore` score combines Semantic Consistency (SC) and Perceptual Quality (PQ) sub-scores to provide a comprehensive evaluation of the synthesized image. The final overall score is derived by taking the square root of the product of the minimum SC and PQ scores. +The `ImageEditingMetric` score combines Semantic Consistency (SC) and Perceptual Quality (PQ) sub-scores to provide a comprehensive evaluation of the synthesized image. The final overall score is derived by taking the square root of the product of the minimum SC and PQ scores. ### SC Scores diff --git a/docs/docs/metrics-text-to-image.mdx b/docs/docs/metrics-text-to-image.mdx new file mode 100644 index 000000000..80ad023ac --- /dev/null +++ b/docs/docs/metrics-text-to-image.mdx @@ -0,0 +1,76 @@ +--- +id: metrics-text-to-image +title: Text to Image +sidebar_label: Text to Image +--- + +import Equation from "@site/src/components/equation"; + +The Text to Image metric assesses the performance of **image generation tasks** by evaluating the quality of synthesized images based on semantic consistency and perceptual quality. `deepeval`'s Text to Image metric is a self-explaining MLLM-Eval, meaning it outputs a reason for its metric score. + +:::tip +The Text to Image metric achieves scores **comparable to human evaluations** when GPT-4v is used as the evaluation model. This metric excels in artifact detection. +::: + +## Required Arguments + +To use the `TextToImageMetric`, you'll have to provide the following arguments when creating an `MLLMTestCase`: + +- `input` +- `actual_output` + +:::note +The input should contain exactly **0 images**, and the output should contain exactly **1 image**. +::: + +## Example + +```python +from deepeval import evaluate +from deepeval.metrics import TextToImageMetric +from deepeval.test_case import MLLMTestCase, MLLMImage + +# Replace this with your actual MLLM application output +actual_output=[MLLMImage(url="https://shoe-images.com/edited-shoes", local=False)] + +metric = TextToImageMetric( + threshold=0.7, + include_reason=True, +) +test_case = MLLMTestCase( + input=["Generate an image of a blue pair of shoes."], + actual_output=actual_output, + retrieval_context=retrieval_context +) + +metric.measure(test_case) +print(metric.score) +print(metric.reason) + +# or evaluate test cases in bulk +evaluate([test_case], [metric]) +``` + +There are five optional parameters when creating a `TextToImageMetric`: + +- [Optional] `threshold`: a float representing the minimum passing threshold, defaulted to 0.5. +- [Optional] `include_reason`: a boolean which when set to `True`, will include a reason for its evaluation score. Defaulted to `True`. +- [Optional] `strict_mode`: a boolean which when set to `True`, enforces a binary metric score: 1 for perfection, 0 otherwise. It also overrides the current threshold and sets it to 1. Defaulted to `False`. +- [Optional] `async_mode`: a boolean which when set to `True`, enables [concurrent execution within the `measure()` method.](metrics-introduction#measuring-metrics-in-async) Defaulted to `True`. +- [Optional] `verbose_mode`: a boolean which when set to `True`, prints the intermediate steps used to calculate said metric to the console, as outlined in the [How Is It Calculated](#how-is-it-calculated) section. Defaulted to `False`. + +## How Is It Calculated? + +The `TextToImageMetric` score is calculated according to the following equation: + + + +The `TextToImageMetric` score combines Semantic Consistency (SC) and Perceptual Quality (PQ) sub-scores to provide a comprehensive evaluation of the synthesized image. The final overall score is derived by taking the square root of the product of the minimum SC and PQ scores. + +### SC Scores + +These scores assess aspects such as alignment with the prompt and resemblance to concepts. The minimum value among these sub-scores represents the SC score. During the SC evaluation, both the input conditions and the synthesized image are used. + +### PQ Scores + +These scores evaluate the naturalness and absence of artifacts in the image. The minimum value among these sub-scores represents the PQ score. For the PQ evaluation, only the synthesized image is used to prevent confusion from the input conditions. diff --git a/docs/sidebars.js b/docs/sidebars.js index 77c533a90..9034b054c 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -54,8 +54,8 @@ module.exports = { }, { type: "category", - label: "Multimodal Metrics", - items: ["metrics-viescore"], + label: "Image Metrics", + items: ["metrics-text-to-image", "metrics-image-editing"], collapsed: true, }, ], diff --git a/tests/test_viescore.py b/tests/test_image_metrics.py similarity index 51% rename from tests/test_viescore.py rename to tests/test_image_metrics.py index dee4050ef..aeed02cbb 100644 --- a/tests/test_viescore.py +++ b/tests/test_image_metrics.py @@ -3,19 +3,23 @@ from deepeval.dataset import EvaluationDataset from deepeval import assert_test, evaluate from deepeval.test_case import MLLMTestCase, LLMTestCase, MLLMImage -from deepeval.metrics import VIEScore, AnswerRelevancyMetric, VIEScoreTask +from deepeval.metrics import AnswerRelevancyMetric, ImageEditingMetric, TextToImageMetric image_path = "./data/image.webp" edited_image_path = "./data/edited_image.webp" -test_case_1 = MLLMTestCase( +############################################################# +# TestCases +############################################################# + +text_to_image_test_case = MLLMTestCase( input=[ "gesnerate a castle school in fantasy land with the words LLM evaluation on it" ], actual_output=[MLLMImage(image_path, local=True)], ) -test_case_2 = MLLMTestCase( +image_editing_test_case = MLLMTestCase( input=[ "edit this image so that it is night themed, and LLM evaluation is spelled correctly", MLLMImage(image_path, local=True), @@ -23,7 +27,7 @@ actual_output=[MLLMImage(edited_image_path, local=True)], ) -test_case_3 = LLMTestCase( +llm_test_case = LLMTestCase( input="What is this again?", actual_output="this is a latte", expected_output="this is a mocha", @@ -33,36 +37,35 @@ tools_called=["mixer", "creamer", "mixer"], ) -dataset = EvaluationDataset(test_cases=[test_case_2, test_case_3]) -dataset.evaluate( - [ - # VIEScore(verbose_mode=True), - VIEScore(verbose_mode=True, task=VIEScoreTask.TEXT_TO_IMAGE_EDITING), - AnswerRelevancyMetric(), - ] -) + +############################################################# +# Evaluate +############################################################# + +# dataset = EvaluationDataset( +# test_cases=[ +# # text_to_image_test_case, +# image_editing_test_case, +# llm_test_case +# ] +# ) +# dataset.evaluate( +# [ +# # TextToImageMetric(), +# ImageEditingMetric(), +# AnswerRelevancyMetric(), +# ] +# ) evaluate( test_cases=[ - test_case_1, - # test_case_2, - test_case_3, + text_to_image_test_case, + # image_editing_test_case, + llm_test_case ], metrics=[ - VIEScore(verbose_mode=True), - # VIEScore(verbose_mode=True, task=VIEScoreTask.TEXT_TO_IMAGE_EDITING), + TextToImageMetric(), + # ImageEditingMetric(), AnswerRelevancyMetric(), ], - # run_async=False -) - - -# #@pytest.mark.skip(reason="openai is expensive") -# def test_viescore(): -# vie_score = VIEScore(verbose_mode=True) -# vie_score_2 = VIEScore( -# verbose_mode=True, task=VIEScoreTask.TEXT_TO_IMAGE_EDITING -# ) -# assert_test( -# test_case_2, [vie_score_2, AnswerRelevancyMetric()], run_async=False -# ) +) \ No newline at end of file