Skip to content

Commit

Permalink
fix fastattn
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Mar 4, 2021
1 parent af0f87b commit c457052
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions efficientdet/keras/efficientdet_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,29 @@ def fuse_features(self, nodes):
dtype = nodes[0].dtype

if self.weight_method == 'attn':
edge_weights = []
for var in self.vars:
var = tf.cast(var, dtype=dtype)
edge_weights.append(var)
edge_weights = [tf.cast(var, dtype=dtype)
for var in self.vars]
normalized_weights = tf.nn.softmax(tf.stack(edge_weights))
nodes = tf.stack(nodes, axis=-1)
new_node = tf.reduce_sum(nodes * normalized_weights, -1)
elif self.weight_method == 'fastattn':
edge_weights = []
for var in self.vars:
var = tf.cast(var, dtype=dtype)
edge_weights.append(var)
edge_weights = [tf.nn.relu(tf.cast(var, dtype=dtype))
for var in self.vars]
weights_sum = add_n(edge_weights)
nodes = [
nodes[i] * edge_weights[i] / (weights_sum + 0.0001)
for i in range(len(nodes))
]
new_node = add_n(nodes)
elif self.weight_method == 'channel_attn':
edge_weights = []
for var in self.vars:
var = tf.cast(var, dtype=dtype)
edge_weights.append(var)
edge_weights = [tf.cast(var, dtype=dtype)
for var in self.vars]
normalized_weights = tf.nn.softmax(tf.stack(edge_weights, -1), axis=-1)
nodes = tf.stack(nodes, axis=-1)
new_node = tf.reduce_sum(nodes * normalized_weights, -1)
elif self.weight_method == 'channel_fastattn':
edge_weights = []
for var in self.vars:
var = tf.cast(var, dtype=dtype)
edge_weights.append(var)

edge_weights = [tf.nn.relu(tf.cast(var, dtype=dtype))
for var in self.vars]
weights_sum = add_n(edge_weights)
nodes = [
nodes[i] * edge_weights[i] / (weights_sum + 0.0001)
Expand Down

0 comments on commit c457052

Please sign in to comment.