-
Notifications
You must be signed in to change notification settings - Fork 5
/
data_structures.py
150 lines (109 loc) · 4.87 KB
/
data_structures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class Image2BInpainted:
USING_RBG_VALUES = 0
USING_IR = 1
USING_STORED_DESCRIPTORS = 2
def __init__(self, rgb, mask, patch_size, stride, inpainting_approach=-1, ir=None, patch_descriptors=None,
half_patch_landscape_descriptors=None, half_patch_portrait_descriptors=None,
inpainted=None, order_image=None):
self.rgb = rgb
self.mask = mask
self.patch_size = patch_size
self.stride = stride
self.height = self.rgb.shape[0]
self.width = self.rgb.shape[1]
self.inpainting_approach = inpainting_approach
self.ir = ir
self.patch_descriptors = patch_descriptors
self.half_patch_landscape_descriptors = half_patch_landscape_descriptors
self.half_patch_portrait_descriptors = half_patch_portrait_descriptors
self.inpainted = inpainted
self.order_image = order_image
# a patch to be inpainted
class Node:
def __init__(self, node_id, overlap_source_region, x_coord, y_coord,
priority=0, labels=None, pruned_labels=None, differences=None, committed=False, additional_differences=None,
potential_matrix_up=None, potential_matrix_down=None, potential_matrix_left=None, potential_matrix_right=None,
label_cost=None, local_likelihood=None, mask=None,
messages=None, beliefs=None, beliefs_new=None):
self.node_id = node_id
self.overlap_source_region = overlap_source_region
self.x_coord = x_coord
self.y_coord = y_coord
self.priority = priority
if labels is None:
self.labels = []
else:
self.labels = labels
if pruned_labels is None:
self.pruned_labels = []
else:
self.pruned_labels = pruned_labels
if differences is None:
self.differences = {}
else:
self.differences = differences
self.committed = committed
if additional_differences is None:
self.additional_differences = {}
else:
self.additional_differences = additional_differences
self.potential_matrix_up = potential_matrix_up
self.potential_matrix_down = potential_matrix_down
self.potential_matrix_left = potential_matrix_left
self.potential_matrix_right = potential_matrix_right
self.label_cost = label_cost
self.local_likelihood = local_likelihood
self.mask = mask
self.messages = messages
self. beliefs = beliefs
self.beliefs_new = beliefs_new
def prune_labels(self, MAX_NB_LABELS):
sorted_differences = sorted(self.additional_differences.items(), key=lambda kv: kv[1])[:MAX_NB_LABELS] #, reverse=True
self.pruned_labels = [label for (label, diff) in sorted_differences]
def get_up_neighbor_position(self, image):
if self.x_coord < image.stride:
return None
neighbor_x_coord = self.x_coord - image.stride
neighbor_y_coord = self.y_coord
return coordinates_to_position(neighbor_x_coord, neighbor_y_coord, image.height, image.patch_size)
def get_down_neighbor_position(self, image):
if self.x_coord > image.height - (image.patch_size + image.stride):
return None
neighbor_x_coord = self.x_coord + image.stride
neighbor_y_coord = self.y_coord
return coordinates_to_position(neighbor_x_coord, neighbor_y_coord, image.height, image.patch_size)
def get_left_neighbor_position(self, image):
if self.y_coord < image.stride:
return None
neighbor_x_coord = self.x_coord
neighbor_y_coord = self.y_coord - image.stride
return coordinates_to_position(neighbor_x_coord, neighbor_y_coord, image.height, image.patch_size)
def get_right_neighbor_position(self, image):
if self.y_coord > image.width - (image.patch_size + image.stride):
return None
neighbor_x_coord = self.x_coord
neighbor_y_coord = self.y_coord + image.stride
return coordinates_to_position(neighbor_x_coord, neighbor_y_coord, image.height, image.patch_size)
def coordinates_to_position(x, y, image_height, patch_size):
return y * len(range(0, image_height - patch_size + 1)) + x
def position_to_coordinates(position, image_height, patch_size):
x = position % (image_height - patch_size + 1)
y = position // (image_height - patch_size + 1)
return x, y
UP = 1
DOWN = -1
LEFT = 2
RIGHT = -2
def opposite_side(side):
return -side
def get_half_patch_from_patch(patch, stride, side):
patch_size = patch.shape[0]
if side == UP:
half_patch = patch[0: stride, :, :]
elif side == DOWN:
half_patch = patch[stride: patch_size, :, :]
elif side == LEFT:
half_patch = patch[:, 0: stride, :]
else:
half_patch = patch[:, stride: patch_size, :]
return half_patch