AudioSpoofing / model.py
ujalaarshad17's picture
Added files
384e020
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import MelSpectrogram
class FeatureExtractor(nn.Module):
def __init__(self, n_mels=13, sample_rate=16000, frame_size_ms=20):
super(FeatureExtractor, self).__init__()
self.mel_spec = MelSpectrogram(
sample_rate=sample_rate,
n_mels=n_mels,
win_length=int(sample_rate * frame_size_ms / 2000),
hop_length=int(sample_rate * frame_size_ms / 2000),
normalized=True
)
def forward(self, audio):
# Convert to Mel spectrogram
mel_features = self.mel_spec(audio)
# Transpose to match Conv1d input shape (batch_size, n_mels, sequence_length)
mel_features = mel_features.transpose(1, 2)
return mel_features
# FrameLevelEmbedding and FrameLevelClassifier remain the same
class FrameLevelEmbedding(nn.Module):
def __init__(self):
super(FrameLevelEmbedding, self).__init__()
self.cnn1 = nn.Conv1d(in_channels=13, out_channels=512, kernel_size=5, padding=2)
self.res_blocks = nn.Sequential(*[ResBlock(512) for _ in range(6)])
self.cnn2 = nn.Conv1d(in_channels=512, out_channels=240, kernel_size=1)
def forward(self, x):
x = x.transpose(1, 2) # (batch_size, seq_len, features) -> (batch_size, features, seq_len)
x = self.cnn1(x)
x = self.res_blocks(x)
x = self.cnn2(x)
x = x.transpose(1, 2) # (batch_size, features, seq_len) -> (batch_size, seq_len, features)
return x
# Keep the other parts of the model unchanged (e.g., ResBlock, FrameLevelClassifier, BoundaryDetectionModel)
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(channels)
self.bn2 = nn.BatchNorm1d(channels)
self.relu = nn.ReLU()
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += identity
return self.relu(out)
class FrameLevelClassifier(nn.Module):
def __init__(self):
super(FrameLevelClassifier, self).__init__()
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=240, nhead=4, dim_feedforward=1024), num_layers=2
)
self.bilstm = nn.LSTM(input_size=240, hidden_size=128, num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, 1) # Bidirectional LSTM -> 2 * hidden_size
def forward(self, x):
# x = self.transformer(x)
x, _ = self.bilstm(x)
x = self.fc(x)
return torch.sigmoid(x)
class BoundaryDetectionModel(nn.Module):
def __init__(self):
super(BoundaryDetectionModel, self).__init__()
self.feature_extractor = FeatureExtractor()
self.frame_embedding = FrameLevelEmbedding()
self.classifier = FrameLevelClassifier()
def forward(self, audio):
features = self.feature_extractor(audio)
embeddings = self.frame_embedding(features)
output = self.classifier(embeddings)
return output
# model = BoundaryDetectionModel()
# audio, sr = torchaudio.load("new_files/Extrinsic_Partial_Fakes/extrinsic_partial_fake_RFP_R_00001.wav")
# if sr != 16000:
# resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
# audio = resampler(audio)
# # audio = audio.mean(dim=0).unsqueeze(0) # Convert to mono and add batch dimension
# output = model(audio)
# print(output.squeeze(2).shape)