from transformers import WhisperModel from torch import nn import torch from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas class WhisperAudioClassifier(nn.Module): def __init__(self): super(WhisperAudioClassifier, self).__init__() # Load the Whisper model encoder self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder() self.whisper_encoder.eval() # Set the Whisper model to evaluation mode # Assuming we know the output size of the Whisper encoder, or it needs to be determined whisper_output_size = 768 self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) self.pool = nn.AdaptiveAvgPool1d(1) # Separate output layers for each class set self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials)) self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei)) self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas)) self.tone_fc1 = nn.Linear(whisper_output_size, 6) self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials)) self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei)) self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas)) self.tone_fc2 = nn.Linear(whisper_output_size, 6) self.dropout = nn.Dropout(0.1) def forward(self, x): # Use Whisper model to encode audio input with torch.no_grad(): # No need to track gradients for the encoder x = self.whisper_encoder(x).last_hidden_state initial, _ = self.initial_attention(x, x, x, need_weights=False) initial = initial.permute(0, 2, 1) # [batch_size, channels, seq_len] initial = self.pool(initial) # [batch_size, channels, 1] initial = initial.squeeze(-1) # [batch_size, channels] initial_out1 = self.initial_fc1(initial) initial_out2 = self.initial_fc2(initial) nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False) nucleus = nucleus.permute(0, 2, 1) # [batch_size, channels, seq_len] nucleus = self.pool(nucleus) # [batch_size, channels, 1] nucleus = nucleus.squeeze(-1) # [batch_size, channels] nucleus_out1 = self.nucleus_fc1(nucleus) nucleus_out2 = self.nucleus_fc2(nucleus) coda, _ = self.coda_attention(x, x, x, need_weights=False) coda = coda.permute(0, 2, 1) # [batch_size, channels, seq_len] coda = self.pool(coda) # [batch_size, channels, 1] coda = coda.squeeze(-1) # [batch_size, channels] coda_out1 = self.coda_fc1(coda) coda_out2 = self.coda_fc2(coda) tone, _ = self.tone_attention(x, x, x, need_weights=False) tone = tone.permute(0, 2, 1) # [batch_size, channels, seq_len] tone = self.pool(tone) # [batch_size, channels, 1] tone = tone.squeeze(-1) # [batch_size, channels] tone_out1 = self.tone_fc1(tone) tone_out2 = self.tone_fc2(tone) return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2