Spaces:
Sleeping
Sleeping
from transformers import WhisperModel | |
from torch import nn | |
import torch | |
from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas | |
class WhisperAudioClassifier(nn.Module): | |
def __init__(self): | |
super(WhisperAudioClassifier, self).__init__() | |
# Load the Whisper model encoder | |
self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder() | |
self.whisper_encoder.eval() # Set the Whisper model to evaluation mode | |
# Assuming we know the output size of the Whisper encoder, or it needs to be determined | |
whisper_output_size = 768 | |
self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
self.pool = nn.AdaptiveAvgPool1d(1) | |
# Separate output layers for each class set | |
self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials)) | |
self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei)) | |
self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas)) | |
self.tone_fc1 = nn.Linear(whisper_output_size, 6) | |
self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials)) | |
self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei)) | |
self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas)) | |
self.tone_fc2 = nn.Linear(whisper_output_size, 6) | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, x): | |
# Use Whisper model to encode audio input | |
with torch.no_grad(): # No need to track gradients for the encoder | |
x = self.whisper_encoder(x).last_hidden_state | |
initial, _ = self.initial_attention(x, x, x, need_weights=False) | |
initial = initial.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
initial = self.pool(initial) # [batch_size, channels, 1] | |
initial = initial.squeeze(-1) # [batch_size, channels] | |
initial_out1 = self.initial_fc1(initial) | |
initial_out2 = self.initial_fc2(initial) | |
nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False) | |
nucleus = nucleus.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
nucleus = self.pool(nucleus) # [batch_size, channels, 1] | |
nucleus = nucleus.squeeze(-1) # [batch_size, channels] | |
nucleus_out1 = self.nucleus_fc1(nucleus) | |
nucleus_out2 = self.nucleus_fc2(nucleus) | |
coda, _ = self.coda_attention(x, x, x, need_weights=False) | |
coda = coda.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
coda = self.pool(coda) # [batch_size, channels, 1] | |
coda = coda.squeeze(-1) # [batch_size, channels] | |
coda_out1 = self.coda_fc1(coda) | |
coda_out2 = self.coda_fc2(coda) | |
tone, _ = self.tone_attention(x, x, x, need_weights=False) | |
tone = tone.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
tone = self.pool(tone) # [batch_size, channels, 1] | |
tone = tone.squeeze(-1) # [batch_size, channels] | |
tone_out1 = self.tone_fc1(tone) | |
tone_out2 = self.tone_fc2(tone) | |
return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2 | |