forked from livepeer/ai-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_to_video.py
140 lines (116 loc) · 5.24 KB
/
image_to_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
import os
import time
from typing import List, Optional, Tuple
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
logger = logging.getLogger(__name__)
SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.
class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}
torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToVideoPipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
logger.warning(
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended "
"as it may lead to suboptimal performance. Please disable one of them."
)
if sfast_enabled:
logger.info(
"ImageToVideoPipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.ldm = compile_model(self.ldm)
# Warm-up the pipeline.
# NOTE: Initial calls may be slow due to compilation. Subsequent calls will
# be faster.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
# Retrieve default model params.
# TODO: Retrieve defaults from Pydantic class in route.
warmup_kwargs = {
"image": PIL.Image.new("RGB", (576, 1024)),
"height": 576,
"width": 1024,
"fps": 6,
"motion_bucket_id": 127,
"noise_aug_strength": 0.02,
"decode_chunk_size": 4,
}
logger.info("Warming up ImageToVideoPipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).frames
except Exception as e:
# FIXME: When out of memory, pipeline is corrupted.
logger.error(f"ImageToVideoPipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)
if deepcache_enabled:
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache
self.ldm = enable_deepcache(self.ldm)
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)
if "decode_chunk_size" not in kwargs:
# Decrease decode_chunk_size to reduce memory usage.
kwargs["decode_chunk_size"] = 4
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]
if num_inference_steps is None or num_inference_steps < 1:
del kwargs["num_inference_steps"]
if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images([image])
else:
has_nsfw_concept = [None]
return self.ldm(image, **kwargs).frames, has_nsfw_concept
def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"