import torch import torch.nn as nn import math import numpy as np from lipreading.models.resnet import ResNet, BasicBlock from lipreading.models.resnet1D import ResNet1D, BasicBlock1D from lipreading.models.shufflenetv2 import ShuffleNetV2 from lipreading.models.tcn import MultibranchTemporalConvNet, TemporalConvNet # -- auxiliary functions def threeD_to_2D_tensor(x): n_batch, n_channels, s_time, sx, sy = x.shape x = x.transpose(1, 2) return x.reshape(n_batch*s_time, n_channels, sx, sy) def _average_batch(x, lengths, B): return torch.stack( [torch.mean( x[index][:,0:i], 1 ) for index, i in enumerate(lengths)],0 ) class MultiscaleMultibranchTCN(nn.Module): def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False): super(MultiscaleMultibranchTCN, self).__init__() self.kernel_sizes = tcn_options['kernel_size'] self.num_kernels = len( self.kernel_sizes ) self.mb_ms_tcn = MultibranchTemporalConvNet(input_size, num_channels, tcn_options, dropout=dropout, relu_type=relu_type, dwpw=dwpw) self.tcn_output = nn.Linear(num_channels[-1], num_classes) self.consensus_func = _average_batch def forward(self, x, lengths, B): # x needs to have dimension (N, C, L) in order to be passed into CNN xtrans = x.transpose(1, 2) out = self.mb_ms_tcn(xtrans) out = self.consensus_func( out, lengths, B ) return self.tcn_output(out) class TCN(nn.Module): """Implements Temporal Convolutional Network (TCN) __https://arxiv.org/pdf/1803.01271.pdf """ def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False): super(TCN, self).__init__() self.tcn_trunk = TemporalConvNet(input_size, num_channels, dropout=dropout, tcn_options=tcn_options, relu_type=relu_type, dwpw=dwpw) self.tcn_output = nn.Linear(num_channels[-1], num_classes) self.consensus_func = _average_batch self.has_aux_losses = False def forward(self, x, lengths, B): # x needs to have dimension (N, C, L) in order to be passed into CNN x = self.tcn_trunk(x.transpose(1, 2)) x = self.consensus_func( x, lengths, B ) return self.tcn_output(x) class Lipreading(nn.Module): def __init__( self, modality='video', hidden_dim=256, backbone_type='resnet', num_classes=30, relu_type='prelu', tcn_options={}, width_mult=1.0, extract_feats=False): super(Lipreading, self).__init__() self.extract_feats = extract_feats self.backbone_type = backbone_type self.modality = modality if self.modality == 'raw_audio': self.frontend_nout = 1 self.backend_out = 512 self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type) elif self.modality == 'video': if self.backbone_type == 'resnet': self.frontend_nout = 64 self.backend_out = 512 self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) elif self.backbone_type == 'shufflenet': assert width_mult in [0.5, 1.0, 1.5, 2.0], "Width multiplier not correct" shufflenet = ShuffleNetV2( input_size=96, width_mult=width_mult) self.trunk = nn.Sequential( shufflenet.features, shufflenet.conv_last, shufflenet.globalpool) self.frontend_nout = 24 self.backend_out = 1024 if width_mult != 2.0 else 2048 self.stage_out_channels = shufflenet.stage_out_channels[-1] frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU() self.frontend3D = nn.Sequential( nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), nn.BatchNorm3d(self.frontend_nout), frontend_relu, nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))) else: raise NotImplementedError tcn_class = TCN if len(tcn_options['kernel_size']) == 1 else MultiscaleMultibranchTCN self.tcn = tcn_class( input_size=self.backend_out, num_channels=[hidden_dim*len(tcn_options['kernel_size'])*tcn_options['width_mult']]*tcn_options['num_layers'], num_classes=num_classes, tcn_options=tcn_options, dropout=tcn_options['dropout'], relu_type=relu_type, dwpw=tcn_options['dwpw'], ) # -- initialize self._initialize_weights_randomly() def forward(self, x, lengths): if self.modality == 'video': B, C, T, H, W = x.size() x = self.frontend3D(x) Tnew = x.shape[2] # output should be B x C2 x Tnew x H x W x = threeD_to_2D_tensor( x ) x = self.trunk(x) if self.backbone_type == 'shufflenet': x = x.view(-1, self.stage_out_channels) x = x.view(B, Tnew, x.size(1)) elif self.modality == 'raw_audio': B, C, T = x.size() x = self.trunk(x) x = x.transpose(1, 2) lengths = [_//640 for _ in lengths] return x if self.extract_feats else self.tcn(x, lengths, B) def _initialize_weights_randomly(self): use_sqrt = True if use_sqrt: def f(n): return math.sqrt( 2.0/float(n) ) else: def f(n): return 2.0/float(n) for m in self.modules(): if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): n = np.prod( m.kernel_size ) * m.out_channels m.weight.data.normal_(0, f(n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): n = float(m.weight.data[0].nelement()) m.weight.data = m.weight.data.normal_(0, f(n))