Spaces:
Sleeping
Sleeping
File size: 6,459 Bytes
907b7f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
import torch.nn as nn
import math
import numpy as np
from lipreading.models.resnet import ResNet, BasicBlock
from lipreading.models.resnet1D import ResNet1D, BasicBlock1D
from lipreading.models.shufflenetv2 import ShuffleNetV2
from lipreading.models.tcn import MultibranchTemporalConvNet, TemporalConvNet
# -- auxiliary functions
def threeD_to_2D_tensor(x):
n_batch, n_channels, s_time, sx, sy = x.shape
x = x.transpose(1, 2)
return x.reshape(n_batch*s_time, n_channels, sx, sy)
def _average_batch(x, lengths, B):
return torch.stack( [torch.mean( x[index][:,0:i], 1 ) for index, i in enumerate(lengths)],0 )
class MultiscaleMultibranchTCN(nn.Module):
def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False):
super(MultiscaleMultibranchTCN, self).__init__()
self.kernel_sizes = tcn_options['kernel_size']
self.num_kernels = len( self.kernel_sizes )
self.mb_ms_tcn = MultibranchTemporalConvNet(input_size, num_channels, tcn_options, dropout=dropout, relu_type=relu_type, dwpw=dwpw)
self.tcn_output = nn.Linear(num_channels[-1], num_classes)
self.consensus_func = _average_batch
def forward(self, x, lengths, B):
# x needs to have dimension (N, C, L) in order to be passed into CNN
xtrans = x.transpose(1, 2)
out = self.mb_ms_tcn(xtrans)
out = self.consensus_func( out, lengths, B )
return self.tcn_output(out)
class TCN(nn.Module):
"""Implements Temporal Convolutional Network (TCN)
__https://arxiv.org/pdf/1803.01271.pdf
"""
def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False):
super(TCN, self).__init__()
self.tcn_trunk = TemporalConvNet(input_size, num_channels, dropout=dropout, tcn_options=tcn_options, relu_type=relu_type, dwpw=dwpw)
self.tcn_output = nn.Linear(num_channels[-1], num_classes)
self.consensus_func = _average_batch
self.has_aux_losses = False
def forward(self, x, lengths, B):
# x needs to have dimension (N, C, L) in order to be passed into CNN
x = self.tcn_trunk(x.transpose(1, 2))
x = self.consensus_func( x, lengths, B )
return self.tcn_output(x)
class Lipreading(nn.Module):
def __init__( self, modality='video', hidden_dim=256, backbone_type='resnet', num_classes=30,
relu_type='prelu', tcn_options={}, width_mult=1.0, extract_feats=False):
super(Lipreading, self).__init__()
self.extract_feats = extract_feats
self.backbone_type = backbone_type
self.modality = modality
if self.modality == 'raw_audio':
self.frontend_nout = 1
self.backend_out = 512
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type)
elif self.modality == 'video':
if self.backbone_type == 'resnet':
self.frontend_nout = 64
self.backend_out = 512
self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
elif self.backbone_type == 'shufflenet':
assert width_mult in [0.5, 1.0, 1.5, 2.0], "Width multiplier not correct"
shufflenet = ShuffleNetV2( input_size=96, width_mult=width_mult)
self.trunk = nn.Sequential( shufflenet.features, shufflenet.conv_last, shufflenet.globalpool)
self.frontend_nout = 24
self.backend_out = 1024 if width_mult != 2.0 else 2048
self.stage_out_channels = shufflenet.stage_out_channels[-1]
frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU()
self.frontend3D = nn.Sequential(
nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
nn.BatchNorm3d(self.frontend_nout),
frontend_relu,
nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
else:
raise NotImplementedError
tcn_class = TCN if len(tcn_options['kernel_size']) == 1 else MultiscaleMultibranchTCN
self.tcn = tcn_class( input_size=self.backend_out,
num_channels=[hidden_dim*len(tcn_options['kernel_size'])*tcn_options['width_mult']]*tcn_options['num_layers'],
num_classes=num_classes,
tcn_options=tcn_options,
dropout=tcn_options['dropout'],
relu_type=relu_type,
dwpw=tcn_options['dwpw'],
)
# -- initialize
self._initialize_weights_randomly()
def forward(self, x, lengths):
if self.modality == 'video':
B, C, T, H, W = x.size()
x = self.frontend3D(x)
Tnew = x.shape[2] # output should be B x C2 x Tnew x H x W
x = threeD_to_2D_tensor( x )
x = self.trunk(x)
if self.backbone_type == 'shufflenet':
x = x.view(-1, self.stage_out_channels)
x = x.view(B, Tnew, x.size(1))
elif self.modality == 'raw_audio':
B, C, T = x.size()
x = self.trunk(x)
x = x.transpose(1, 2)
lengths = [_//640 for _ in lengths]
return x if self.extract_feats else self.tcn(x, lengths, B)
def _initialize_weights_randomly(self):
use_sqrt = True
if use_sqrt:
def f(n):
return math.sqrt( 2.0/float(n) )
else:
def f(n):
return 2.0/float(n)
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
n = np.prod( m.kernel_size ) * m.out_channels
m.weight.data.normal_(0, f(n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = float(m.weight.data[0].nelement())
m.weight.data = m.weight.data.normal_(0, f(n))
|