From 2ff28d59d8b6d1070d1d4d15aa53dfd27d37ddae Mon Sep 17 00:00:00 2001 From: zdzheng-nus Date: Fri, 3 Dec 2021 23:46:01 +0800 Subject: [PATCH] add new losses --- train.py | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index 9bd4359..07b7f3e 100755 --- a/train.py +++ b/train.py @@ -32,6 +32,7 @@ except ImportError: # will be 3.x series print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') +from pytorch_metric_learning import losses, miners #pip install pytorch-metric-learning ###################################################################### # Options @@ -56,6 +57,10 @@ parser.add_argument('--droprate', default=0.5, type=float, help='drop rate') parser.add_argument('--PCB', action='store_true', help='use PCB+ResNet50' ) parser.add_argument('--circle', action='store_true', help='use Circle loss' ) +parser.add_argument('--contrast', action='store_true', help='use contrast loss' ) +parser.add_argument('--triplet', action='store_true', help='use triplet loss' ) +parser.add_argument('--lifted', action='store_true', help='use lifted loss' ) +parser.add_argument('--sphere', action='store_true', help='use sphere loss' ) parser.add_argument('--ibn', action='store_true', help='use resnet+ibn' ) parser.add_argument('--DG', action='store_true', help='use extra DG-Market Dataset for training. Please download it from https://github.com/NVlabs/DG-Net#dg-market.' ) parser.add_argument('--fp16', action='store_true', help='use float16 instead of float32, which will save about 50% memory' ) @@ -193,6 +198,15 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): warm_iteration = round(dataset_sizes['train']/opt.batchsize)*opt.warm_epoch # first 5 epoch if opt.circle: criterion_circle = CircleLoss(m=0.25, gamma=32) # gamma = 64 may lead to a better result. + if opt.triplet: + miner = miners.MultiSimilarityMiner() + criterion_triplet = losses.TripletMarginLoss(margin=0.3) + if opt.lifted: + criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0) + if opt.contrast: + criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1) + if opt.sphere: + criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4) for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) @@ -236,14 +250,23 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): sm = nn.Softmax(dim=1) log_sm = nn.LogSoftmax(dim=1) - if opt.circle: + if opt.circle or opt.triplet or opt.lifted or opt.contrast or opt.sphere: logits, ff = outputs fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) ff = ff.div(fnorm.expand_as(ff)) - loss = criterion(logits, labels) + criterion_circle(*convert_label_to_similarity( ff, labels))/now_batch_size - #loss = criterion_circle(*convert_label_to_similarity( ff, labels)) + loss = criterion(logits, labels) _, preds = torch.max(logits.data, 1) - + if opt.circle: + loss += criterion_circle(*convert_label_to_similarity( ff, labels))/now_batch_size + if opt.triplet: + hard_pairs = miner(ff, labels) + loss += criterion_triplet(ff, labels, hard_pairs) #/now_batch_size + if opt.lifted: + loss += criterion_lifted(ff, labels) #/now_batch_size + if opt.contrast: + loss += criterion_contrast(ff, labels) #/now_batch_size + if opt.sphere: + loss += criterion_sphere(ff, labels) #/now_batch_size elif opt.PCB: # PCB part = {} num_part = 6 @@ -382,18 +405,20 @@ def save_network(network, epoch_label): # Load a pretrainied model and reset final fully connected layer. # +return_feature = opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere + if opt.use_dense: - model = ft_net_dense(len(class_names), opt.droprate, circle = opt.circle) + model = ft_net_dense(len(class_names), opt.droprate, circle = return_feature) elif opt.use_NAS: model = ft_net_NAS(len(class_names), opt.droprate) elif opt.use_swin: - model = ft_net_swin(len(class_names), opt.droprate, opt.stride, circle =opt.circle) + model = ft_net_swin(len(class_names), opt.droprate, opt.stride, circle = return_feature) elif opt.use_efficient: - model = ft_net_efficient(len(class_names), opt.droprate, circle = opt.circle) + model = ft_net_efficient(len(class_names), opt.droprate, circle = return_feature) elif opt.use_hr: - model = ft_net_hr(len(class_names), opt.droprate, circle = opt.circle) + model = ft_net_hr(len(class_names), opt.droprate, circle = return_feature) else: - model = ft_net(len(class_names), opt.droprate, opt.stride, circle =opt.circle, ibn=opt.ibn) + model = ft_net(len(class_names), opt.droprate, opt.stride, circle = return_feature, ibn=opt.ibn) if opt.PCB: model = PCB(len(class_names))