Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable automatic mixed precision training #80

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
parser.add_argument('--style_seg_path', default=[])
parser.add_argument('--output_image_path', default='./results/example1.png')
parser.add_argument('--cuda', type=int, default=1, help='Enable CUDA.')
parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
parser.add_argument("--engine", type=str, help="run serialized TRT engine")
parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
parser.add_argument('--verbose', action='store_true', default = False, help='toggles verbose')
parser.add_argument("-d", "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")


args = parser.parse_args()

# Load model
Expand All @@ -32,7 +39,7 @@

if args.cuda:
p_wct.cuda(0)

process_stylization.stylization(
p_wct=p_wct,
content_image_path=args.content_image_path,
Expand All @@ -41,4 +48,5 @@
style_seg_path=args.style_seg_path,
output_image_path=args.output_image_path,
cuda=args.cuda,
args=args
)
2 changes: 1 addition & 1 deletion demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ axel -n 1 https://vignette.wikia.nocookie.net/strangerthings8338/images/e/e0/Wik
convert -resize 25% content1.png content1.png;
convert -resize 50% style1.png style1.png;
cd ..;
python demo.py;
python demo.py $@;
11 changes: 10 additions & 1 deletion photo_wct.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
self.d3 = VGGDecoder(3)
self.e4 = VGGEncoder(4)
self.d4 = VGGDecoder(4)

def transform(self, cont_img, styl_img, cont_seg, styl_seg):
self.__compute_label_info(cont_seg, styl_seg)

Expand Down Expand Up @@ -53,8 +53,15 @@ def transform(self, cont_img, styl_img, cont_seg, styl_seg):
csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg)
Im1 = self.d1(csF1)
return Im1

def forward(self, args):
[cont_img, styl_img, cont_seg, styl_seg] = args
print (cont_img, styl_img, cont_seg, styl_seg)
self.transform(cont_img, styl_img, cont_seg, styl_seg)

def __compute_label_info(self, cont_seg, styl_seg):
cont_seg=cont_seg.numpy()
styl_seg=styl_seg.numpy()
if cont_seg.size == False or styl_seg.size == False:
return
max_label = np.max(cont_seg) + 1
Expand All @@ -69,6 +76,8 @@ def __compute_label_info(self, cont_seg, styl_seg):
self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size)

def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
cont_seg = cont_seg.numpy()
styl_seg = styl_seg.numpy()
cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2)
styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2)
cont_feat_view = cont_feat.view(cont_c, -1).clone()
Expand Down
21 changes: 13 additions & 8 deletions process_stylization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import time

import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
from torch.onnx import export
import torchvision.transforms as transforms
import torchvision.utils as utils

Expand All @@ -32,8 +34,7 @@ def __exit__(self, exc_type, exc_value, exc_tb):
print(self.msg % (time.time() - self.start_time))


def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path,
cuda):
def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path, cuda, args):
# Load image
cont_img = Image.open(content_image_path).convert('RGB')
styl_img = Image.open(style_image_path).convert('RGB')
Expand All @@ -52,12 +53,16 @@ def stylization(p_wct, content_image_path, style_image_path, content_seg_path, s
styl_img = styl_img.cuda(0)
p_wct.cuda(0)

cont_img = Variable(cont_img, volatile=True)
styl_img = Variable(styl_img, volatile=True)

cont_seg = np.asarray(cont_seg)
styl_seg = np.asarray(styl_seg)

cont_img = Variable(cont_img, requires_grad=False)
styl_img = Variable(styl_img, requires_grad=False)
cont_seg = torch.FloatTensor(np.asarray(cont_seg))
styl_seg = torch.FloatTensor(np.asarray(styl_seg))

if args.export_onnx:
assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
export(p_wct, [cont_img, styl_img, cont_seg, styl_seg],
f=args.export_onnx, verbose=args.verbose)

with Timer("Elapsed time in stylization: %f"):
stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg)
utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1)
Expand Down