-
Notifications
You must be signed in to change notification settings - Fork 108
/
Capsule_Keras.py
81 lines (69 loc) · 3.28 KB
/
Capsule_Keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#! -*- coding: utf-8 -*-
# refer: https://kexue.fm/archives/5112
from keras import activations
from keras import backend as K
from keras.engine.topology import Layer
def squash(x, axis=-1):
s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
scale = K.sqrt(s_squared_norm)/ (0.5 + s_squared_norm)
return scale * x
#define our own softmax function instead of K.softmax
def softmax(x, axis=-1):
ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
return ex/K.sum(ex, axis=axis, keepdims=True)
#A Capsule Implement with Pure Keras
class Capsule(Layer):
def __init__(self, num_capsule, dim_capsule, routings=3, share_weights=True, activation='squash', **kwargs):
super(Capsule, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.share_weights = share_weights
if activation == 'squash':
self.activation = squash
else:
self.activation = activations.get(activation)
def build(self, input_shape):
super(Capsule, self).build(input_shape)
input_dim_capsule = input_shape[-1]
if self.share_weights:
self.W = self.add_weight(name='capsule_kernel',
shape=(1, input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
else:
input_num_capsule = input_shape[-2]
self.W = self.add_weight(name='capsule_kernel',
shape=(input_num_capsule,
input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
def call(self, u_vecs):
if self.share_weights:
u_hat_vecs = K.conv1d(u_vecs, self.W)
else:
u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])
batch_size = K.shape(u_vecs)[0]
input_num_capsule = K.shape(u_vecs)[1]
u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule,
self.num_capsule, self.dim_capsule))
u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
#final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]
b = K.zeros_like(u_hat_vecs[:,:,:,0]) #shape = [None, num_capsule, input_num_capsule]
for i in range(self.routings):
c = softmax(b, 1)
# o = K.batch_dot(c, u_hat_vecs, [2, 2])
o = tf.einsum('bin,binj->bij', c, u_hat_vecs)
if K.backend() == 'theano':
o = K.sum(o, axis=1)
if i < self.routings - 1:
o = K.l2_normalize(o, -1)
# b = K.batch_dot(o, u_hat_vecs, [2, 3])
b = tf.einsum('bij,binj->bin', o, u_hat_vecs)
if K.backend() == 'theano':
b = K.sum(b, axis=1)
return self.activation(o)
def compute_output_shape(self, input_shape):
return (None, self.num_capsule, self.dim_capsule)