-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathNTM_Model.py
169 lines (105 loc) · 6.41 KB
/
NTM_Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import tensorflow as tf
import NTMCell
class NTM_CopyTask_Model(tf.keras.Model):
def __init__(self, batch_size, output_dim, rnn_size, memory_rows, memory_columns, num_read_heads, num_write_heads, addressing_type='LOC',
shift_range=tf.range(-1,2), return_all_states = False, **kwargs):
super().__init__()
self.batch_size = batch_size
self.output_dim = output_dim
self.cell = NTMCell.NTMCell(rnn_size, memory_rows, memory_columns, num_read_heads, num_write_heads, output_dim, addressing_type='LOC',
shift_range=tf.range(-1,2))
self.init_state = self.cell.get_initial_state(batch_size = batch_size)
self.return_all_states = return_all_states
@tf.function
def call(self,inputs):
timesteps = inputs.shape[1]
#state_list = []
#op_list = []
#for t in range(timesteps):
# op, state = self.cell(inputs[:,t,:], state)
# op = tf.nn.sigmoid(op)
#
# state_list.append(state)
# op_list.append(op)
#outputs = tf.concat([tf.expand_dims(op_list[i],1) for i in range(timesteps)], axis = 1)
#if self.return_all_states:
# return outputs, state_list
#timesteps = inputs.shape[1]
#states = self.cell.get_initial_state(batch_size = self.batch_size)
#for time in range(timesteps):
# outputs, states = self.cell(inputs[:,time,:],states) #inputs contain the sof and eof delimeters
#response_sheet = tf.zeros([self.batch_size,timesteps-2,self.output_dim]) #timesteps -2 because inputs contains 1 sof and 1 eof, which we do not want
#final_response = []
#for time in range(timesteps - 2):
# outputs, states = self.cell(response_sheet[:,time,:], states)
# final_response.append(outputs)
#final_response = tf.nn.sigmoid(tf.stack(final_response, axis =1))
if self.return_all_states:
outputs, states = tf.compat.v1.nn.dynamic_rnn(self.cell,inputs,initial_state=self.init_state,)
response_sheet = tf.zeros([self.batch_size,timesteps-2,self.output_dim]) #timesteps -2 because inputs contains 1 sof and 1 eof, which we do not want
outputs, states = tf.compat.v1.nn.dynamic_rnn(self.cell,response_sheet,initial_state=states)
outputs, states = tf.compat.v1.nn.dynamic_rnn(self.cell,inputs,initial_state=self.init_state)
response_sheet = tf.zeros([self.batch_size,timesteps-2,self.output_dim]) #timesteps -2 because inputs contains 1 sof and 1 eof, which we do not want
outputs, states = tf.compat.v1.nn.dynamic_rnn(self.cell,response_sheet,initial_state=states)
return tf.nn.sigmoid(outputs), states
class NTM_Associative_Recall_Model(tf.keras.Model):
def __init__(self, batch_size, output_dim, item_len ,rnn_size, memory_rows, memory_columns, num_read_heads, num_write_heads, controller = tf.keras.layers.LSTMCell, addressing_type='LOC',
shift_range=tf.range(-1,2), **kwargs):
#output_dim should include the two delimiters' rows
#item_len: no. of vectors in each item
super().__init__()
self.batch_size = batch_size
self.output_dim = output_dim
self.item_len = item_len
self.controller = controller
self.cell = NTMCell.NTMCell(rnn_size, memory_rows, memory_columns, num_read_heads, num_write_heads, output_dim, controller = self.controller ,addressing_type='LOC',
shift_range=tf.range(-1,2))
self.init_state = self.cell.get_initial_state(batch_size = batch_size)
@tf.function
def call(self, inputs):
#We'll make sure this time that we will include the history of each timesteps' states, as using tf.nn.dynamic_rnn only returns the last timesteps' states
'''
inputs whould be of shape [batch_size, timesteps, output_dim]
'''
timesteps = inputs.shape[1]
state = self.init_state
#Reading the inputs
read_states = []
read_outputs = []
for time in range(timesteps):
output, state = self.cell(inputs[:,time,:], state)
read_states.append(state)
read_outputs.append(output)
#read_states = tf.nn.sigmoid(tf.convert_to_tensor(read_states))
#read_outputs = tf.nn.sigmoid(tf.convert_to_tensor(read_outputs))
#Response Strip for the NTM
response_sheet = tf.zeros([self.batch_size, self.item_len, self.output_dim])
#Writing the answer
write_states = []
write_outputs = []
print(response_sheet.shape)
for time in range(self.item_len):
output, state = self.cell(response_sheet[:,time,:], state)
write_states.append(state)
write_outputs.append(output)
#write_states = tf.nn.sigmoid(tf.convert_to_tensor(write_states))
#write_outputs = tf.nn.sigmoid(tf.convert_to_tensor(write_outputs))
cached_stuff = {
'While_Reading' : (read_outputs, read_states),
'While_Writing' : (write_outputs, write_states)
}
write_outputs = tf.nn.sigmoid(tf.stack(write_outputs, axis = 1))
return write_outputs, cached_stuff
class CE_Loss_Function(tf.keras.losses.Loss):
def call(self,y_true, y_pred):
'''
y_true: (batch_size, timesteps*2 + 2, output_dim + 1), The input sequence copied to the right side of the
y_pred: (batch_size, timesteps*2 + 2, output_dim + 1), The output of the model
'''
return -tf.reduce_mean(y_true*tf.math.log(y_pred + 1e-8) + (1-y_true)*tf.math.log(1-y_pred + 1e-8))
class CustomLoss(tf.keras.losses.Loss):
def call(self,y_true, y_pred):
return tf.reduce_mean(tf.sqrt(tf.abs(y_true - y_pred)))
def HuberLoss(y_true, y_pred, delta):
return 10*tf.reduce_mean(tf.where(tf.abs(y_true-y_pred) < delta,.5*(y_true-y_pred)**2 , delta*(tf.abs(y_true-y_pred)-0.5*delta)))
#Result enhanced by a factor of 10 for easy interpretability