Skip to content

Commit

Permalink
QAConv 2.1
Browse files Browse the repository at this point in the history
* Simplified graph sampling
* Einstein summation for QAConv
* Hard triplet loss
* Adaptive epoch and learning rate scheduling
* Automatic mixed precision training
  • Loading branch information
Shengcai Liao committed Sep 16, 2021
1 parent 7d4cb8a commit 9bdea33
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 883 deletions.
248 changes: 122 additions & 126 deletions main.py

Large diffs are not rendered by default.

379 changes: 0 additions & 379 deletions main_gs.py

This file was deleted.

29 changes: 20 additions & 9 deletions reid/evaluators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import print_function, absolute_import
import sys
import time
from collections import OrderedDict

import torch
import numpy as np
from .utils import to_torch

from .evaluation_metrics import cmc, mean_ap
from .tlift import TLift
Expand All @@ -23,10 +23,9 @@ def pre_tlift(gallery, query):

def extract_cnn_feature(model, inputs):
model = model.cuda().eval()
inputs = to_torch(inputs).cuda()
with torch.no_grad():
outputs = model(inputs)
outputs = outputs.data.cpu()
outputs = outputs.cpu()
return outputs


Expand Down Expand Up @@ -180,14 +179,24 @@ def evaluate(self, matcher, testset, query_loader, gallery_loader, gal_batch_siz
prob_batch_size=4096, tau=100, sigma=200, K=10, alpha=0.2):
query = testset.query
gallery = testset.gallery
prob_fea, _ = extract_features(self.model, query_loader, verbose=True)
prob_fea = torch.cat([prob_fea[f].unsqueeze(0) for f, _, _, _ in query], 0)
gal_fea, _ = extract_features(self.model, gallery_loader, verbose=True)
gal_fea = torch.cat([gal_fea[f].unsqueeze(0) for f, _, _, _ in gallery], 0)

print('Compute similarity...', end='\t')
print('Compute similarity ...', end='\t')
start = time.time()
dist = pairwise_distance(matcher, prob_fea, gal_fea, gal_batch_size, prob_batch_size) # [p, g]

prob_fea, _ = extract_features(self.model, query_loader)
prob_fea = torch.cat([prob_fea[f].unsqueeze(0) for f, _, _, _ in query], 0)
num_prob = len(query)
num_gal = len(gallery)
batch_size = gallery_loader.batch_size
dist = torch.zeros(num_prob, num_gal)

for i, (imgs, fnames, pids, _) in enumerate(gallery_loader):
print('Compute similarity %d / %d. \t' % (i + 1, len(gallery_loader)), end='\r', file=sys.stdout.console)
gal_fea = extract_cnn_feature(self.model, imgs)
g0 = i * batch_size
g1 = min(num_gal, (i + 1) * batch_size)
dist[:, g0:g1] = pairwise_distance(matcher, prob_fea, gal_fea, batch_size, prob_batch_size) # [p, g]

print('Time: %.3f seconds.' % (time.time() - start))
rank1, mAP = evaluate_all(dist, query=query, gallery=gallery)

