Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Code reference from "Temporal Interlacing Network" | |
# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py | |
# Hao Shao, Shengju Qian, Yu Liu | |
# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext('_ext', | |
['tin_shift_forward', 'tin_shift_backward']) | |
class TINShiftFunction(Function): | |
def forward(ctx, input, shift): | |
C = input.size(2) | |
num_segments = shift.size(1) | |
if C // num_segments <= 0 or C % num_segments != 0: | |
raise ValueError('C should be a multiple of num_segments, ' | |
f'but got C={C} and num_segments={num_segments}.') | |
ctx.save_for_backward(shift) | |
out = torch.zeros_like(input) | |
ext_module.tin_shift_forward(input, shift, out) | |
return out | |
def backward(ctx, grad_output): | |
shift = ctx.saved_tensors[0] | |
data_grad_input = grad_output.new(*grad_output.size()).zero_() | |
shift_grad_input = shift.new(*shift.size()).zero_() | |
ext_module.tin_shift_backward(grad_output, shift, data_grad_input) | |
return data_grad_input, shift_grad_input | |
tin_shift = TINShiftFunction.apply | |
class TINShift(nn.Module): | |
"""Temporal Interlace Shift. | |
Temporal Interlace shift is a differentiable temporal-wise frame shifting | |
which is proposed in "Temporal Interlacing Network" | |
Please refer to https://arxiv.org/abs/2001.06499 for more details. | |
Code is modified from https://github.com/mit-han-lab/temporal-shift-module | |
""" | |
def forward(self, input, shift): | |
"""Perform temporal interlace shift. | |
Args: | |
input (Tensor): Feature map with shape [N, num_segments, C, H * W]. | |
shift (Tensor): Shift tensor with shape [N, num_segments]. | |
Returns: | |
Feature map after temporal interlace shift. | |
""" | |
return tin_shift(input, shift) | |