diff --git a/fgfa_rfcn/demo.py b/fgfa_rfcn/demo.py index 0a2902c..96a122c 100644 --- a/fgfa_rfcn/demo.py +++ b/fgfa_rfcn/demo.py @@ -111,7 +111,7 @@ def main(): # load demo data - image_names = glob.glob(cur_path + '/../demo/ILSVRC2015_val_00007010/*.JPEG') + image_names = sorted(glob.glob(cur_path + '/../demo/ILSVRC2015_val_00007010/*.JPEG')) output_dir = cur_path + '/../demo/rfcn_fgfa/' if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -180,7 +180,7 @@ def main(): file_idx = 0 thresh = 1e-3 for idx, element in enumerate(data): - + file_name = '{:06d}'.format(file_idx) data_batch = mx.io.DataBatch(data=[element], label=[], pad=0, index=idx, provide_data=[[(k, v.shape) for k, v in zip(data_names, element)]], provide_label=[None]) @@ -213,8 +213,8 @@ def main(): data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(), scales) total_time = time.time()-t1 if (cfg.TEST.SEQ_NMS==False): - save_image(output_dir, file_idx, out_im) - print 'testing {} {:.4f}s'.format(str(file_idx)+'.JPEG', total_time /(file_idx+1)) + save_image(output_dir, file_name, out_im) + print 'testing {} {:.4f}s'.format(file_name+'.JPEG', total_time /(file_idx+1)) file_idx += 1 else: ################################################# @@ -234,8 +234,8 @@ def main(): total_time = time.time() - t1 if (cfg.TEST.SEQ_NMS == False): - save_image(output_dir, file_idx, out_im) - print 'testing {} {:.4f}s'.format(str(file_idx)+'.JPEG', total_time / (file_idx+1)) + save_image(output_dir, file_name, out_im) + print 'testing {} {:.4f}s'.format(file_name+'.JPEG', total_time / (file_idx+1)) file_idx += 1 end_counter+=1 @@ -247,9 +247,10 @@ def main(): keep = nms(dets) all_boxes[cls_ind + 1][frame_ind] = dets[keep, :] for idx in range(len(data)): + file_name = '{:06d}'.format(idx) boxes_this_image = [[]] + [all_boxes[j][idx] for j in range(1, num_classes)] out_im = draw_all_detection(data[idx][0].asnumpy(), boxes_this_image, classes, scales[0], cfg) - save_image(output_dir, idx, out_im) + save_image(output_dir, file_name, out_im) print 'done' diff --git a/fgfa_rfcn/symbols/resnet_v1_101_flownet_rfcn.py b/fgfa_rfcn/symbols/resnet_v1_101_flownet_rfcn.py index 107cc0d..fff9d96 100644 --- a/fgfa_rfcn/symbols/resnet_v1_101_flownet_rfcn.py +++ b/fgfa_rfcn/symbols/resnet_v1_101_flownet_rfcn.py @@ -1078,7 +1078,7 @@ def get_aggregation_symbol(self, cfg): warp_list = mx.sym.SliceChannel(conv_feat, axis=0, num_outputs=data_range) for i in range(data_range): tiled_weight = mx.symbol.tile(data=weights[i], reps=(1, 1024, 1, 1)) - aggregated_conv_feat += tiled_weight * warp_list[i] + aggregated_conv_feat = aggregated_conv_feat + tiled_weight * warp_list[i] #weights = mx.symbol.tile(data=weights, reps=(1, 1024, 1, 1)) #aggregated_conv_feat = mx.sym.sum(weights * conv_feat, axis=0, keepdims=True)