forked from mac-mvak/StyleDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstyleremoval_image.py
133 lines (114 loc) · 6 KB
/
styleremoval_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import time
from glob import glob
from tqdm import tqdm
import os
import numpy as np
import cv2
from PIL import Image
import torch
from torch import nn
import torchvision.utils as tvu
from models.ddpm.diffusion import DDPM
from models.improved_ddpm.script_util import i_DDPM
from utils.text_dic import SRC_TRG_TXT_DIC
from utils.diffusion_utils import get_beta_schedule, denoising_step
from losses import id_loss
from losses.clip_loss import CLIPLoss
from datasets.data_utils import get_dataset, get_dataloader
from configs.paths_config import DATASET_PATHS, MODEL_PATHS, HYBRID_MODEL_PATHS, HYBRID_CONFIG
from datasets.imagenet_dic import IMAGENET_DIC
from datasets.GENERIC_dataset import GENERIC_dataset
from utils.align_utils import run_alignment
class StyleRemovalImage(object):
def __init__(self, args, config, device=None):
self.args = args
self.config = config
if device is None:
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
self.device = device
self.model_var_type = config.model.var_type
betas = get_beta_schedule(
beta_start=config.diffusion.beta_start,
beta_end=config.diffusion.beta_end,
num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps
)
self.betas = torch.from_numpy(betas).float().to(self.device)
self.num_timesteps = betas.shape[0]
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
posterior_variance = betas * \
(1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
if self.model_var_type == "fixedlarge":
self.logvar = np.log(np.append(posterior_variance[1], betas[1:]))
elif self.model_var_type == 'fixedsmall':
self.logvar = np.log(np.maximum(posterior_variance, 1e-20))
def remove_style(self):
print(self.args.exp)
#print(f' {self.src_txts}')
#print(f'-> {self.trg_txts}')
model = i_DDPM(self.config.data.dataset, self.args.image_size)
if self.args.image_size == 256:
model_path = 'pretrained/256x256_diffusion_uncond.pt'
elif self.args.image_size == 512:
model_path = 'pretrained/512x512_diffusion.pt'
init_ckpt = torch.load(model_path)
u = model.load_state_dict(init_ckpt)
model.to(self.device)
# ----------- Precompute Latents -----------#
print("Prepare style latent")
seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0_remove
seq_inv = [int(s) for s in list(seq_inv)]
seq_inv_next = [-1] + list(seq_inv[:-1])
n = self.args.bs_train
style_lat_pairs = []
style_image_path = self.args.style_image
style_color_ds = GENERIC_dataset(style_image_path, color=True, img_size=self.args.image_size)
style_gray_ds = GENERIC_dataset(style_image_path, img_size=self.args.image_size)
color_img = torch.from_numpy(style_color_ds[0][1])
tvu.save_image((color_img + 1) * 0.5, os.path.join(self.args.image_folder,
f'style_color_rec_ninv{self.args.n_inv_step}.png'))
style_lat_pairs.append(color_img.unsqueeze(0))
x0 = torch.from_numpy(style_gray_ds[0][0])
tvu.save_image((x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'style_0_orig.png'))
x = x0.clone().to(self.device).unsqueeze(0)
model.eval()
time_s = time.time()
with torch.no_grad():
with tqdm(total=len(seq_inv), desc=f"Inversion process style") as progress_bar:
for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):
t = (torch.ones(n) * i).to(self.device)
t_prev = (torch.ones(n) * j).to(self.device)
x = denoising_step(x, t=t, t_next=t_prev, models=model,
logvars=self.logvar,
sampling_type='ddim',
b=self.betas,
eta=0,
learn_sigma=True)
progress_bar.update(1)
time_e = time.time()
print(f'{time_e - time_s} seconds')
x_lat = x.clone()
tvu.save_image((x_lat + 1) * 0.5, os.path.join(self.args.image_folder,
f'style_1_lat_ninv{self.args.n_inv_step}.png'))
with tqdm(total=len(seq_inv), desc=f"Generative process style") as progress_bar:
time_s = time.time()
for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
t = (torch.ones(n) * i).to(self.device)
t_next = (torch.ones(n) * j).to(self.device)
x = denoising_step(x, t=t, t_next=t_next, models=model,
logvars=self.logvar,
sampling_type=self.args.sample_type,
b=self.betas,
learn_sigma=True)
progress_bar.update(1)
time_e = time.time()
print(f'{time_e - time_s} seconds')
style_lat_pairs += [x0, x.detach().clone(), x_lat.detach().clone()]
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
f'style_1_rec_ninv{self.args.n_inv_step}.png'))
image_name = self.args.style_image.split('.')[0]
pairs_path = os.path.join('precomputed/',
f'{self.config.data.category}_style_{image_name}_t{self.args.t_0_remove}_size{self.args.image_size}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_{self.args.removal_mode}_pairs.pth')
torch.save([style_lat_pairs], pairs_path)