Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright 2024 Wen-Chin Huang | |
# MIT License (https://opensource.org/licenses/MIT) | |
# SSLMOS model | |
# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper) | |
import torch | |
import torch.nn as nn | |
from .modules import Projection | |
class SSLMOS(torch.nn.Module): | |
def __init__( | |
self, | |
# dummy, for signature need | |
model_input: str, | |
# model related | |
ssl_module: str, | |
s3prl_name: str, | |
ssl_model_output_dim: int, | |
ssl_model_layer_idx: int, | |
# mean net related | |
mean_net_dnn_dim: int = 64, | |
mean_net_output_type: str = "scalar", | |
mean_net_output_dim: int = 5, | |
mean_net_output_step: float = 0.25, | |
mean_net_range_clipping: bool = True, | |
# listener related | |
use_listener_modeling: bool = False, | |
num_listeners: int = None, | |
listener_emb_dim: int = None, | |
use_mean_listener: bool = True, | |
# decoder related | |
decoder_type: str = "ffn", | |
decoder_dnn_dim: int = 64, | |
output_type: str = "scalar", | |
range_clipping: bool = True, | |
# dummy | |
num_domains: int = None, | |
): | |
super().__init__() # this is needed! or else there will be an error. | |
self.use_mean_listener = use_mean_listener | |
self.output_type = output_type | |
# define listener embedding | |
self.use_listener_modeling = use_listener_modeling | |
# define ssl model | |
if ssl_module == "s3prl": | |
from s3prl.nn import S3PRLUpstream | |
if s3prl_name in S3PRLUpstream.available_names(): | |
self.ssl_model = S3PRLUpstream(s3prl_name) | |
self.ssl_model_layer_idx = ssl_model_layer_idx | |
else: | |
raise NotImplementedError | |
# default uses ffn type mean net | |
self.mean_net_dnn = Projection( | |
ssl_model_output_dim, | |
mean_net_dnn_dim, | |
nn.ReLU, | |
mean_net_output_type, | |
mean_net_output_dim, | |
mean_net_output_step, | |
mean_net_range_clipping, | |
) | |
# listener modeling related | |
self.use_listener_modeling = use_listener_modeling | |
if use_listener_modeling: | |
self.num_listeners = num_listeners | |
self.listener_embeddings = nn.Embedding( | |
num_embeddings=num_listeners, embedding_dim=listener_emb_dim | |
) | |
# define decoder | |
self.decoder_type = decoder_type | |
if decoder_type == "ffn": | |
decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim | |
else: | |
raise NotImplementedError | |
# there is always dnn | |
self.decoder_dnn = Projection( | |
decoder_dnn_input_dim, | |
decoder_dnn_dim, | |
self.activation, | |
output_type, | |
range_clipping, | |
) | |
def get_num_params(self): | |
return sum(p.numel() for n, p in self.named_parameters()) | |
def forward(self, inputs): | |
"""Calculate forward propagation. | |
Args: | |
waveform has shape (batch, time) | |
waveform_lengths has shape (batch) | |
listener_ids has shape (batch) | |
""" | |
waveform = inputs["waveform"] | |
waveform_lengths = inputs["waveform_lengths"] | |
batch, time = waveform.shape | |
# get listener embedding | |
if self.use_listener_modeling: | |
listener_ids = inputs["listener_idxs"] | |
# NOTE(unlight): not tested yet | |
listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) | |
listener_embs = torch.stack( | |
[listener_embs for i in range(time)], dim=1 | |
) # (batch, time, feat_dim) | |
# ssl model forward | |
all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( | |
waveform, waveform_lengths | |
) | |
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] | |
# inject listener embedding | |
if self.use_listener_modeling: | |
# NOTE(unlight): not tested yet | |
encoder_outputs = encoder_outputs.view( | |
(batch, time, -1) | |
) # (batch, time, feat_dim) | |
decoder_inputs = torch.cat( | |
[encoder_outputs, listener_embs], dim=-1 | |
) # concat along feature dimension | |
else: | |
decoder_inputs = encoder_outputs | |
# masked mean pooling | |
# masks = make_non_pad_mask(encoder_outputs_lens) | |
# masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] | |
# decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) | |
# mean net | |
mean_net_outputs = self.mean_net_dnn( | |
decoder_inputs | |
) # [batch, time, 1 (scalar) / 5 (categorical)] | |
# decoder | |
if self.use_listener_modeling: | |
if self.decoder_type == "rnn": | |
decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) | |
else: | |
decoder_outputs = decoder_inputs | |
decoder_outputs = self.decoder_dnn( | |
decoder_outputs | |
) # [batch, time, 1 (scalar) / 5 (categorical)] | |
# set outputs | |
# return lengths for masked loss calculation | |
ret = { | |
"waveform_lengths": waveform_lengths, | |
"frame_lengths": encoder_outputs_lens, | |
} | |
# define scores | |
ret["mean_scores"] = mean_net_outputs | |
ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None | |
return ret | |
def mean_net_inference(self, inputs): | |
waveform = inputs["waveform"] | |
waveform_lengths = inputs["waveform_lengths"] | |
# ssl model forward | |
all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( | |
waveform, waveform_lengths | |
) | |
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
# mean net | |
decoder_inputs = encoder_outputs | |
mean_net_outputs = self.mean_net_dnn( | |
decoder_inputs, inference=True | |
) # [batch, time, 1 (scalar) / 5 (categorical)] | |
mean_net_outputs = mean_net_outputs.squeeze(-1) | |
scores = torch.mean(mean_net_outputs, dim=1) # [batch] | |
return { | |
"ssl_embeddings": encoder_outputs, | |
"scores": scores | |
} | |
def mean_net_inference_p1(self, waveform, waveform_lengths): | |
# ssl model forward | |
all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths) | |
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
return encoder_outputs | |
def mean_net_inference_p2(self, encoder_outputs): | |
# mean net | |
mean_net_outputs = self.mean_net_dnn( | |
encoder_outputs | |
) # [batch, time, 1 (scalar) / 5 (categorical)] | |
mean_net_outputs = mean_net_outputs.squeeze(-1) | |
scores = torch.mean(mean_net_outputs, dim=1) | |
return scores | |
def get_ssl_embeddings(self, inputs): | |
waveform = inputs["waveform"] | |
waveform_lengths = inputs["waveform_lengths"] | |
all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( | |
waveform, waveform_lengths | |
) | |
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
return encoder_outputs |