From 0433c97b780ad08577462320565a47b83779b344 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 13 Aug 2024 13:24:10 +0200 Subject: [PATCH] refactor: resolve CLIPFeatureExtractor deprecation warning (#152) This commit resolves a CLIPFeatureExtractor deprecation warning thrown by the NSFW check logic. --- runner/app/pipelines/utils/utils.py | 4 ++-- runner/test_prompts.py | 12 ------------ 2 files changed, 2 insertions(+), 14 deletions(-) delete mode 100644 runner/test_prompts.py diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index ebac62d1..7edc903e 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -11,7 +11,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from PIL import Image from torch import dtype as TorchDtype -from transformers import CLIPFeatureExtractor +from transformers import CLIPImageProcessor from typing import Dict logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ def __init__( self._safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ).to(self.device) - self._feature_extractor = CLIPFeatureExtractor.from_pretrained( + self._feature_extractor = CLIPImageProcessor.from_pretrained( "openai/clip-vit-base-patch32" ) diff --git a/runner/test_prompts.py b/runner/test_prompts.py deleted file mode 100644 index b87a60dc..00000000 --- a/runner/test_prompts.py +++ /dev/null @@ -1,12 +0,0 @@ -from app.pipelines.utils.utils import split_prompt - -if __name__ == "__main__": - input_prompt = "A photo of a cat.|" - test = split_prompt(input_prompt) - - input_prompt2 = "" - test2 = split_prompt(input_prompt2) - - input_pormpt3 = "A photo of a cat.|A photo of a dog.|A photo of a bird." - test3 = split_prompt(input_pormpt3) -