Skip to content

Commit

Permalink
stock classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
aoirint committed Jul 24, 2019
1 parent 8d959d8 commit b46322b
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 68 deletions.
78 changes: 78 additions & 0 deletions SSBUBoundingBoxUtil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

def fighters_info_bbox(fighter_num):
assert fighter_num in (2, 3, 4), 'Not implemented'

# FIXME: magic number
if fighter_num == 2:
return [
# Fighter 1
[ 245 ,560, 240,160 ],
# Fighter 2
[ 245+495,560, 240,160 ],
]
elif fighter_num == 3:
return [
# Fighter 1
[ 75 ,560, 240,160 ],
# Fighter 2
[ 75+416,560, 240,160 ],
# Fighter 3
[ 75+832,560, 240,160 ],
]
elif fighter_num == 4:
return [
# Fighter 1
[ 98 ,560, 240,160 ],
# Fighter 2
[ 98+272,560, 240,160 ],
# Fighter 3
[ 98+544,560, 240,160 ],
# Fighter 4
[ 98+816,560, 240,160 ],
]


def fighters_damage_bboxes(fighter_num):
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
left = bbox[0] + 85
top = bbox[1] + 50

ret.append([
[ left ,top, 35,55 ],
[ left+30,top, 35,55 ],
[ left+60,top, 35,55 ],
[ left+97,top+28, 18,25 ],
])
return ret

def fighters_name_bbox(fighter_num):
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
ret.append([ bbox[0]+105,bbox[1]+110, 120,16 ])
return ret

def fighters_chara_bbox(fighter_num):
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
ret.append([ bbox[0]+10,bbox[1]+28, 110,110 ])
return ret

def fighters_stock_bboxes(fighter_num, stock_num=3):
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
left = bbox[0] + 73
top = bbox[1] + 131

ret.append(
[ [ left + 17*k, top, 16, 16, ] for k in range(stock_num) ]
)
return ret
120 changes: 52 additions & 68 deletions SSBUFrameAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,30 @@
from SSBUDigitClassifier import SSBUDigitClassifier
from SSBUNameRecognizer import SSBUNameRecognizer
from SSBUCharaClassifier import SSBUCharaClassifier
from SSBUStockClassifier import SSBUStockClassifier
import SSBUBoundingBoxUtil


class SSBUFrameAnalyzer:
def __init__(self, digit_classifier, name_recognizer, chara_classifier):
def __init__(self, digit_classifier, name_recognizer, chara_classifier, stock_classifier):
self.digit_classifier = digit_classifier
self.name_recognizer = name_recognizer
self.chara_classifier = chara_classifier
self.stock_classifier = stock_classifier

def __call__(self, frame, fighter_num=2):

dmgs = self.analyze_damage(frame, fighter_num=fighter_num)
names = self.analyze_name(frame, fighter_num=fighter_num)
charas = self.analyze_chara(frame, fighter_num=fighter_num)
stocks = self.analyze_stock(frame, fighter_num=fighter_num)

fighters = {}
for fighter_idx in range(fighter_num):
fighters[fighter_idx] = {
'chara_name': charas[fighter_idx],
'name': names[fighter_idx],
'damage': dmgs[fighter_idx],
'stocks': stocks[fighter_idx],
}

result = {
Expand All @@ -32,69 +37,8 @@ def __call__(self, frame, fighter_num=2):

return result

def fighters_info_bbox(self, fighter_num):
assert fighter_num in (2, 3, 4), 'Not implemented'

# FIXME: magic number
if fighter_num == 2:
return [
# Fighter 1
[ 245 ,560, 240,160 ],
# Fighter 2
[ 245+495,560, 240,160 ],
]
elif fighter_num == 3:
return [
# Fighter 1
[ 75 ,560, 240,160 ],
# Fighter 2
[ 75+416,560, 240,160 ],
# Fighter 3
[ 75+832,560, 240,160 ],
]
elif fighter_num == 4:
return [
# Fighter 1
[ 98 ,560, 240,160 ],
# Fighter 2
[ 98+272,560, 240,160 ],
# Fighter 3
[ 98+545,560, 240,160 ],
# Fighter 4
[ 98+817,560, 240,160 ],
]
def fighters_damage_bboxes(self, fighter_num):
info_bboxes = self.fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
left = bbox[0] + 85
top = bbox[1] + 50

ret.append([
[ left ,top, 35,55 ],
[ left+30,top, 35,55 ],
[ left+60,top, 35,55 ],
[ left+97,top+28, 18,25 ],
])
return ret
def fighters_name_bbox(self, fighter_num):
info_bboxes = self.fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
ret.append([ bbox[0]+105,bbox[1]+110, 120,16 ])
return ret
def fighters_chara_bbox(self, fighter_num):
info_bboxes = self.fighters_info_bbox(fighter_num=fighter_num)

ret = []
for fighter_idx, bbox in enumerate(info_bboxes):
ret.append([ bbox[0]+10,bbox[1]+28, 110,110 ])
return ret

def analyze_damage(self, frame, fighter_num):
fighters_dmg_bboxes = self.fighters_damage_bboxes(fighter_num=fighter_num)
fighters_dmg_bboxes = SSBUBoundingBoxUtil.fighters_damage_bboxes(fighter_num=fighter_num)
assert fighter_num == len(fighters_dmg_bboxes)

dc = self.digit_classifier
Expand Down Expand Up @@ -136,7 +80,7 @@ def predict_digit(img):
return result

def analyze_name(self, frame, fighter_num):
fighters_name_bbox = self.fighters_name_bbox(fighter_num=fighter_num)
fighters_name_bbox = SSBUBoundingBoxUtil.fighters_name_bbox(fighter_num=fighter_num)
assert fighter_num == len(fighters_name_bbox)

nr = self.name_recognizer
Expand All @@ -156,15 +100,15 @@ def analyze_name(self, frame, fighter_num):
return result

def analyze_chara(self, frame, fighter_num):
fighters_chara_bbox = self.fighters_chara_bbox(fighter_num=fighter_num)
fighters_chara_bbox = SSBUBoundingBoxUtil.fighters_chara_bbox(fighter_num=fighter_num)
assert fighter_num == len(fighters_chara_bbox)

cc = self.chara_classifier
def predict_chara(img):
names, dists = cc(img, k=3)
min_dist = dists[0]

print(names, dists)
# print(names, dists)
thresh_dist = 10.

name = names[0] if min_dist < thresh_dist else None
Expand All @@ -184,6 +128,41 @@ def predict_chara(img):

return result

def analyze_stock(self, frame, fighter_num):
fighters_stock_bboxes = SSBUBoundingBoxUtil.fighters_stock_bboxes(fighter_num=fighter_num, stock_num=5)
assert fighter_num == len(fighters_stock_bboxes)

sc = self.stock_classifier
def predict_stock(img):
stocks, dists = sc(img, k=3)
min_dist = dists[0]

# print(stocks, dists)
thresh_dist = 0.6

stock = stocks[0] if min_dist < thresh_dist else None
return stock

result = {}
for fighter_idx in range(fighter_num):
bboxes = fighters_stock_bboxes[fighter_idx]

stocks = []
for bbox_idx, bbox in enumerate(bboxes):
simg = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]] # RGB
simg = cv2.cvtColor(simg, cv2.COLOR_BGR2GRAY) # GRAY

