YourMT3 / amt /src /model /conv_block.py
mimbres's picture
.
a03c9b4
raw
history blame
9.3 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
from typing import Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def init_layer(layer: nn.Module) -> None:
"""Initialize a Linear or Convolutional layer."""
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data.zero_()
def init_bn(bn: nn.Module) -> None:
"""Initialize a Batchnorm layer."""
bn.bias.data.zero_()
bn.weight.data.fill_(1.0)
bn.running_mean.data.zero_()
bn.running_var.data.fill_(1.0)
def act(x: torch.Tensor, activation: str) -> torch.Tensor:
"""Activation function."""
funcs = {"relu": F.relu_, "leaky_relu": lambda x: F.leaky_relu_(x, 0.01), "swish": lambda x: x * torch.sigmoid(x)}
return funcs.get(activation, lambda x: Exception("Incorrect activation!"))(x)
class Res2DAVPBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, avp_kernel_size, activation):
"""Convolutional residual block modified fromr bytedance/music_source_separation."""
super().__init__()
padding = kernel_size[0] // 2, kernel_size[1] // 2
self.activation = activation
self.bn1, self.bn2 = nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding, bias=False)
self.is_shortcut = in_channels != out_channels
if self.is_shortcut:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.avp = nn.AvgPool2d(avp_kernel_size)
self.init_weights()
def init_weights(self):
for m in [self.conv1, self.conv2] + ([self.shortcut] if self.is_shortcut else []):
init_layer(m)
for m in [self.bn1, self.bn2]:
init_bn(m)
def forward(self, x):
origin = x
x = act(self.bn1(self.conv1(x)), self.activation)
x = self.bn2(self.conv2(x))
x += self.shortcut(origin) if self.is_shortcut else origin
x = act(x, self.activation)
return self.avp(x)
class PreEncoderBlockRes3B(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), avp_kernerl_size=(1, 2), activation='relu'):
"""Pre-Encoder with 3 Res2DAVPBlocks."""
super().__init__()
self.blocks = nn.ModuleList([
Res2DAVPBlock(in_channels if i == 0 else out_channels, out_channels, kernel_size, avp_kernerl_size,
activation) for i in range(3)
])
def forward(self, x): # (B, T, F)
x = rearrange(x, 'b t f -> b 1 t f')
for block in self.blocks:
x = block(x)
return rearrange(x, 'b c t f -> b t f c')
def test_res3b():
# mel-spec input
x = torch.randn(2, 256, 512) # (B, T, F)
pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128)
x = pre(x) # (2, 256, 64, 128): B T,F,C
x = torch.randn(2, 110, 1024) # (B, T, F)
pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128)
x = pre(x) # (2, 110, 128, 128): B,T,F,C
# ====================================================================================================================
# PreEncoderBlockHFTT: hFT-Transformer-like Pre-encoder
# ====================================================================================================================
class PreEncoderBlockHFTT(nn.Module):
def __init__(self, margin_pre=15, margin_post=16) -> None:
"""Pre-Encoder with hFT-Transformer-like convolutions."""
super().__init__()
self.margin_pre, self.margin_post = margin_pre, margin_post
self.conv = nn.Conv2d(1, 4, kernel_size=(1, 5), padding='same', padding_mode='zeros')
self.emb_freq = nn.Linear(128, 128)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, F)
x = rearrange(x, 'b t f -> b 1 f t') # (B, 1, F, T) or (2, 1, 128, 110)
x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (B, 1, F, T+margin) or (2,1,128,141)
x = self.conv(x) # (B, C, F, T+margin) or (2, 4, 128, 141)
x = x.unfold(dimension=3, size=32, step=1) # (B, c1, T, F, c2) or (2, 4, 128, 110, 32)
x = rearrange(x, 'b c1 f t c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128)
return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128)
def test_hftt():
# from model.spectrogram import get_spectrogram_layer_from_audio_cfg
# from config.config import audio_cfg as default_audio_cfg
# audio_cfg = default_audio_cfg
# audio_cfg['codec'] = 'melspec'
# audio_cfg['hop_length'] = 300
# audio_cfg['n_mels'] = 128
# x = torch.randn(2, 1, 32767)
# mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg)
# x = mspec(x)
x = torch.randn(2, 110, 128) # (B, T, F)
pre_enc_hftt = PreEncoderBlockHFTT()
y = pre_enc_hftt(x) # (2, 110, 128, 128): B, T, F, C
# ====================================================================================================================
# PreEncoderBlockRes3BHFTT: hFT-Transformer-like Pre-encoder with Res2DAVPBlock and spec input
# ====================================================================================================================
class PreEncoderBlockRes3BHFTT(nn.Module):
def __init__(self, margin_pre: int = 15, margin_post: int = 16) -> None:
"""Pre-Encoder with hFT-Transformer-like convolutions.
Args:
margin_pre (int): padding before the input
margin_post (int): padding after the input
stack_dim (Literal['c', 'f']): stack dimension. channel or frequency
"""
super().__init__()
self.margin_pre, self.margin_post = margin_pre, margin_post
self.res3b = PreEncoderBlockRes3B(in_channels=1, out_channels=4)
self.emb_freq = nn.Linear(128, 128)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, F) or (2, 110, 1024), input spectrogram
x = rearrange(x, 'b t f -> b f t') # (2, 1024, 110): B,F,T
x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (2, 1024, 141): B,F,T+margin
x = rearrange(x, 'b f t -> b t f') # (2, 141, 1024): B,T+margin,F
x = self.res3b(x) # (2, 141, 128, 4): B,T+margin,F,C
x = x.unfold(dimension=1, size=32, step=1) # (B, T, F, C1, C2) or (2, 110, 128, 4, 32)
x = rearrange(x, 'b t f c1 c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128)
return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128)
def test_res3b_hftt():
# from model.spectrogram import get_spectrogram_layer_from_audio_cfg
# from config.config import audio_cfg as default_audio_cfg
# audio_cfg = default_audio_cfg
# audio_cfg['codec'] = 'spec'
# audio_cfg['hop_length'] = 300
# x = torch.randn(2, 1, 32767)
# spec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg)
# x = spec(x) # (2, 110, 1024): B,T,F
x = torch.randn(2, 110, 1024) # (B, T, F)
pre_enc_res3b_hftt = PreEncoderBlockRes3BHFTT()
y = pre_enc_res3b_hftt(x) # (2, 110, 128, 128): B, T, F, C
# # ====================================================================================================================
# # PreEncoderBlockConv1D: Pre-encoder without activation, with Melspec input
# # ====================================================================================================================
# class PreEncoderBlockConv1D(nn.Module):
# def __init__(self,
# in_channels,
# out_channels,
# kernel_size=3) -> None:
# """Pre-Encoder with 1D convolution."""
# super().__init__()
# self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1)
# self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1)
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# # x: (B, T, F) or (2, 128, 256), input melspec
# x = rearrange(x, 'b t f -> b f t') # (2, 256, 128): B,F,T
# x = self.conv1(x) # (2, 128, 128): B,F,T
# return rearrange(x, 'b f t -> b t f') # (2, 110, 128): B,T,F
# def test_conv1d():
# # from model.spectrogram import get_spectrogram_layer_from_audio_cfg
# # from config.config import audio_cfg as default_audio_cfg
# # audio_cfg = default_audio_cfg
# # audio_cfg['codec'] = 'melspec'
# # audio_cfg['hop_length'] = 256
# # audio_cfg['n_mels'] = 512
# # x = torch.randn(2, 1, 32767)
# # mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg)
# # x = mspec(x)
# x = torch.randn(2, 128, 128) # (B, T, F)
# pre_enc_conv1d = PreEncoderBlockConv1D(in_channels=1, out_channels=128)
# y = pre_enc_conv1d(x) # (2, 110, 128, 128): B, T, F, C