import torch import torch.nn as nn import torchaudio from torchaudio.transforms import MelSpectrogram class FeatureExtractor(nn.Module): def __init__(self, n_mels=13, sample_rate=16000, frame_size_ms=20): super(FeatureExtractor, self).__init__() self.mel_spec = MelSpectrogram( sample_rate=sample_rate, n_mels=n_mels, win_length=int(sample_rate * frame_size_ms / 2000), hop_length=int(sample_rate * frame_size_ms / 2000), normalized=True ) def forward(self, audio): # Convert to Mel spectrogram mel_features = self.mel_spec(audio) # Transpose to match Conv1d input shape (batch_size, n_mels, sequence_length) mel_features = mel_features.transpose(1, 2) return mel_features # FrameLevelEmbedding and FrameLevelClassifier remain the same class FrameLevelEmbedding(nn.Module): def __init__(self): super(FrameLevelEmbedding, self).__init__() self.cnn1 = nn.Conv1d(in_channels=13, out_channels=512, kernel_size=5, padding=2) self.res_blocks = nn.Sequential(*[ResBlock(512) for _ in range(6)]) self.cnn2 = nn.Conv1d(in_channels=512, out_channels=240, kernel_size=1) def forward(self, x): x = x.transpose(1, 2) # (batch_size, seq_len, features) -> (batch_size, features, seq_len) x = self.cnn1(x) x = self.res_blocks(x) x = self.cnn2(x) x = x.transpose(1, 2) # (batch_size, features, seq_len) -> (batch_size, seq_len, features) return x # Keep the other parts of the model unchanged (e.g., ResBlock, FrameLevelClassifier, BoundaryDetectionModel) class ResBlock(nn.Module): def __init__(self, channels): super(ResBlock, self).__init__() self.conv1 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm1d(channels) self.bn2 = nn.BatchNorm1d(channels) self.relu = nn.ReLU() def forward(self, x): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += identity return self.relu(out) class FrameLevelClassifier(nn.Module): def __init__(self): super(FrameLevelClassifier, self).__init__() self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=240, nhead=4, dim_feedforward=1024), num_layers=2 ) self.bilstm = nn.LSTM(input_size=240, hidden_size=128, num_layers=2, bidirectional=True, batch_first=True) self.fc = nn.Linear(256, 1) # Bidirectional LSTM -> 2 * hidden_size def forward(self, x): # x = self.transformer(x) x, _ = self.bilstm(x) x = self.fc(x) return torch.sigmoid(x) class BoundaryDetectionModel(nn.Module): def __init__(self): super(BoundaryDetectionModel, self).__init__() self.feature_extractor = FeatureExtractor() self.frame_embedding = FrameLevelEmbedding() self.classifier = FrameLevelClassifier() def forward(self, audio): features = self.feature_extractor(audio) embeddings = self.frame_embedding(features) output = self.classifier(embeddings) return output # model = BoundaryDetectionModel() # audio, sr = torchaudio.load("new_files/Extrinsic_Partial_Fakes/extrinsic_partial_fake_RFP_R_00001.wav") # if sr != 16000: # resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) # audio = resampler(audio) # # audio = audio.mean(dim=0).unsqueeze(0) # Convert to mono and add batch dimension # output = model(audio) # print(output.squeeze(2).shape)