From b02523d94d39f4439bc081d8097c054ca010d6ce Mon Sep 17 00:00:00 2001 From: Trenton Chang Date: Thu, 7 Jan 2021 22:57:01 -0800 Subject: [PATCH] Update _SimpleConsensus to use static autograd methods --- .../segmental_consensuses/simple_consensus.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py index 950fffb..2dddc48 100644 --- a/mmaction/models/tenons/segmental_consensuses/simple_consensus.py +++ b/mmaction/models/tenons/segmental_consensuses/simple_consensus.py @@ -1,33 +1,26 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ...registry import SEGMENTAL_CONSENSUSES class _SimpleConsensus(torch.autograd.Function): """Simplest segmental consensus module""" - def __init__(self, - consensus_type='avg', - dim=1): - super(_SimpleConsensus, self).__init__() - - assert consensus_type in ['avg'] - self.consensus_type = consensus_type - self.dim = dim - self.shape = None - - def forward(self, x): - self.shape = x.size() - if self.consensus_type == 'avg': - output = x.mean(dim=self.dim, keepdim=True) + @staticmethod + def forward(ctx, x, dim, consensus_type): + ctx.save_for_backward(x, dim, consensus_type) + if consensus_type == 'avg': + output = x.mean(dim=dim, keepdim=True) else: output = None return output - def backward(self, grad_output): - if self.consensus_type == 'avg': - grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) + @staticmethod + def backward(ctx, grad_output): + x, dim, consensus_type = ctx.saved_tensors + shape = x.size() + if consensus_type == 'avg': + grad_in = grad_output.expand(shape) / float(shape[dim]) else: grad_in = None return grad_in @@ -46,4 +39,6 @@ def init_weights(self): pass def forward(self, input): - return _SimpleConsensus(self.consensus_type, self.dim)(input) + return _SimpleConsensus.apply(input, + self.dim, + self.consensus_type)