Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thanks #241

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2016 Joshua Z. Zhang
Copyright (c) 2016 Prasad9

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
231 changes: 47 additions & 184 deletions README.md

Large diffs are not rendered by default.

103 changes: 103 additions & 0 deletions dataset/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,106 @@ def _data_augmentation(self, data, label):
data = data.astype('float32')
data = data - self._mean_pixels
return data, label

class DetTestImageIter(mx.io.DataIter):
"""
Detection Iterator, which will feed data and label to network
Optional data augmentation is performed when providing batch

Parameters:
----------
imdb : Imdb
image database
batch_size : int
batch size
data_shape : int or (int, int)
image shape to be resized
mean_pixels : float or float list
[R, G, B], mean pixel values
"""
def __init__(self, test_images, batch_size, data_shape, \
mean_pixels=[128, 128, 128]):
super(DetTestImageIter, self).__init__()

self.test_images = test_images
self.batch_size = batch_size
if isinstance(data_shape, int):
data_shape = (data_shape, data_shape)
self._data_shape = data_shape
self._mean_pixels = mx.nd.array(mean_pixels).reshape((3,1,1))

self._current = 0
self._size = len(test_images)
self._index = np.arange(self._size)

self._data = None
self._label = None
self._get_batch()
self.resized_data = None

@property
def provide_data(self):
return [(k, v.shape) for k, v in self._data.items()]

@property
def provide_label(self):
return []

def reset(self):
self._current = 0

def iter_next(self):
return self._current < self._size

def next(self):
if self.iter_next():
self._get_batch()
data_batch = mx.io.DataBatch(data=list(self._data.values()),
label=list(self._label.values()),
pad=self.getpad(), index=self.getindex())
self._current += self.batch_size
return data_batch
else:
raise StopIteration

def getindex(self):
return self._current // self.batch_size

def getpad(self):
pad = self._current + self.batch_size - self._size
return 0 if pad < 0 else pad

def _get_batch(self):
"""
Load data/label from dataset
"""
batch_data = mx.nd.zeros((self.batch_size, 3, self._data_shape[0], self._data_shape[1]))
for i in range(self.batch_size):
if (self._current + i) >= self._size:
continue
else:
index = self._index[self._current + i]
img_content = self.test_images[index]
#img = mx.img.imdecode(img_content)
img = mx.nd.array(img_content)
data = self._data_augmentation(img)
batch_data[i] = data

self._data = {'data': batch_data}
self._label = {'label': None}

def _data_augmentation(self, data):
"""
perform data augmentations: crop, mirror, resize, sub mean, swap channels...
"""
interp_methods = [cv2.INTER_LINEAR]
interp_method = interp_methods[int(np.random.uniform(0, 1) * len(interp_methods))]
data = mx.img.imresize(data, self._data_shape[1], self._data_shape[0], interp_method)
self.resized_data = data
data = mx.nd.transpose(data, (2,0,1))
data = data.astype('float32')
data = data - self._mean_pixels
return data

def current_data(self):
return self.resized_data
208 changes: 208 additions & 0 deletions detect/image_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from __future__ import print_function
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
from dataset.iterator import DetTestImageIter
import cv2

class ImageDetector(object):
"""
SSD detector which hold a detection network and wraps detection API

Parameters:
----------
symbol : mx.Symbol
detection network Symbol
model_prefix : str
name prefix of trained model
epoch : int
load epoch of trained model
data_shape : int
input data resize shape
mean_pixels : tuple of float
(mean_r, mean_g, mean_b)
batch_size : int
run detection with batch size
ctx : mx.ctx
device to use, if None, use mx.cpu() as default context
"""
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
classes, thresh = 0.6, plot_confidence = True, batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))])
self.mod.set_params(args, auxs)
self.data_shape = data_shape
self.mean_pixels = mean_pixels
self.classes = classes
self.colors = []
self.fill_random_colors_int()
self.thresh = thresh
self.plot_confidence = plot_confidence

def fill_random_colors(self):
import random
for i in range(len(self.classes)):
self.colors.append((random.random(), random.random(), random.random()))

#print(self.colors)

def fill_random_colors_int(self):
import random
for i in range(len(self.classes)):
self.colors.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))

#print(self.colors)


def detect(self, det_iter, show_timer=False):
"""
detect all images in iterator

Parameters:
----------
det_iter : DetIter
iterator for all testing images
show_timer : Boolean
whether to print out detection exec time

Returns:
----------
list of detection results
"""
num_images = det_iter._size
result = []
detections = []
#if not isinstance(det_iter, mx.io.PrefetchingIter):
# det_iter = mx.io.PrefetchingIter(det_iter)
start = timer()
for pred, _, _ in self.mod.iter_predict(det_iter):
detections.append(pred[0].asnumpy())
time_elapsed = timer() - start
if show_timer:
print("Detection time for {} images: {:.4f} sec".format(num_images, time_elapsed))
for output in detections:
for i in range(output.shape[0]):
det = output[i, :, :]
res = det[np.where(det[:, 0] >= 0)[0]]
result.append(res)
resized_img = det_iter.current_data()
return result, resized_img

