-
Notifications
You must be signed in to change notification settings - Fork 3
/
style_transfer.py
46 lines (34 loc) · 1.39 KB
/
style_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import src
from epicpath import EPath
content_path_list, style_path_list = src.st.data.get_data()
file_combination = src.st.data.get_next_files(content_path_list, style_path_list)
if file_combination is not None:
# image_couple = src.images.load_content_style_img(content_path.as_posix(), style_path.as_posix(), plot_it=True)
src.st.var.param.n = file_combination.n
extractor = src.st.StyleContentModel(
style_layers=src.st.var.param.style_layers.value,
content_layers=src.st.var.param.content_layers.value,
content_gram_layers=src.st.var.param.content_gram_layers.value
)
if src.st.var.param.length > 1:
print(f'param: {src.st.var.param.n}')
optimizers = src.st.Optimizers(
shape=(1,),
lr=src.st.var.param.lr.value
)
file_combination.results_folder.mkdir()
parameters_path = EPath('results/parameters.txt')
if not parameters_path.exists():
src.st.var.param.save_all_txt(parameters_path)
p_path = EPath(f'results/p{src.st.var.param.n}.txt')
if not p_path.exists():
src.st.var.param.save_current_txt(p_path)
src.st.train.style_transfer(
file_combination=file_combination,
extractor=extractor,
optimizers=optimizers,
epochs=src.st.var.param.epochs.value,
steps_per_epoch=src.st.var.param.steps_per_epoch.value
)
else:
print('No result_path left...')