-
Notifications
You must be signed in to change notification settings - Fork 37
/
img2vid.py
99 lines (82 loc) · 3.3 KB
/
img2vid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import io
import av
import comfy.latent_formats
import comfy.model_base
import comfy.model_management
import comfy.model_patcher
import comfy.sd
import comfy.supported_models_base
import comfy.utils
import numpy as np
import torch
from ltx_video.models.autoencoders.vae_encode import get_vae_size_scale_factor
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def videofy(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (image * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
def pad_tensor(tensor, target_len):
dim = 2
repeat_factor = target_len - tensor.shape[dim] # Ceiling division
last_element = tensor.select(dim, -1).unsqueeze(dim)
padding = last_element.repeat(1, 1, repeat_factor, 1, 1)
return torch.cat([tensor, padding], dim=dim)
def encode_media_conditioning(
init_media, vae, width, height, frames_number, image_compression, initial_latent
):
pixels = comfy.utils.common_upscale(
init_media.movedim(-1, 1), width, height, "bilinear", ""
).movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
if image_compression > 0:
for i in range(encode_pixels.shape[0]):
image = videofy(encode_pixels[i], image_compression)
encode_pixels[i] = image
encoded_latents = vae.encode(encode_pixels).float()
video_scale_factor, _, _ = get_vae_size_scale_factor(vae.first_stage_model)
video_scale_factor = video_scale_factor if frames_number > 1 else 1
target_len = (frames_number // video_scale_factor) + 1
encoded_latents = encoded_latents[:, :, :target_len]
if initial_latent is None:
initial_latent = encoded_latents
else:
if encoded_latents.shape[2] > initial_latent.shape[2]:
initial_latent = pad_tensor(initial_latent, encoded_latents.shape[2])
initial_latent[:, :, : encoded_latents.shape[2], ...] = encoded_latents
init_image_frame_number = init_media.shape[0]
if init_image_frame_number == 1:
result = pad_tensor(initial_latent, target_len)
elif init_image_frame_number % 8 != 1:
result = pad_tensor(initial_latent, target_len)
else:
result = initial_latent
return result