# cv2.imwrite('fighter-stock-%d-%d.png' % (fighter_idx, bbox_idx, ), simg)

stock = predict_stock(simg)
if stock is None:
break
stocks.append(stock)

result[fighter_idx] = stocks

return result



if __name__ == '__main__':
Expand All @@ -194,6 +173,7 @@ def predict_chara(img):
parser = argparse.ArgumentParser()
parser.add_argument('digit_dictionary', type=str)
parser.add_argument('chara_dictionary', type=str)
parser.add_argument('stock_dictionary', type=str)
parser.add_argument('input', type=str)
parser.add_argument('fighter_num', type=int) # TODO: predict
args = parser.parse_args()
Expand All @@ -210,7 +190,11 @@ def predict_chara(img):
chara_classifier = SSBUCharaClassifier(feature_json=args.chara_dictionary)
print('loaded chara classifier')

analyzer = SSBUFrameAnalyzer(digit_classifier=digit_classifier, name_recognizer=name_recognizer, chara_classifier=chara_classifier)
print('loading stock classifier...')
stock_classifier = SSBUStockClassifier(feature_json=args.stock_dictionary)
print('loaded stock classifier')

analyzer = SSBUFrameAnalyzer(digit_classifier=digit_classifier, name_recognizer=name_recognizer, chara_classifier=chara_classifier, stock_classifier=stock_classifier)

frame = cv2.imread(args.input, 1) # RGB
frame = cv2.resize(frame, (1280, 720))
Expand Down
89 changes: 89 additions & 0 deletions SSBUStockClassifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from skimage.feature import hog
import numpy as np
import json
import cv2

class SSBUStockClassifier:
def __init__(self, feature_json):
self.feature_json = feature_json

with open(feature_json, 'r') as fp:
data = json.load(fp)

chara_id2name = sorted(list(data.keys()))

charas = []
features = []
for chara_id, chara_name in enumerate(chara_id2name):
fts = data[chara_name]

# n = 0
for feature in fts:
charas.append(int(chara_id))
features.append(np.asarray(feature, dtype=np.float32))
# n += 1
# if n == 4:
# break

self.categories = set(charas)
self.datacount = len(charas)

self.chara_id2name = chara_id2name
self.charas = np.asarray(charas, dtype=np.int32)
self.features = np.asarray(features, dtype=np.float32)


def __call__(self, img, k=3):
# img isinstance of np.ndarray
# print(img.shape)
assert len(img.shape) == 2 # gray
assert img.shape[1] == 16 and img.shape[0] == 16
img = cv2.resize(img, (30, 30)) # hog requirement
h0 = hog(img)

dists = np.linalg.norm(self.features - h0, axis=1)

sarg = np.argsort(dists) # sorted-arg
topKarg = sarg[:k]

charas = self.charas[topKarg].tolist()
names = []
for i in range(len(charas)):
chara_id = int(charas[i])
name = self.chara_id2name[chara_id]
names.append(name)

return names, dists[topKarg].tolist()

if __name__ == '__main__':
import argparse
import time
import cv2
import SSBUBoundingBoxUtil

parser = argparse.ArgumentParser()
parser.add_argument('dictionary', type=str)
parser.add_argument('input', type=str)
parser.add_argument('fighter_num', type=int)
args = parser.parse_args()

print('loading...')
classifier = SSBUStockClassifier(feature_json=args.dictionary)
print('loaded')

img = cv2.imread(args.input, 0)
stock_bboxes = SSBUBoundingBoxUtil.fighters_stock_bboxes(fighter_num=args.fighter_num, stock_num=3)

t = time.time()
# img = cv2.resize(img, (110, 110))

for fighter_idx, stock_bboxes in enumerate(stock_bboxes):
for stock_idx, bbox in enumerate(stock_bboxes):
simg = img[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]

ret = classifier(simg)
print(ret)

elapsed = time.time() - t

print('FPS: %f (%f s)' % (1/elapsed, elapsed, ))

0 comments on commit b46322b

Please sign in to comment.