WavLMRawNetSVBase / wavlmrawnetsvbase.py
bunyaminergen's picture
Initial
e8adb69
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
PretrainedConfig,
WavLMModel,
)
from transformers.modeling_outputs import BaseModelOutput
############################################################
# 1) Config Sınıfı
############################################################
class WavLMRawNetSVBaseConfig(PretrainedConfig):
model_type = "wavlm-rawnet-sv-base"
def __init__(
self,
sample_rate=16000,
sinc_out_channels=128,
sinc_kernel_size=251,
sinc_in_channels=1,
rawnet_embedding_dim=256,
use_se=True,
num_blocks=4,
pooling_attention_dim=128,
wavlm_path="microsoft/wavlm-large",
wavlm_embedding_dim=1024,
fusion_embedding_dim=256,
**kwargs
):
"""
Transformers tabanlı bir config sınıfı.
Parametreleri:
--------------
- sample_rate : int
- sinc_out_channels : int
- sinc_kernel_size : int
- sinc_in_channels : int
- rawnet_embedding_dim : int
- use_se : bool (SEBlock kullanılsın mı?)
- num_blocks : int (RawNet içindeki residual block sayısı)
- pooling_attention_dim : int (AttentiveStatsPooling için attention boyutu)
- wavlm_path : str (örn: "microsoft/wavlm-large")
- wavlm_embedding_dim : int (WavLM last_hidden_state boyutu, genelde 1024)
- fusion_embedding_dim : int (WavLM + RawNet birleştikten sonraki final boyut)
"""
super().__init__(**kwargs)
self.sample_rate = sample_rate
self.sinc_out_channels = sinc_out_channels
self.sinc_kernel_size = sinc_kernel_size
self.sinc_in_channels = sinc_in_channels
self.rawnet_embedding_dim = rawnet_embedding_dim
self.use_se = use_se
self.num_blocks = num_blocks
self.pooling_attention_dim = pooling_attention_dim
self.wavlm_path = wavlm_path
self.wavlm_embedding_dim = wavlm_embedding_dim
self.fusion_embedding_dim = fusion_embedding_dim
############################################################
# 2) Model Sınıfı: WavLM + RawNet End-to-End
############################################################
class WavLMRawNetSVBase(PreTrainedModel):
config_class = WavLMRawNetSVBaseConfig
base_model_prefix = "wavlm_rawnet_sv_base"
def __init__(self, config: WavLMRawNetSVBaseConfig):
"""
WavLM + RawNet mimarisini PreTrainedModel çatısı altında toplar.
"""
super().__init__(config)
# 2.1) WavLM kısmı
# ------------------------------------------------------
self.wavlm_model = WavLMModel.from_pretrained(config.wavlm_path)
# 2.2) RawNet kısımları
# ------------------------------------------------------
self.sinc_conv = SincConv(
out_channels=config.sinc_out_channels,
kernel_size=config.sinc_kernel_size,
sample_rate=config.sample_rate,
in_channels=config.sinc_in_channels,
padding=config.sinc_kernel_size // 2
)
self.first_norm = nn.InstanceNorm1d(config.sinc_out_channels, affine=True)
self.first_act = nn.Mish()
self.blocks = ResidualStack(
channels=config.sinc_out_channels,
kernel_size=3,
dilation=1,
use_se=config.use_se,
num_blocks=config.num_blocks
)
self.pooling_layer = AttentiveStatsPooling(
in_dim=config.sinc_out_channels,
attention_dim=config.pooling_attention_dim
)
self.emb_fc = nn.Linear(config.sinc_out_channels * 2, config.rawnet_embedding_dim)
self.emb_act = nn.ReLU()
# 2.3) Fusion (WavLM + RawNet)
# ------------------------------------------------------
self.fusion_fc = nn.Linear(
config.wavlm_embedding_dim + config.rawnet_embedding_dim,
config.fusion_embedding_dim
)
self.fusion_act = nn.ReLU()
def forward(
self,
input_values: torch.Tensor,
**kwargs
):
"""
input_values: (batch_size, time), float32 dalga formu
Dönüş: BaseModelOutput with last_hidden_state = (batch_size, 1, fusion_embedding_dim)
"""
# 1) WavLM Forward
wavlm_out = self.wavlm_model(input_values, **kwargs)
# wavlm_out.last_hidden_state -> (batch, seq_len, wavlm_emb_dim)
wavlm_emb = wavlm_out.last_hidden_state.mean(dim=1) # (batch, wavlm_emb_dim)
# 2) RawNet Forward
x = input_values.unsqueeze(1) # (batch, 1, time)
x = self.sinc_conv(x)
x = self.first_norm(x)
x = self.first_act(x)
x = self.blocks(x)
x = self.pooling_layer(x)
x = x.squeeze(-1)
x = self.emb_fc(x)
x = self.emb_act(x)
rawnet_emb = x # (batch, rawnet_embedding_dim)
# 3) Fusion
fused = torch.cat([wavlm_emb, rawnet_emb], dim=1)
fused = self.fusion_fc(fused)
fused = self.fusion_act(fused)
return BaseModelOutput(
last_hidden_state=fused.unsqueeze(1) # -> (batch, 1, fusion_dim)
)
############################################################
# 3) RawNet bileşenleri (SincConv, ResidualStack, vs.)
############################################################
class SincConv(nn.Module):
def __init__(
self,
out_channels,
kernel_size,
sample_rate=16000,
in_channels=1,
stride=1,
padding=0,
min_low_hz=50,
min_band_hz=50
):
super().__init__()
if in_channels != 1:
raise ValueError("SincConv only supports one input channel.")
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
self.stride = stride
self.padding = padding
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
low_hz = 30
high_hz = sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
band = (high_hz - low_hz) / self.out_channels
self.low_hz_ = nn.Parameter(torch.linspace(low_hz, high_hz, self.out_channels))
self.band_hz_ = nn.Parameter(torch.full((self.out_channels,), band))
n_lin = torch.linspace(0, kernel_size - 1, kernel_size)
window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / (kernel_size - 1))
window_ = window_.float()
self.register_buffer("window_", window_)
assert kernel_size % 2 == 1, "kernel_size tek sayı olmalı."
half_kernel = (kernel_size - 1) // 2
n_arr = torch.arange(-half_kernel, half_kernel + 1)
n_ = 2 * math.pi * n_arr / self.sample_rate
n_ = n_.float()
self.register_buffer("n_", n_)
def forward(self, x):
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(
low + self.min_band_hz + torch.abs(self.band_hz_),
self.min_low_hz,
self.sample_rate / 2
)
band = (high - low)[:, None]
f_low = low[:, None] * self.n_
f_high = high[:, None] * self.n_
band_pass_left = (2 * f_high).sin() / (1e-8 + 2 * f_high)
band_pass_right = (2 * f_low).sin() / (1e-8 + 2 * f_low)
band_pass = band_pass_left - band_pass_right
band_pass = band_pass * self.window_
band_pass = band_pass / (2 * band + 1e-8)
filters = band_pass.to(x.device)
filters = filters.unsqueeze(1)
return F.conv1d(
x, filters,
stride=self.stride,
padding=self.padding
)
class SEBlock(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1)
return x * y.expand_as(x)
class ResidualBlock(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=1, use_se=True):
super().__init__()
self.use_se = use_se
padding = dilation * (kernel_size - 1) // 2
self.conv = nn.Conv1d(
channels, channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation
)
self.norm = nn.InstanceNorm1d(channels, affine=True)
self.se = SEBlock(channel=channels) if use_se else None
self.act = nn.Mish()
def forward(self, x):
identity = x
out = self.conv(x)
out = self.norm(out)
if self.se is not None:
out = self.se(out)
out = self.act(out)
return identity + out
class ResidualStack(nn.Module):
def __init__(self, channels, kernel_size, dilation, use_se, num_blocks):
super().__init__()
self.blocks = nn.ModuleList([
ResidualBlock(
channels=channels,
kernel_size=kernel_size,
dilation=dilation,
use_se=use_se
)
for _ in range(num_blocks)
])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class AttentiveStatsPooling(nn.Module):
def __init__(self, in_dim, attention_dim=128):
super().__init__()
self.linear = nn.Linear(in_dim, attention_dim)
self.tanh = nn.Tanh()
self.att_conv = nn.Conv1d(attention_dim, 1, kernel_size=1)
def forward(self, x):
# x: (B, C, T)
out = self.linear(x.transpose(1, 2)) # -> (B, T, att_dim)
out = self.tanh(out)
out = out.transpose(1, 2) # -> (B, att_dim, T)
w = self.att_conv(out) # -> (B, 1, T)
w = F.softmax(w, dim=2)
mean = torch.sum(x * w, dim=2)
mean_sq = torch.sum((x ** 2) * w, dim=2)
std = torch.sqrt(mean_sq - mean ** 2 + 1e-9)
return torch.cat([mean, std], dim=1) # (B, C*2)
#################################################################
# Son
#################################################################