Skip to content

Commit

Permalink
Update _SimpleConsensus to use static autograd methods
Browse files Browse the repository at this point in the history
  • Loading branch information
tchang1997 committed Jan 11, 2021
1 parent 117a4d1 commit b02523d
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions mmaction/models/tenons/segmental_consensuses/simple_consensus.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...registry import SEGMENTAL_CONSENSUSES


class _SimpleConsensus(torch.autograd.Function):
"""Simplest segmental consensus module"""

def __init__(self,
consensus_type='avg',
dim=1):
super(_SimpleConsensus, self).__init__()

assert consensus_type in ['avg']
self.consensus_type = consensus_type
self.dim = dim
self.shape = None

def forward(self, x):
self.shape = x.size()
if self.consensus_type == 'avg':
output = x.mean(dim=self.dim, keepdim=True)
@staticmethod
def forward(ctx, x, dim, consensus_type):
ctx.save_for_backward(x, dim, consensus_type)
if consensus_type == 'avg':
output = x.mean(dim=dim, keepdim=True)
else:
output = None
return output

def backward(self, grad_output):
if self.consensus_type == 'avg':
grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim])
@staticmethod
def backward(ctx, grad_output):
x, dim, consensus_type = ctx.saved_tensors
shape = x.size()
if consensus_type == 'avg':
grad_in = grad_output.expand(shape) / float(shape[dim])
else:
grad_in = None
return grad_in
Expand All @@ -46,4 +39,6 @@ def init_weights(self):
pass

def forward(self, input):
return _SimpleConsensus(self.consensus_type, self.dim)(input)
return _SimpleConsensus.apply(input,
self.dim,
self.consensus_type)

1 comment on commit b02523d

@ZienZhang6
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个修改还是有用的,虽然又出现了其他的问题,但新出现的问题目测跟这个问题没关系

Please sign in to comment.