forked from livepeer/ai-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_to_image.py
222 lines (188 loc) · 8.51 KB
/
image_to_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import logging
import os
from enum import Enum
from typing import List, Optional, Tuple
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (SafetyChecker, get_model_dir,
get_torch_device, is_lightning_model,
is_turbo_model)
from diffusers import (AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline, UNet2DConditionModel)
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile
from safetensors.torch import load_file
ImageFile.LOAD_TRUNCATED_IMAGES = True
logger = logging.getLogger(__name__)
class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))
class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}
torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load the fp16 variant if fp16 'safetensors' files are present in the cache.
# NOTE: Exception for SDXL-Lightning model: despite having fp16 'safetensors'
# files, they are not named according to the standard convention.
has_fp16_variant = (
any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
or ModelName.SDXL_LIGHTNING.value in model_id
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToImagePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"
# ByteDance/SDXL-Lightning-2step
if "2step" in model_id:
unet_id = "sdxl_lightning_2step_unet"
# ByteDance/SDXL-Lightning-4step
elif "4step" in model_id:
unet_id = "sdxl_lightning_4step_unet"
# ByteDance/SDXL-Lightning-8step
elif "8step" in model_id:
unet_id = "sdxl_lightning_8step_unet"
else:
# Default to 2step
unet_id = "sdxl_lightning_2step_unet"
unet = UNet2DConditionModel.from_config(
base, subfolder="unet", cache_dir=kwargs["cache_dir"]
).to(torch_device, kwargs["torch_dtype"])
unet.load_state_dict(
load_file(
hf_hub_download(
ModelName.SDXL_LIGHTNING.value,
f"{unet_id}.safetensors",
cache_dir=kwargs["cache_dir"],
),
device=str(torch_device),
)
)
self.ldm = StableDiffusionXLPipeline.from_pretrained(
base, unet=unet, **kwargs
).to(torch_device)
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif ModelName.INSTRUCT_PIX2PIX.value in model_id:
self.ldm = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.ldm.scheduler.config
)
else:
self.ldm = AutoPipelineForImage2Image.from_pretrained(
model_id, **kwargs
).to(torch_device)
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
logger.warning(
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended "
"as it may lead to suboptimal performance. Please disable one of them."
)
if sfast_enabled:
logger.info(
"ImageToImagePipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.ldm = compile_model(self.ldm)
# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"ImageToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"ImageToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache
self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"ImageToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]
if num_inference_steps is None or num_inference_steps < 1:
del kwargs["num_inference_steps"]
if (
self.model_id == "stabilityai/sdxl-turbo"
or self.model_id == "stabilityai/sd-turbo"
):
# SD turbo models were trained without guidance_scale so
# it should be set to 0
kwargs["guidance_scale"] = 0.0
# Ensure num_inference_steps * strength >= 1 for minimum pipeline
# execution steps.
if "num_inference_steps" in kwargs:
kwargs["strength"] = max(
1.0 / kwargs.get("num_inference_steps", 1),
kwargs.get("strength", 0.5),
)
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
if "2step" in self.model_id:
kwargs["num_inference_steps"] = 2
elif "4step" in self.model_id:
kwargs["num_inference_steps"] = 4
elif "8step" in self.model_id:
kwargs["num_inference_steps"] = 8
else:
# Default to 2step
kwargs["num_inference_steps"] = 2
output = self.ldm(prompt, image=image, **kwargs)
if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)
return output.images, has_nsfw_concept
def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"