-
Notifications
You must be signed in to change notification settings - Fork 3
/
focal_loss.py
83 lines (73 loc) · 2.98 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class FocalLoss(nn.Module):
def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
super(FocalLoss, self).__init__()
self.num_class = num_class
self.alpha = alpha
self.gamma = gamma
self.smooth = smooth
self.size_average = size_average
if self.alpha is None:
self.alpha = torch.ones(self.num_class, 1)
elif isinstance(self.alpha, (list, np.ndarray)):
assert len(self.alpha) == self.num_class
self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
self.alpha = self.alpha / self.alpha.sum()
elif isinstance(self.alpha, floa):
alpha = torch.ones(self.num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[balance_index] = self.alpha
self.alpha = alpha
else:
raise TypeError('Not support alpha type')
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0, 1]')
def forward(self, input, target):
logit = F.softmax(input, dim=1)
if logit.dim() > 2:
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = target.view(-1, 1)
epsilon = 1e-10
alpha = self.alpha
if alpha.device != input.device:
alpha = alpha.to(input.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
ont_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
one_hot_key = torch.clamp(one_hot_key, self.smooth, 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + epsilon
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
class BCEFocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
super(BCEFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, _input, target):
pt = torch.sigmoid(_input)
alpha = self.alpha
loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
(1 - alpha) * pt ** self.gamma * ( 1 - target) * torch.log(1 - pt)
if self.reduction == 'elementwise_mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss