|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers import ( |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
WavLMModel, |
|
) |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
|
|
|
|
|
|
class WavLMRawNetSVBaseConfig(PretrainedConfig): |
|
model_type = "wavlm-rawnet-sv-base" |
|
|
|
def __init__( |
|
self, |
|
sample_rate=16000, |
|
sinc_out_channels=128, |
|
sinc_kernel_size=251, |
|
sinc_in_channels=1, |
|
rawnet_embedding_dim=256, |
|
use_se=True, |
|
num_blocks=4, |
|
pooling_attention_dim=128, |
|
wavlm_path="microsoft/wavlm-large", |
|
wavlm_embedding_dim=1024, |
|
fusion_embedding_dim=256, |
|
**kwargs |
|
): |
|
""" |
|
Transformers tabanlı bir config sınıfı. |
|
|
|
Parametreleri: |
|
-------------- |
|
- sample_rate : int |
|
- sinc_out_channels : int |
|
- sinc_kernel_size : int |
|
- sinc_in_channels : int |
|
- rawnet_embedding_dim : int |
|
- use_se : bool (SEBlock kullanılsın mı?) |
|
- num_blocks : int (RawNet içindeki residual block sayısı) |
|
- pooling_attention_dim : int (AttentiveStatsPooling için attention boyutu) |
|
- wavlm_path : str (örn: "microsoft/wavlm-large") |
|
- wavlm_embedding_dim : int (WavLM last_hidden_state boyutu, genelde 1024) |
|
- fusion_embedding_dim : int (WavLM + RawNet birleştikten sonraki final boyut) |
|
""" |
|
super().__init__(**kwargs) |
|
|
|
self.sample_rate = sample_rate |
|
|
|
self.sinc_out_channels = sinc_out_channels |
|
self.sinc_kernel_size = sinc_kernel_size |
|
self.sinc_in_channels = sinc_in_channels |
|
|
|
self.rawnet_embedding_dim = rawnet_embedding_dim |
|
self.use_se = use_se |
|
self.num_blocks = num_blocks |
|
self.pooling_attention_dim = pooling_attention_dim |
|
|
|
self.wavlm_path = wavlm_path |
|
self.wavlm_embedding_dim = wavlm_embedding_dim |
|
|
|
self.fusion_embedding_dim = fusion_embedding_dim |
|
|
|
|
|
|
|
|
|
|
|
class WavLMRawNetSVBase(PreTrainedModel): |
|
config_class = WavLMRawNetSVBaseConfig |
|
base_model_prefix = "wavlm_rawnet_sv_base" |
|
|
|
def __init__(self, config: WavLMRawNetSVBaseConfig): |
|
""" |
|
WavLM + RawNet mimarisini PreTrainedModel çatısı altında toplar. |
|
""" |
|
super().__init__(config) |
|
|
|
|
|
|
|
self.wavlm_model = WavLMModel.from_pretrained(config.wavlm_path) |
|
|
|
|
|
|
|
self.sinc_conv = SincConv( |
|
out_channels=config.sinc_out_channels, |
|
kernel_size=config.sinc_kernel_size, |
|
sample_rate=config.sample_rate, |
|
in_channels=config.sinc_in_channels, |
|
padding=config.sinc_kernel_size // 2 |
|
) |
|
|
|
self.first_norm = nn.InstanceNorm1d(config.sinc_out_channels, affine=True) |
|
self.first_act = nn.Mish() |
|
|
|
self.blocks = ResidualStack( |
|
channels=config.sinc_out_channels, |
|
kernel_size=3, |
|
dilation=1, |
|
use_se=config.use_se, |
|
num_blocks=config.num_blocks |
|
) |
|
|
|
self.pooling_layer = AttentiveStatsPooling( |
|
in_dim=config.sinc_out_channels, |
|
attention_dim=config.pooling_attention_dim |
|
) |
|
|
|
self.emb_fc = nn.Linear(config.sinc_out_channels * 2, config.rawnet_embedding_dim) |
|
self.emb_act = nn.ReLU() |
|
|
|
|
|
|
|
self.fusion_fc = nn.Linear( |
|
config.wavlm_embedding_dim + config.rawnet_embedding_dim, |
|
config.fusion_embedding_dim |
|
) |
|
self.fusion_act = nn.ReLU() |
|
|
|
def forward( |
|
self, |
|
input_values: torch.Tensor, |
|
**kwargs |
|
): |
|
""" |
|
input_values: (batch_size, time), float32 dalga formu |
|
|
|
Dönüş: BaseModelOutput with last_hidden_state = (batch_size, 1, fusion_embedding_dim) |
|
""" |
|
|
|
wavlm_out = self.wavlm_model(input_values, **kwargs) |
|
|
|
wavlm_emb = wavlm_out.last_hidden_state.mean(dim=1) |
|
|
|
|
|
x = input_values.unsqueeze(1) |
|
x = self.sinc_conv(x) |
|
x = self.first_norm(x) |
|
x = self.first_act(x) |
|
x = self.blocks(x) |
|
x = self.pooling_layer(x) |
|
x = x.squeeze(-1) |
|
x = self.emb_fc(x) |
|
x = self.emb_act(x) |
|
rawnet_emb = x |
|
|
|
|
|
fused = torch.cat([wavlm_emb, rawnet_emb], dim=1) |
|
fused = self.fusion_fc(fused) |
|
fused = self.fusion_act(fused) |
|
|
|
return BaseModelOutput( |
|
last_hidden_state=fused.unsqueeze(1) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class SincConv(nn.Module): |
|
def __init__( |
|
self, |
|
out_channels, |
|
kernel_size, |
|
sample_rate=16000, |
|
in_channels=1, |
|
stride=1, |
|
padding=0, |
|
min_low_hz=50, |
|
min_band_hz=50 |
|
): |
|
super().__init__() |
|
if in_channels != 1: |
|
raise ValueError("SincConv only supports one input channel.") |
|
|
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.sample_rate = sample_rate |
|
self.stride = stride |
|
self.padding = padding |
|
self.min_low_hz = min_low_hz |
|
self.min_band_hz = min_band_hz |
|
|
|
low_hz = 30 |
|
high_hz = sample_rate / 2 - (self.min_low_hz + self.min_band_hz) |
|
band = (high_hz - low_hz) / self.out_channels |
|
|
|
self.low_hz_ = nn.Parameter(torch.linspace(low_hz, high_hz, self.out_channels)) |
|
self.band_hz_ = nn.Parameter(torch.full((self.out_channels,), band)) |
|
|
|
n_lin = torch.linspace(0, kernel_size - 1, kernel_size) |
|
window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / (kernel_size - 1)) |
|
window_ = window_.float() |
|
self.register_buffer("window_", window_) |
|
|
|
assert kernel_size % 2 == 1, "kernel_size tek sayı olmalı." |
|
half_kernel = (kernel_size - 1) // 2 |
|
n_arr = torch.arange(-half_kernel, half_kernel + 1) |
|
n_ = 2 * math.pi * n_arr / self.sample_rate |
|
n_ = n_.float() |
|
self.register_buffer("n_", n_) |
|
|
|
def forward(self, x): |
|
low = self.min_low_hz + torch.abs(self.low_hz_) |
|
high = torch.clamp( |
|
low + self.min_band_hz + torch.abs(self.band_hz_), |
|
self.min_low_hz, |
|
self.sample_rate / 2 |
|
) |
|
band = (high - low)[:, None] |
|
|
|
f_low = low[:, None] * self.n_ |
|
f_high = high[:, None] * self.n_ |
|
|
|
band_pass_left = (2 * f_high).sin() / (1e-8 + 2 * f_high) |
|
band_pass_right = (2 * f_low).sin() / (1e-8 + 2 * f_low) |
|
band_pass = band_pass_left - band_pass_right |
|
|
|
band_pass = band_pass * self.window_ |
|
band_pass = band_pass / (2 * band + 1e-8) |
|
|
|
filters = band_pass.to(x.device) |
|
filters = filters.unsqueeze(1) |
|
|
|
return F.conv1d( |
|
x, filters, |
|
stride=self.stride, |
|
padding=self.padding |
|
) |
|
|
|
|
|
class SEBlock(nn.Module): |
|
def __init__(self, channel, reduction=8): |
|
super().__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool1d(1) |
|
self.fc = nn.Sequential( |
|
nn.Linear(channel, channel // reduction, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(channel // reduction, channel, bias=False), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
b, c, _ = x.size() |
|
y = self.avg_pool(x).view(b, c) |
|
y = self.fc(y).view(b, c, 1) |
|
return x * y.expand_as(x) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, channels, kernel_size=3, dilation=1, use_se=True): |
|
super().__init__() |
|
self.use_se = use_se |
|
padding = dilation * (kernel_size - 1) // 2 |
|
|
|
self.conv = nn.Conv1d( |
|
channels, channels, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
dilation=dilation |
|
) |
|
self.norm = nn.InstanceNorm1d(channels, affine=True) |
|
self.se = SEBlock(channel=channels) if use_se else None |
|
self.act = nn.Mish() |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = self.conv(x) |
|
out = self.norm(out) |
|
if self.se is not None: |
|
out = self.se(out) |
|
out = self.act(out) |
|
return identity + out |
|
|
|
|
|
class ResidualStack(nn.Module): |
|
def __init__(self, channels, kernel_size, dilation, use_se, num_blocks): |
|
super().__init__() |
|
self.blocks = nn.ModuleList([ |
|
ResidualBlock( |
|
channels=channels, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
use_se=use_se |
|
) |
|
for _ in range(num_blocks) |
|
]) |
|
|
|
def forward(self, x): |
|
for block in self.blocks: |
|
x = block(x) |
|
return x |
|
|
|
|
|
class AttentiveStatsPooling(nn.Module): |
|
def __init__(self, in_dim, attention_dim=128): |
|
super().__init__() |
|
self.linear = nn.Linear(in_dim, attention_dim) |
|
self.tanh = nn.Tanh() |
|
self.att_conv = nn.Conv1d(attention_dim, 1, kernel_size=1) |
|
|
|
def forward(self, x): |
|
|
|
out = self.linear(x.transpose(1, 2)) |
|
out = self.tanh(out) |
|
out = out.transpose(1, 2) |
|
w = self.att_conv(out) |
|
w = F.softmax(w, dim=2) |
|
|
|
mean = torch.sum(x * w, dim=2) |
|
mean_sq = torch.sum((x ** 2) * w, dim=2) |
|
std = torch.sqrt(mean_sq - mean ** 2 + 1e-9) |
|
|
|
return torch.cat([mean, std], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|