|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Literal |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
|
|
def init_layer(layer: nn.Module) -> None: |
|
"""Initialize a Linear or Convolutional layer.""" |
|
nn.init.xavier_uniform_(layer.weight) |
|
if hasattr(layer, "bias") and layer.bias is not None: |
|
layer.bias.data.zero_() |
|
|
|
|
|
def init_bn(bn: nn.Module) -> None: |
|
"""Initialize a Batchnorm layer.""" |
|
bn.bias.data.zero_() |
|
bn.weight.data.fill_(1.0) |
|
bn.running_mean.data.zero_() |
|
bn.running_var.data.fill_(1.0) |
|
|
|
|
|
def act(x: torch.Tensor, activation: str) -> torch.Tensor: |
|
"""Activation function.""" |
|
funcs = {"relu": F.relu_, "leaky_relu": lambda x: F.leaky_relu_(x, 0.01), "swish": lambda x: x * torch.sigmoid(x)} |
|
return funcs.get(activation, lambda x: Exception("Incorrect activation!"))(x) |
|
|
|
|
|
class Res2DAVPBlock(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, avp_kernel_size, activation): |
|
"""Convolutional residual block modified fromr bytedance/music_source_separation.""" |
|
super().__init__() |
|
|
|
padding = kernel_size[0] // 2, kernel_size[1] // 2 |
|
|
|
self.activation = activation |
|
self.bn1, self.bn2 = nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels) |
|
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding, bias=False) |
|
|
|
self.is_shortcut = in_channels != out_channels |
|
if self.is_shortcut: |
|
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) |
|
|
|
self.avp = nn.AvgPool2d(avp_kernel_size) |
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
for m in [self.conv1, self.conv2] + ([self.shortcut] if self.is_shortcut else []): |
|
init_layer(m) |
|
for m in [self.bn1, self.bn2]: |
|
init_bn(m) |
|
|
|
def forward(self, x): |
|
origin = x |
|
x = act(self.bn1(self.conv1(x)), self.activation) |
|
x = self.bn2(self.conv2(x)) |
|
x += self.shortcut(origin) if self.is_shortcut else origin |
|
x = act(x, self.activation) |
|
return self.avp(x) |
|
|
|
|
|
class PreEncoderBlockRes3B(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), avp_kernerl_size=(1, 2), activation='relu'): |
|
"""Pre-Encoder with 3 Res2DAVPBlocks.""" |
|
super().__init__() |
|
|
|
self.blocks = nn.ModuleList([ |
|
Res2DAVPBlock(in_channels if i == 0 else out_channels, out_channels, kernel_size, avp_kernerl_size, |
|
activation) for i in range(3) |
|
]) |
|
|
|
def forward(self, x): |
|
x = rearrange(x, 'b t f -> b 1 t f') |
|
for block in self.blocks: |
|
x = block(x) |
|
return rearrange(x, 'b c t f -> b t f c') |
|
|
|
|
|
def test_res3b(): |
|
|
|
x = torch.randn(2, 256, 512) |
|
pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128) |
|
x = pre(x) |
|
|
|
x = torch.randn(2, 110, 1024) |
|
pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128) |
|
x = pre(x) |
|
|
|
|
|
|
|
|
|
|
|
class PreEncoderBlockHFTT(nn.Module): |
|
|
|
def __init__(self, margin_pre=15, margin_post=16) -> None: |
|
"""Pre-Encoder with hFT-Transformer-like convolutions.""" |
|
super().__init__() |
|
|
|
self.margin_pre, self.margin_post = margin_pre, margin_post |
|
self.conv = nn.Conv2d(1, 4, kernel_size=(1, 5), padding='same', padding_mode='zeros') |
|
self.emb_freq = nn.Linear(128, 128) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = rearrange(x, 'b t f -> b 1 f t') |
|
x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) |
|
x = self.conv(x) |
|
x = x.unfold(dimension=3, size=32, step=1) |
|
x = rearrange(x, 'b c1 f t c2 -> b t f (c1 c2)') |
|
return self.emb_freq(x) |
|
|
|
|
|
def test_hftt(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.randn(2, 110, 128) |
|
pre_enc_hftt = PreEncoderBlockHFTT() |
|
y = pre_enc_hftt(x) |
|
|
|
|
|
|
|
|
|
|
|
class PreEncoderBlockRes3BHFTT(nn.Module): |
|
|
|
def __init__(self, margin_pre: int = 15, margin_post: int = 16) -> None: |
|
"""Pre-Encoder with hFT-Transformer-like convolutions. |
|
|
|
Args: |
|
margin_pre (int): padding before the input |
|
margin_post (int): padding after the input |
|
stack_dim (Literal['c', 'f']): stack dimension. channel or frequency |
|
|
|
""" |
|
super().__init__() |
|
self.margin_pre, self.margin_post = margin_pre, margin_post |
|
self.res3b = PreEncoderBlockRes3B(in_channels=1, out_channels=4) |
|
self.emb_freq = nn.Linear(128, 128) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = rearrange(x, 'b t f -> b f t') |
|
x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) |
|
x = rearrange(x, 'b f t -> b t f') |
|
x = self.res3b(x) |
|
x = x.unfold(dimension=1, size=32, step=1) |
|
x = rearrange(x, 'b t f c1 c2 -> b t f (c1 c2)') |
|
return self.emb_freq(x) |
|
|
|
|
|
def test_res3b_hftt(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.randn(2, 110, 1024) |
|
pre_enc_res3b_hftt = PreEncoderBlockRes3BHFTT() |
|
y = pre_enc_res3b_hftt(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|