|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import Wav2Vec2Config |
|
|
|
from .torch_utils import get_mask_from_lengths |
|
from .wav2vec2 import Wav2Vec2Model |
|
|
|
|
|
class Audio2MeshModel(nn.Module): |
|
def __init__( |
|
self, |
|
config |
|
): |
|
super().__init__() |
|
out_dim = config['out_dim'] |
|
latent_dim = config['latent_dim'] |
|
model_path = config['model_path'] |
|
only_last_fetures = config['only_last_fetures'] |
|
from_pretrained = config['from_pretrained'] |
|
|
|
self._only_last_features = only_last_fetures |
|
|
|
self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) |
|
if from_pretrained: |
|
self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) |
|
else: |
|
self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) |
|
self.audio_encoder.feature_extractor._freeze_parameters() |
|
|
|
hidden_size = self.audio_encoder_config.hidden_size |
|
|
|
self.in_fn = nn.Linear(hidden_size, latent_dim) |
|
|
|
self.out_fn = nn.Linear(latent_dim, out_dim) |
|
nn.init.constant_(self.out_fn.weight, 0) |
|
nn.init.constant_(self.out_fn.bias, 0) |
|
|
|
def forward(self, audio, label, audio_len=None): |
|
attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None |
|
|
|
seq_len = label.shape[1] |
|
|
|
embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True, |
|
attention_mask=attention_mask) |
|
|
|
if self._only_last_features: |
|
hidden_states = embeddings.last_hidden_state |
|
else: |
|
hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) |
|
|
|
layer_in = self.in_fn(hidden_states) |
|
out = self.out_fn(layer_in) |
|
|
|
return out, None |
|
|
|
def infer(self, input_value, seq_len): |
|
embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) |
|
|
|
if self._only_last_features: |
|
hidden_states = embeddings.last_hidden_state |
|
else: |
|
hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) |
|
|
|
layer_in = self.in_fn(hidden_states) |
|
out = self.out_fn(layer_in) |
|
|
|
return out |
|
|
|
|
|
|