File size: 3,785 Bytes
384e020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import MelSpectrogram

class FeatureExtractor(nn.Module):
    def __init__(self, n_mels=13, sample_rate=16000, frame_size_ms=20):
        super(FeatureExtractor, self).__init__()
        self.mel_spec = MelSpectrogram(
            sample_rate=sample_rate,
            n_mels=n_mels,
            win_length=int(sample_rate * frame_size_ms / 2000),
            hop_length=int(sample_rate * frame_size_ms / 2000),
            normalized=True
        )

    def forward(self, audio):
        # Convert to Mel spectrogram
        mel_features = self.mel_spec(audio)
        # Transpose to match Conv1d input shape (batch_size, n_mels, sequence_length)
        mel_features = mel_features.transpose(1, 2)
        return mel_features


# FrameLevelEmbedding and FrameLevelClassifier remain the same
class FrameLevelEmbedding(nn.Module):
    def __init__(self):
        super(FrameLevelEmbedding, self).__init__()
        self.cnn1 = nn.Conv1d(in_channels=13, out_channels=512, kernel_size=5, padding=2)
        self.res_blocks = nn.Sequential(*[ResBlock(512) for _ in range(6)])
        self.cnn2 = nn.Conv1d(in_channels=512, out_channels=240, kernel_size=1)

    def forward(self, x):
        x = x.transpose(1, 2)  # (batch_size, seq_len, features) -> (batch_size, features, seq_len)
        x = self.cnn1(x)
        x = self.res_blocks(x)
        x = self.cnn2(x)
        x = x.transpose(1, 2)  # (batch_size, features, seq_len) -> (batch_size, seq_len, features)
        return x

# Keep the other parts of the model unchanged (e.g., ResBlock, FrameLevelClassifier, BoundaryDetectionModel)
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(channels)
        self.bn2 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

class FrameLevelClassifier(nn.Module):
    def __init__(self):
        super(FrameLevelClassifier, self).__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=240, nhead=4, dim_feedforward=1024), num_layers=2
        )
        self.bilstm = nn.LSTM(input_size=240, hidden_size=128, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(256, 1)  # Bidirectional LSTM -> 2 * hidden_size

    def forward(self, x):
        # x = self.transformer(x)
        x, _ = self.bilstm(x)
        x = self.fc(x)
        return torch.sigmoid(x)


class BoundaryDetectionModel(nn.Module):
    def __init__(self):
        super(BoundaryDetectionModel, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.frame_embedding = FrameLevelEmbedding()
        self.classifier = FrameLevelClassifier()

    def forward(self, audio):
        features = self.feature_extractor(audio)
        embeddings = self.frame_embedding(features)
        output = self.classifier(embeddings)
        return output


# model = BoundaryDetectionModel()
# audio, sr = torchaudio.load("new_files/Extrinsic_Partial_Fakes/extrinsic_partial_fake_RFP_R_00001.wav")
# if sr != 16000:
#     resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
#     audio = resampler(audio)
# # audio = audio.mean(dim=0).unsqueeze(0)  # Convert to mono and add batch dimension
# output = model(audio)
# print(output.squeeze(2).shape)