import torch from torch import nn from transformers.activations import ACT2FN class Conv2dFeatureExtractor(nn.Module): def __init__(self, config): super().__init__() self.conv = torch.nn.Sequential( *[ nn.Sequential( nn.Conv2d( conv_in, out_channels=conv_out, kernel_size=(conv_kernel, conv_kernel), stride=(conv_stride, conv_stride), ), ACT2FN[config.feat_extract_activation], ) for conv_in, conv_out, conv_kernel, conv_stride in zip( [1, *config.conv_dim], config.conv_dim, config.conv_kernel, config.conv_stride ) ], ) linear_in_dim = config.conv_dim[-1] * (((config.second_dim_input_size - 1) // 2 - 1) // 2) self.out = torch.nn.Linear(linear_in_dim, config.hidden_size, bias=True) def forward(self, input_values: torch.Tensor) -> torch.Tensor: hidden_states = self.conv(input_values[:, None, ...]) hidden_states = self.out(hidden_states.transpose(1, 2).flatten(2, 3)) return hidden_states.transpose(1, 2)