-
Notifications
You must be signed in to change notification settings - Fork 0
/
refine_embeddings.py
98 lines (71 loc) · 3.84 KB
/
refine_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import better_exceptions
import numpy as np
import tensorflow as tf
rnn = tf.contrib.rnn
class EmbeddingsRefiner(object):
""" Class to refine embeddings before matching
"""
def __init__(self, embedding_dimensions=128):
self.num_refinement_steps = 5
self.use_left_refinement = True
self.use_right_refinement = True
self.embedding_dimensions = embedding_dimensions
def refine(self, left_hypercolumn, right_hypercolumns):
""" refine hypercolumn embeddings of both left and right image
:param left_hypercolumns - [batch_size, 128]
:param right_hypercolumns - length L list of [batch_size, 128] tensors
"""
right_features_refined = self.fce_right(right_hypercolumns) # (L, batch_size, 128)
left_feature_refined = self.fce_left(left_hypercolumn, right_features_refined)
return left_feature_refined, right_features_refined
def fce_left(self, left_hypercolumn, right_hypercolumns):
""" refine hypercolumn for left image
f(x_i, S) = attLSTM(f'(x_i), g(S), K)
hypercolumn refinement is done by running LSTM for fixed no. of steps (num_refinement_steps)
attention over hypercolumns of points on epipolar line in right image used as
context vector of LSTM
:param left_hypercolumn - [batch_size, 128] tensor (point feature)
:param right_hypercolumns - [L, batch_size, 128]
"""
# assert(tf.shape(left_hypercolumn) == tf.shape(right_hypercolumns[0]))
batch_size = tf.shape(left_hypercolumn)[0]
cell = rnn.BasicLSTMCell(self.embedding_dimensions)
prev_state = cell.zero_state(batch_size, tf.float32) # state[0] is c, state[1] is h
for step in xrange(self.num_refinement_steps):
output, state = cell(left_hypercolumn, prev_state) # output: (batch_size, 128)
h_k = tf.add(output, left_hypercolumn) # (batch_size, 128)
content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], right_hypercolumns)) # (L, batch_size, 128)
r_k = tf.reduce_sum(tf.multiply(content_based_attention, right_hypercolumns), axis=0) # (batch_size, 128)
prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))
return output
def fce_right(self, right_hypercolumns):
""" refine hypercolumn for right image
g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i)
Set information is incorporated into embedding using bidirectional LSTM
:param right_hypercolumns - length L list of [batch_size, 128] tensors (point features)
"""
# dimension of fw and bw is half, so that output has size embedding_dimensions
fw_cell = rnn.BasicLSTMCell(self.embedding_dimensions / 2)
bw_cell = rnn.BasicLSTMCell(self.embedding_dimensions / 2)
outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, right_hypercolumns, dtype=tf.float32)
right_features_refined = tf.add(tf.stack(right_hypercolumns), tf.stack(outputs))
return right_features_refined
def main():
embedding_dimensions = 32
refiner = EmbeddingsRefiner(embedding_dimensions=embedding_dimensions)
sess = tf.InteractiveSession()
L = 10
batch_size = 4
left_hypercolumn = tf.constant(np.random.randn(batch_size, embedding_dimensions), dtype=tf.float32)
right_hypercolumns = [None] * L
for i in xrange(L):
right_hypercolumns[i] = tf.constant(np.random.randn(batch_size, embedding_dimensions), dtype=tf.float32)
left_feature_refined, right_features_refined = refiner.refine(left_hypercolumn, right_hypercolumns)
sess.run(tf.global_variables_initializer())
left_f, right_fs = sess.run([left_feature_refined, right_features_refined])
print(left_f.shape)
print(len(right_fs))
print(right_fs[0].shape)
sess.close()
if __name__ == '__main__':
main()