def im_detect(self, img, show_timer=False):
"""
wrapper for detecting multiple images

Parameters:
----------
im_list : list of str
image path or list of image paths
root_dir : str
directory of input images, optional if image path already
has full directory information
extension : str
image extension, eg. ".jpg", optional

Returns:
----------
list of detection results in format [det0, det1...], det is in
format np.array([id, score, xmin, ymin, xmax, ymax]...)
"""
im_list = [img]
test_iter = DetTestImageIter(im_list, 1, self.data_shape, self.mean_pixels)
return self.detect(test_iter, show_timer)

def plot_rects(self, img, dets):
img_shape = img.shape
for i in range(dets.shape[0]):
cls_id = int(dets[i, 0])
if cls_id >= 0:
score = dets[i, 1]
#print('Score is {}, class {}'.format(score, cls_id))
if score > self.thresh:
xmin = int(dets[i, 2] * img_shape[1])
ymin = int(dets[i, 3] * img_shape[0])
xmax = int(dets[i, 4] * img_shape[1])
ymax = int(dets[i, 5] * img_shape[0])

cv2.rectangle(img, (xmin, ymin), (xmax, ymax), self.colors[cls_id], 4)

class_name = self.classes[cls_id]
cv2.putText(img, class_name, (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4)
#print('Class id = {}, Score = {}, Country = {}, rect = ({}, {}, {}, {})'.format(cls_id, score, class_name, xmin, ymin, xmax, ymax))

def detect_and_visualize_image(self, img, show_timer=False):
"""
wrapper for im_detect and visualize_detection

Parameters:
----------
im_list : list of str or str
image path or list of image paths
root_dir : str or None
directory of input images, optional if image path already
has full directory information
extension : str or None
image extension, eg. ".jpg", optional

Returns:
----------

"""
dets, resized_img = self.im_detect(img, show_timer=show_timer)
resized_img = resized_img.asnumpy()
resized_img /= 255.0
for k, det in enumerate(dets):
self.plot_rects(resized_img, det)
return resized_img

def scale_and_plot_rects(self, img, dets):
img_shape = img.shape
for i in range(dets.shape[0]):
cls_id = int(dets[i, 0])
if cls_id >= 0:
score = dets[i, 1]
#print('Score is {}, class {}'.format(score, cls_id))
if score > self.thresh:
xmin = int(dets[i, 2] * img_shape[1])
ymin = int(dets[i, 3] * img_shape[0])
xmax = int(dets[i, 4] * img_shape[1])
ymax = int(dets[i, 5] * img_shape[0])

cv2.rectangle(img, (xmin, ymin), (xmax, ymax), self.colors[cls_id], 4)

class_name = self.classes[cls_id]
cv2.putText(img, class_name, (xmin, ymin - 15), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 3)
if self.plot_confidence:
score_color = (0, 255, 0) if score > 0.5 else (255, 0, 0)
cv2.putText(img, '{:.3f}'.format(score), (xmax - 60, ymin - 15), cv2.FONT_HERSHEY_SIMPLEX, 1, score_color, 1)


def detect_and_layover_image(self, img, show_timer=False):
"""
wrapper for im_detect and visualize_detection

Parameters:
----------
im_list : list of str or str
image path or list of image paths
root_dir : str or None
directory of input images, optional if image path already
has full directory information
extension : str or None
image extension, eg. ".jpg", optional

Returns:
----------

"""
dets, _ = self.im_detect(img, show_timer=show_timer)
for k, det in enumerate(dets):
self.scale_and_plot_rects(img, det)
return img
23 changes: 23 additions & 0 deletions flags/data_utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
XML_FOLDER = 'Annotations'
GENERATED_DATA = 'GeneratedData'
TRAIN_FOLDER = 'Train'
VAL_FOLDER = 'Val'
TEST_FOLDER = 'Test'
LABEL = 'Label'

# Dimensions of the raw flag height and width
FLAG_HEIGHT = 144
FLAG_WIDTH = 224

# There are 202599 images in my CelebA dataset. Give this value appropriately.
CELEBA_TOTAL_FILES = 202599 # Directly hardcoded to save memory

MIN_FLAGS = 1
MAX_FLAGS = 2 # Currently supports upto 2 Maximum flags in one image.

BORDER_WHITE_AREA = 40 # How much percent of card should be covered with white area.

IMAGE_SIZE = 224 # Input image size

TOTAL_TRAIN_IMAGES = 120000 # Corresponds to how many train images to generate
TOTAL_VALIDATION_IMAGES = 10000 # Corresponds to how many validation images to generate
Loading