Expand All @@ -204,6 +213,8 @@ def evaluate(self, matcher, testset, query_loader, gallery_loader, gal_batch_siz
dist_rerank[num_prob:, :num_prob] = dist.t()
dist_rerank[:num_prob, :num_prob] = pairwise_distance(matcher, prob_fea, prob_fea, gal_batch_size,
prob_batch_size)
gal_fea, _ = extract_features(self.model, gallery_loader, verbose=True)
gal_fea = torch.cat([gal_fea[f].unsqueeze(0) for f, _, _, _ in gallery], 0)
dist_rerank[num_prob:, num_prob:] = pairwise_distance(matcher, gal_fea, gal_fea, gal_batch_size,
prob_batch_size)

Expand Down
58 changes: 58 additions & 0 deletions reid/loss/triplet_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Class for the hard triplet loss
Shengcai Liao and Ling Shao, "Graph Sampling Based Deep Metric Learning for Generalizable Person Re-Identification." In arXiv preprint, arXiv:2104.01546, 2021.
Author:
Shengcai Liao
[email protected]
Version:
V1.0
April 1, 2021
"""

import torch
from torch.nn import Module
from torch import nn


class TripletLoss(Module):
def __init__(self, matcher, margin=16):
"""
Inputs:
matcher: a class for matching pairs of images
margin: margin parameter for the triplet loss
"""
super(TripletLoss, self).__init__()
self.matcher = matcher
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin, reduction='none')

def reset_running_stats(self):
self.matcher.reset_running_stats()

def reset_parameters(self):
self.matcher.reset_parameters()

def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'.format(input.dim()))

def forward(self, feature, target):
self._check_input_dim(feature)
self.matcher.make_kernel(feature)

score = self.matcher(feature) # [b, b]

target1 = target.unsqueeze(1)
mask = (target1 == target1.t())
pair_labels = mask.float()

min_pos = torch.min(score * pair_labels +
(1 - pair_labels + torch.eye(score.size(0), device=score.device)) * 1e15, dim=1)[0]
max_neg = torch.max(score * (1 - pair_labels) - pair_labels * 1e15, dim=1)[0]

# Compute ranking hinge loss
loss = self.ranking_loss(min_pos, max_neg, torch.ones_like(target))

with torch.no_grad():
acc = (min_pos >= max_neg).float()

return loss, acc
23 changes: 11 additions & 12 deletions reid/models/qaconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
Shengcai Liao
[email protected]
Version:
V1.2
Mar. 31, 2021
V1.3
July 1, 2021
"""

import torch
from torch import nn
from torch.nn import Module
from torch.nn import functional as F


class QAConv(Module):
Expand All @@ -29,41 +28,41 @@ def __init__(self, num_features, height, width):
self.height = height
self.width = width
self.bn = nn.BatchNorm1d(1)
self.fc = nn.Linear(self.height * self.width * 2, 1)
self.fc = nn.Linear(self.height * self.width, 1)
self.logit_bn = nn.BatchNorm1d(1)
self.kernel = None
self.reset_parameters()

def reset_running_stats(self):
self.bn.reset_running_stats()
self.logit_bn.reset_running_stats()

def reset_parameters(self):
self.bn.reset_parameters()
self.fc.reset_parameters()
self.logit_bn.reset_parameters()
with torch.no_grad():
self.fc.weight.fill_(1. / (self.height * self.width))

def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'.format(input.dim()))

def make_kernel(self, features): # probe features
kernel = features.permute([0, 2, 3, 1]) # [p, h, w, d]
kernel = kernel.reshape(-1, self.num_features, 1, 1) # [phw, d, 1, 1]
self.kernel = kernel
self.kernel = features

def forward(self, features): # gallery features
self._check_input_dim(features)

hw = self.height * self.width
batch_size = features.size(0)

score = F.conv2d(features, self.kernel) # [g, phw, h, w]
score = torch.einsum('g c h w, p c y x -> g p y x h w', features, self.kernel)
score = score.view(batch_size, -1, hw, hw)
score = torch.cat((score.max(dim=2)[0], score.max(dim=3)[0]), dim=-1)

score = score.view(-1, 1, 2 * hw)
score = self.bn(score).view(-1, 2 * hw)
score = score.view(-1, 1, hw)
score = self.bn(score).view(-1, hw)
score = self.fc(score)
score = score.view(-1, 2).sum(dim=-1, keepdim=True)
score = self.logit_bn(score)
score = score.view(batch_size, -1).t() # [p, g]

Expand Down
8 changes: 3 additions & 5 deletions reid/models/resmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
Shengcai Liao
[email protected]
Version:
V1.1
Feb. 7, 2021
V1.2
July 4, 2021
"""

from __future__ import absolute_import
Expand Down Expand Up @@ -58,9 +58,8 @@ def __init__(self, depth, ibn_type=None, final_layer='layer3', neck=128, pretrai
out_planes = fea_dims[final_layer]

if neck > 0:
self.neck_conv = nn.Conv2d(out_planes, neck, kernel_size=3, padding=1, bias=False)
self.neck_conv = nn.Conv2d(out_planes, neck, kernel_size=3, padding=1)
out_planes = neck
self.neck_bn = nn.BatchNorm2d(out_planes)

self.num_features = out_planes

Expand All @@ -73,7 +72,6 @@ def forward(self, inputs):

if self.neck > 0:
x = self.neck_conv(x)
x = self.neck_bn(x)

x = F.normalize(x)

Expand Down
151 changes: 0 additions & 151 deletions reid/pretrainer.py

This file was deleted.

Loading

0 comments on commit 9bdea33

Please sign in to comment.