-
Notifications
You must be signed in to change notification settings - Fork 52
/
loss.py
87 lines (79 loc) · 3.15 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch.autograd import Function
from itertools import repeat
import numpy as np
# Intersection = dot(A, B)
# Union = dot(A, A) + dot(B, B)
# The Dice loss function is defined as
# 1/2 * intersection / union
#
# The derivative is 2[(union * target - 2 * intersect * input) / union^2]
class DiceLoss(Function):
def __init__(self, *args, **kwargs):
pass
def forward(self, input, target, save=True):
if save:
self.save_for_backward(input, target)
eps = 0.000001
_, result_ = input.max(1)
result_ = torch.squeeze(result_)
if input.is_cuda:
result = torch.cuda.FloatTensor(result_.size())
self.target_ = torch.cuda.FloatTensor(target.size())
else:
result = torch.FloatTensor(result_.size())
self.target_ = torch.FloatTensor(target.size())
result.copy_(result_)
self.target_.copy_(target)
target = self.target_
# print(input)
intersect = torch.dot(result, target)
# binary values so sum the same as sum of squares
result_sum = torch.sum(result)
target_sum = torch.sum(target)
union = result_sum + target_sum + (2*eps)
# the target volume can be empty - so we still want to
# end up with a score of 1 if the result is 0/0
IoU = intersect / union
print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
union, intersect, target_sum, result_sum, 2*IoU))
out = torch.FloatTensor(1).fill_(2*IoU)
self.intersect, self.union = intersect, union
return out
def backward(self, grad_output):
input, _ = self.saved_tensors
intersect, union = self.intersect, self.union
target = self.target_
gt = torch.div(target, union)
IoU2 = intersect/(union*union)
pred = torch.mul(input[:, 1], IoU2)
dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4))
grad_input = torch.cat((torch.mul(dDice, -grad_output[0]),
torch.mul(dDice, grad_output[0])), 0)
return grad_input , None
def dice_loss(input, target):
return DiceLoss()(input, target)
def dice_error(input, target):
eps = 0.000001
_, result_ = input.max(1)
result_ = torch.squeeze(result_)
if input.is_cuda:
result = torch.cuda.FloatTensor(result_.size())
target_ = torch.cuda.FloatTensor(target.size())
else:
result = torch.FloatTensor(result_.size())
target_ = torch.FloatTensor(target.size())
result.copy_(result_.data)
target_.copy_(target.data)
target = target_
intersect = torch.dot(result, target)
result_sum = torch.sum(result)
target_sum = torch.sum(target)
union = result_sum + target_sum + 2*eps
intersect = np.max([eps, intersect])
# the target volume can be empty - so we still want to
# end up with a score of 1 if the result is 0/0
IoU = intersect / union
# print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
# union, intersect, target_sum, result_sum, 2*IoU))
return 2*IoU