forked from TheDenk/cogvideox-controlnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cogvideo_transformer.py
116 lines (97 loc) · 5.66 KB
/
cogvideo_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import Any, Dict, Optional, Tuple, Union
import torch
import numpy as np
from diffusers.utils import is_torch_version
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel, Transformer2DModelOutput
class CustomCogVideoXTransformer3DModel(CogVideoXTransformer3DModel):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
start_frame = None,
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
if start_frame is not None:
hidden_states = torch.cat([start_frame, hidden_states], dim=2)
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 2.1 Controlnet states
# if controlnet_states is not None:
# mean_latents, std_latents = torch.mean(hidden_states, dim=(1, 2), keepdim=True), torch.std(hidden_states, dim=(1, 2), keepdim=True)
# mean_control, std_control = torch.mean(controlnet_states, dim=(1, 2), keepdim=True), torch.std(controlnet_states, dim=(1, 2), keepdim=True)
# controlnet_states = (controlnet_states - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents
# hidden_states = hidden_states + controlnet_states
# if controlnet_states is not None:
# controlnet_start_index = len(self.transformer_blocks) - len(controlnet_states)
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
# if (controlnet_states is not None) and (i >= controlnet_start_index):
# controlnet_states_block = controlnet_states[i - controlnet_start_index]
# mean_latents, std_latents = torch.mean(hidden_states, dim=(1, 2), keepdim=True), torch.std(hidden_states, dim=(1, 2), keepdim=True)
# mean_control, std_control = torch.mean(controlnet_states_block, dim=(1, 2), keepdim=True), torch.std(controlnet_states_block, dim=(1, 2), keepdim=True)
# controlnet_states_block = (controlnet_states_block - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)