Spaces:
Running
Running
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import os | |
import numpy as np | |
import hydra | |
from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection | |
class BaselineLightningModule(pl.LightningModule): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.construct_model() | |
self.save_hyperparameters() | |
def construct_model(self): | |
self.feature_extractors = nn.ModuleList([ | |
load_ssl_model(cp_path='wav2vec_small.pt'), | |
DomainEmbedding(3,128), | |
]) | |
output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) | |
output_layers = [ | |
LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim) | |
] | |
output_dim = output_layers[-1].get_output_dim() | |
output_layers.append( | |
Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim) | |
) | |
self.output_layers = nn.ModuleList(output_layers) | |
def forward(self, inputs): | |
outputs = {} | |
for feature_extractor in self.feature_extractors: | |
outputs.update(feature_extractor(inputs)) | |
x = outputs | |
for output_layer in self.output_layers: | |
x = output_layer(x,inputs) | |
return x | |