Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchaudio | |
from torchaudio.transforms import MelSpectrogram | |
class FeatureExtractor(nn.Module): | |
def __init__(self, n_mels=13, sample_rate=16000, frame_size_ms=20): | |
super(FeatureExtractor, self).__init__() | |
self.mel_spec = MelSpectrogram( | |
sample_rate=sample_rate, | |
n_mels=n_mels, | |
win_length=int(sample_rate * frame_size_ms / 2000), | |
hop_length=int(sample_rate * frame_size_ms / 2000), | |
normalized=True | |
) | |
def forward(self, audio): | |
# Convert to Mel spectrogram | |
mel_features = self.mel_spec(audio) | |
# Transpose to match Conv1d input shape (batch_size, n_mels, sequence_length) | |
mel_features = mel_features.transpose(1, 2) | |
return mel_features | |
# FrameLevelEmbedding and FrameLevelClassifier remain the same | |
class FrameLevelEmbedding(nn.Module): | |
def __init__(self): | |
super(FrameLevelEmbedding, self).__init__() | |
self.cnn1 = nn.Conv1d(in_channels=13, out_channels=512, kernel_size=5, padding=2) | |
self.res_blocks = nn.Sequential(*[ResBlock(512) for _ in range(6)]) | |
self.cnn2 = nn.Conv1d(in_channels=512, out_channels=240, kernel_size=1) | |
def forward(self, x): | |
x = x.transpose(1, 2) # (batch_size, seq_len, features) -> (batch_size, features, seq_len) | |
x = self.cnn1(x) | |
x = self.res_blocks(x) | |
x = self.cnn2(x) | |
x = x.transpose(1, 2) # (batch_size, features, seq_len) -> (batch_size, seq_len, features) | |
return x | |
# Keep the other parts of the model unchanged (e.g., ResBlock, FrameLevelClassifier, BoundaryDetectionModel) | |
class ResBlock(nn.Module): | |
def __init__(self, channels): | |
super(ResBlock, self).__init__() | |
self.conv1 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False) | |
self.conv2 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm1d(channels) | |
self.bn2 = nn.BatchNorm1d(channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
identity = x | |
out = self.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += identity | |
return self.relu(out) | |
class FrameLevelClassifier(nn.Module): | |
def __init__(self): | |
super(FrameLevelClassifier, self).__init__() | |
self.transformer = nn.TransformerEncoder( | |
nn.TransformerEncoderLayer(d_model=240, nhead=4, dim_feedforward=1024), num_layers=2 | |
) | |
self.bilstm = nn.LSTM(input_size=240, hidden_size=128, num_layers=2, bidirectional=True, batch_first=True) | |
self.fc = nn.Linear(256, 1) # Bidirectional LSTM -> 2 * hidden_size | |
def forward(self, x): | |
# x = self.transformer(x) | |
x, _ = self.bilstm(x) | |
x = self.fc(x) | |
return torch.sigmoid(x) | |
class BoundaryDetectionModel(nn.Module): | |
def __init__(self): | |
super(BoundaryDetectionModel, self).__init__() | |
self.feature_extractor = FeatureExtractor() | |
self.frame_embedding = FrameLevelEmbedding() | |
self.classifier = FrameLevelClassifier() | |
def forward(self, audio): | |
features = self.feature_extractor(audio) | |
embeddings = self.frame_embedding(features) | |
output = self.classifier(embeddings) | |
return output | |
# model = BoundaryDetectionModel() | |
# audio, sr = torchaudio.load("new_files/Extrinsic_Partial_Fakes/extrinsic_partial_fake_RFP_R_00001.wav") | |
# if sr != 16000: | |
# resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
# audio = resampler(audio) | |
# # audio = audio.mean(dim=0).unsqueeze(0) # Convert to mono and add batch dimension | |
# output = model(audio) | |
# print(output.squeeze(2).shape) |