-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
29 lines (24 loc) · 1.08 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# https://github.com/youansheng/RefinementTest/blob/master/loss.py
import torch
import torch.nn.functional as F
import torch.nn as nn
# Recommend
class CrossEntropyLoss2d(nn.Module):
def __init__(self, weight=None, size_average=True):
super(CrossEntropyLoss2d, self).__init__()
self.nll_loss = nn.NLLLoss2d(weight, size_average)
def forward(self, inputs, targets):
return self.nll_loss(F.log_softmax(inputs), targets)
# this may be unstable sometimes.Notice set the size_average
def CrossEntropy2d(input, target, weight=None, size_average=False):
# input:(n, c, h, w) target:(n, h, w)
n, c, h, w = input.size()
input = input.transpose(1, 2).transpose(2, 3).contiguous()
input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0].view(-1, c)
target_mask = target >= 0
target = target[target_mask]
#loss = F.nll_loss(F.log_softmax(input), target, weight=weight, size_average=False)
loss = F.cross_entropy(input, target, weight=weight, size_average=False)
if size_average:
loss /= target_mask.sum().data[0]
return loss