Skip to content

Commit

Permalink
add new losses
Browse files Browse the repository at this point in the history
  • Loading branch information
zdzheng-nus committed Dec 3, 2021
1 parent 6e6b118 commit 2ff28d5
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
except ImportError: # will be 3.x series
print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')

from pytorch_metric_learning import losses, miners #pip install pytorch-metric-learning

######################################################################
# Options
Expand All @@ -56,6 +57,10 @@
parser.add_argument('--droprate', default=0.5, type=float, help='drop rate')
parser.add_argument('--PCB', action='store_true', help='use PCB+ResNet50' )
parser.add_argument('--circle', action='store_true', help='use Circle loss' )
parser.add_argument('--contrast', action='store_true', help='use contrast loss' )
parser.add_argument('--triplet', action='store_true', help='use triplet loss' )
parser.add_argument('--lifted', action='store_true', help='use lifted loss' )
parser.add_argument('--sphere', action='store_true', help='use sphere loss' )
parser.add_argument('--ibn', action='store_true', help='use resnet+ibn' )
parser.add_argument('--DG', action='store_true', help='use extra DG-Market Dataset for training. Please download it from https://github.com/NVlabs/DG-Net#dg-market.' )
parser.add_argument('--fp16', action='store_true', help='use float16 instead of float32, which will save about 50% memory' )
Expand Down Expand Up @@ -193,6 +198,15 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
warm_iteration = round(dataset_sizes['train']/opt.batchsize)*opt.warm_epoch # first 5 epoch
if opt.circle:
criterion_circle = CircleLoss(m=0.25, gamma=32) # gamma = 64 may lead to a better result.
if opt.triplet:
miner = miners.MultiSimilarityMiner()
criterion_triplet = losses.TripletMarginLoss(margin=0.3)
if opt.lifted:
criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0)
if opt.contrast:
criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
if opt.sphere:
criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
Expand Down Expand Up @@ -236,14 +250,23 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):

sm = nn.Softmax(dim=1)
log_sm = nn.LogSoftmax(dim=1)
if opt.circle:
if opt.circle or opt.triplet or opt.lifted or opt.contrast or opt.sphere:
logits, ff = outputs
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
loss = criterion(logits, labels) + criterion_circle(*convert_label_to_similarity( ff, labels))/now_batch_size
#loss = criterion_circle(*convert_label_to_similarity( ff, labels))
loss = criterion(logits, labels)
_, preds = torch.max(logits.data, 1)

if opt.circle:
loss += criterion_circle(*convert_label_to_similarity( ff, labels))/now_batch_size
if opt.triplet:
hard_pairs = miner(ff, labels)
loss += criterion_triplet(ff, labels, hard_pairs) #/now_batch_size
if opt.lifted:
loss += criterion_lifted(ff, labels) #/now_batch_size
if opt.contrast:
loss += criterion_contrast(ff, labels) #/now_batch_size
if opt.sphere:
loss += criterion_sphere(ff, labels) #/now_batch_size
elif opt.PCB: # PCB
part = {}
num_part = 6
Expand Down Expand Up @@ -382,18 +405,20 @@ def save_network(network, epoch_label):
# Load a pretrainied model and reset final fully connected layer.
#

return_feature = opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere

if opt.use_dense:
model = ft_net_dense(len(class_names), opt.droprate, circle = opt.circle)
model = ft_net_dense(len(class_names), opt.droprate, circle = return_feature)
elif opt.use_NAS:
model = ft_net_NAS(len(class_names), opt.droprate)
elif opt.use_swin:
model = ft_net_swin(len(class_names), opt.droprate, opt.stride, circle =opt.circle)
model = ft_net_swin(len(class_names), opt.droprate, opt.stride, circle = return_feature)
elif opt.use_efficient:
model = ft_net_efficient(len(class_names), opt.droprate, circle = opt.circle)
model = ft_net_efficient(len(class_names), opt.droprate, circle = return_feature)
elif opt.use_hr:
model = ft_net_hr(len(class_names), opt.droprate, circle = opt.circle)
model = ft_net_hr(len(class_names), opt.droprate, circle = return_feature)
else:
model = ft_net(len(class_names), opt.droprate, opt.stride, circle =opt.circle, ibn=opt.ibn)
model = ft_net(len(class_names), opt.droprate, opt.stride, circle = return_feature, ibn=opt.ibn)

if opt.PCB:
model = PCB(len(class_names))
Expand Down

0 comments on commit 2ff28d5

Please sign in to comment.