from functools import partial from typing import Final, List, Optional, Tuple, Union import torch from loguru import logger from torch import Tensor, nn from df_local.config import Csv, DfParams, config from df_local.modules import ( Conv2dNormAct, ConvTranspose2dNormAct, DfOp, GroupedGRU, GroupedLinear, GroupedLinearEinsum, Mask, SqueezedGRU, erb_fb, get_device, ) from df_local.multiframe import MF_METHODS, MultiFrameModule from libdf import DF class ModelParams(DfParams): section = "deepfilternet" def __init__(self): super().__init__() self.conv_lookahead: int = config( "CONV_LOOKAHEAD", cast=int, default=0, section=self.section ) self.conv_ch: int = config("CONV_CH", cast=int, default=16, section=self.section) self.conv_depthwise: bool = config( "CONV_DEPTHWISE", cast=bool, default=True, section=self.section ) self.convt_depthwise: bool = config( "CONVT_DEPTHWISE", cast=bool, default=True, section=self.section ) self.conv_kernel: List[int] = config( "CONV_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore ) self.conv_kernel_inp: List[int] = config( "CONV_KERNEL_INP", cast=Csv(int), default=(3, 3), section=self.section # type: ignore ) self.emb_hidden_dim: int = config( "EMB_HIDDEN_DIM", cast=int, default=256, section=self.section ) self.emb_num_layers: int = config( "EMB_NUM_LAYERS", cast=int, default=2, section=self.section ) self.df_hidden_dim: int = config( "DF_HIDDEN_DIM", cast=int, default=256, section=self.section ) self.df_gru_skip: str = config("DF_GRU_SKIP", default="none", section=self.section) self.df_output_layer: str = config( "DF_OUTPUT_LAYER", default="linear", section=self.section ) self.df_pathway_kernel_size_t: int = config( "DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section ) self.enc_concat: bool = config("ENC_CONCAT", cast=bool, default=False, section=self.section) self.df_num_layers: int = config("DF_NUM_LAYERS", cast=int, default=3, section=self.section) self.df_n_iter: int = config("DF_N_ITER", cast=int, default=2, section=self.section) self.gru_type: str = config("GRU_TYPE", default="grouped", section=self.section) self.gru_groups: int = config("GRU_GROUPS", cast=int, default=1, section=self.section) self.lin_groups: int = config("LINEAR_GROUPS", cast=int, default=1, section=self.section) self.group_shuffle: bool = config( "GROUP_SHUFFLE", cast=bool, default=True, section=self.section ) self.dfop_method: str = config("DFOP_METHOD", cast=str, default="df", section=self.section) self.mask_pf: bool = config("MASK_PF", cast=bool, default=False, section=self.section) def init_model(df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True): p = ModelParams() if df_state is None: df_state = DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb) erb = erb_fb(df_state.erb_widths(), p.sr, inverse=False) erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True) model = DfNet(erb, erb_inverse, run_df, train_mask) return model.to(device=get_device()) class Add(nn.Module): def forward(self, a, b): return a + b class Concat(nn.Module): def forward(self, a, b): return torch.cat((a, b), dim=-1) class Encoder(nn.Module): def __init__(self): super().__init__() p = ModelParams() assert p.nb_erb % 4 == 0, "erb_bins should be divisible by 4" self.erb_conv0 = Conv2dNormAct( 1, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True ) conv_layer = partial( Conv2dNormAct, in_ch=p.conv_ch, out_ch=p.conv_ch, kernel_size=p.conv_kernel, bias=False, separable=True, ) self.erb_conv1 = conv_layer(fstride=2) self.erb_conv2 = conv_layer(fstride=2) self.erb_conv3 = conv_layer(fstride=1) self.df_conv0 = Conv2dNormAct( 2, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True ) self.df_conv1 = conv_layer(fstride=2) self.erb_bins = p.nb_erb self.emb_in_dim = p.conv_ch * p.nb_erb // 4 self.emb_out_dim = p.emb_hidden_dim if p.gru_type == "grouped": self.df_fc_emb = GroupedLinear( p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups ) else: df_fc_emb = GroupedLinearEinsum( p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups ) self.df_fc_emb = nn.Sequential(df_fc_emb, nn.ReLU(inplace=True)) if p.enc_concat: self.emb_in_dim *= 2 self.combine = Concat() else: self.combine = Add() self.emb_out_dim = p.emb_hidden_dim self.emb_n_layers = p.emb_num_layers assert p.gru_type in ("grouped", "squeeze"), f"But got {p.gru_type}" if p.gru_type == "grouped": self.emb_gru = GroupedGRU( self.emb_in_dim, self.emb_out_dim, num_layers=1, batch_first=True, groups=p.gru_groups, shuffle=p.group_shuffle, add_outputs=True, ) else: self.emb_gru = SqueezedGRU( self.emb_in_dim, self.emb_out_dim, num_layers=1, batch_first=True, linear_groups=p.lin_groups, linear_act_layer=partial(nn.ReLU, inplace=True), ) self.lsnr_fc = nn.Sequential(nn.Linear(self.emb_out_dim, 1), nn.Sigmoid()) self.lsnr_scale = p.lsnr_max - p.lsnr_min self.lsnr_offset = p.lsnr_min def forward( self, feat_erb: Tensor, feat_spec: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands. # erb: [B, 1, T, Fe] # spec: [B, 2, T, Fc] # b, _, t, _ = feat_erb.shape e0 = self.erb_conv0(feat_erb) # [B, C, T, F] e1 = self.erb_conv1(e0) # [B, C*2, T, F/2] e2 = self.erb_conv2(e1) # [B, C*4, T, F/4] e3 = self.erb_conv3(e2) # [B, C*4, T, F/4] c0 = self.df_conv0(feat_spec) # [B, C, T, Fc] c1 = self.df_conv1(c0) # [B, C*2, T, Fc] cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1] cemb = self.df_fc_emb(cemb) # [T, B, C * F/4] emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F/4] emb = self.combine(emb, cemb) emb, _ = self.emb_gru(emb) # [B, T, -1] lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset return e0, e1, e2, e3, emb, c0, lsnr class ErbDecoder(nn.Module): def __init__(self): super().__init__() p = ModelParams() assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8" self.emb_out_dim = p.emb_hidden_dim if p.gru_type == "grouped": self.emb_gru = GroupedGRU( p.conv_ch * p.nb_erb // 4, # For compat self.emb_out_dim, num_layers=p.emb_num_layers - 1, batch_first=True, groups=p.gru_groups, shuffle=p.group_shuffle, add_outputs=True, ) # SqueezedGRU uses GroupedLinearEinsum, so let's use it here as well fc_emb = GroupedLinear( p.emb_hidden_dim, p.conv_ch * p.nb_erb // 4, groups=p.lin_groups, shuffle=p.group_shuffle, ) self.fc_emb = nn.Sequential(fc_emb, nn.ReLU(inplace=True)) else: self.emb_gru = SqueezedGRU( self.emb_out_dim, self.emb_out_dim, output_size=p.conv_ch * p.nb_erb // 4, num_layers=p.emb_num_layers - 1, batch_first=True, gru_skip_op=nn.Identity, linear_groups=p.lin_groups, linear_act_layer=partial(nn.ReLU, inplace=True), ) self.fc_emb = nn.Identity() tconv_layer = partial( ConvTranspose2dNormAct, kernel_size=p.conv_kernel, bias=False, separable=True, ) conv_layer = partial( Conv2dNormAct, bias=False, separable=True, ) # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions self.conv3p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) self.convt3 = conv_layer(p.conv_ch, p.conv_ch, kernel_size=p.conv_kernel) self.conv2p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) self.convt2 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2) self.conv1p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) self.convt1 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2) self.conv0p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) self.conv0_out = conv_layer( p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid ) def forward(self, emb, e3, e2, e1, e0) -> Tensor: # Estimates erb mask b, _, t, f8 = e3.shape emb, _ = self.emb_gru(emb) emb = self.fc_emb(emb) emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8] e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4] e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2] e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F] m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F] return m class DfOutputReshapeMF(nn.Module): """Coefficients output reshape for multiframe/MultiFrameModule Requires input of shape B, C, T, F, 2. """ def __init__(self, df_order: int, df_bins: int): super().__init__() self.df_order = df_order self.df_bins = df_bins def forward(self, coefs: Tensor) -> Tensor: # [B, T, F, O*2] -> [B, O, T, F, 2] coefs = coefs.view(*coefs.shape[:-1], -1, 2) coefs = coefs.permute(0, 3, 1, 2, 4) return coefs class DfDecoder(nn.Module): def __init__(self, out_channels: int = -1): super().__init__() p = ModelParams() layer_width = p.conv_ch self.emb_dim = p.emb_hidden_dim self.df_n_hidden = p.df_hidden_dim self.df_n_layers = p.df_num_layers self.df_order = p.df_order self.df_bins = p.nb_df self.gru_groups = p.gru_groups self.df_out_ch = out_channels if out_channels > 0 else p.df_order * 2 conv_layer = partial(Conv2dNormAct, separable=True, bias=False) kt = p.df_pathway_kernel_size_t self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1)) if p.gru_type == "grouped": self.df_gru = GroupedGRU( p.emb_hidden_dim, p.df_hidden_dim, num_layers=self.df_n_layers, batch_first=True, groups=p.gru_groups, shuffle=p.group_shuffle, add_outputs=True, ) else: self.df_gru = SqueezedGRU( p.emb_hidden_dim, p.df_hidden_dim, num_layers=self.df_n_layers, batch_first=True, gru_skip_op=nn.Identity, linear_act_layer=partial(nn.ReLU, inplace=True), ) p.df_gru_skip = p.df_gru_skip.lower() assert p.df_gru_skip in ("none", "identity", "groupedlinear") self.df_skip: Optional[nn.Module] if p.df_gru_skip == "none": self.df_skip = None elif p.df_gru_skip == "identity": assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match" self.df_skip = nn.Identity() elif p.df_gru_skip == "groupedlinear": self.df_skip = GroupedLinearEinsum( p.emb_hidden_dim, p.df_hidden_dim, groups=p.lin_groups ) else: raise NotImplementedError() assert p.df_output_layer in ("linear", "groupedlinear") self.df_out: nn.Module out_dim = self.df_bins * self.df_out_ch if p.df_output_layer == "linear": df_out = nn.Linear(self.df_n_hidden, out_dim) elif p.df_output_layer == "groupedlinear": df_out = GroupedLinearEinsum(self.df_n_hidden, out_dim, groups=p.lin_groups) else: raise NotImplementedError self.df_out = nn.Sequential(df_out, nn.Tanh()) self.df_fc_a = nn.Sequential(nn.Linear(self.df_n_hidden, 1), nn.Sigmoid()) self.out_transform = DfOutputReshapeMF(self.df_order, self.df_bins) def forward(self, emb: Tensor, c0: Tensor) -> Tuple[Tensor, Tensor]: b, t, _ = emb.shape c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden if self.df_skip is not None: c += self.df_skip(emb) c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last alpha = self.df_fc_a(c) # [B, T, 1] c = self.df_out(c) # [B, T, F*O*2], O: df_order c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2] c = self.out_transform(c) return c, alpha class DfNet(nn.Module): run_df: Final[bool] pad_specf: Final[bool] def __init__( self, erb_fb: Tensor, erb_inv_fb: Tensor, run_df: bool = True, train_mask: bool = True, ): super().__init__() p = ModelParams() layer_width = p.conv_ch assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8" self.df_lookahead = p.df_lookahead if p.pad_mode == "model" else 0 self.nb_df = p.nb_df self.freq_bins: int = p.fft_size // 2 + 1 self.emb_dim: int = layer_width * p.nb_erb self.erb_bins: int = p.nb_erb if p.conv_lookahead > 0 and p.pad_mode.startswith("input"): self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0) else: self.pad_feat = nn.Identity() self.pad_specf = p.pad_mode.endswith("specf") if p.df_lookahead > 0 and self.pad_specf: self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -p.df_lookahead, p.df_lookahead), 0.0) else: self.pad_spec = nn.Identity() if (p.conv_lookahead > 0 or p.df_lookahead > 0) and p.pad_mode.startswith("output"): assert p.conv_lookahead == p.df_lookahead pad = (0, 0, 0, 0, -p.conv_lookahead, p.conv_lookahead) self.pad_out = nn.ConstantPad3d(pad, 0.0) else: self.pad_out = nn.Identity() self.register_buffer("erb_fb", erb_fb) self.enc = Encoder() self.erb_dec = ErbDecoder() self.mask = Mask(erb_inv_fb, post_filter=p.mask_pf) self.df_order = p.df_order self.df_bins = p.nb_df self.df_op: Union[DfOp, MultiFrameModule] if p.dfop_method == "real_unfold": raise ValueError("RealUnfold DF OP is now unsupported.") assert p.df_output_layer != "linear", "Must be used with `groupedlinear`" self.df_op = MF_METHODS[p.dfop_method]( num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead ) n_ch_out = self.df_op.num_channels() self.df_dec = DfDecoder(out_channels=n_ch_out) self.run_df = run_df if not run_df: logger.warning("Runing without DF") self.train_mask = train_mask assert p.df_n_iter == 1 def forward( self, spec: Tensor, feat_erb: Tensor, feat_spec: Tensor, # Not used, take spec modified by mask instead ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Forward method of DeepFilterNet2. Args: spec (Tensor): Spectrum of shape [B, 1, T, F, 2] feat_erb (Tensor): ERB features of shape [B, 1, T, E] feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F'] Returns: spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2] m (Tensor): ERB mask estimate of shape [B, 1, T, E] lsnr (Tensor): Local SNR estimate of shape [B, T, 1] """ feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) feat_erb = self.pad_feat(feat_erb) feat_spec = self.pad_feat(feat_spec) e0, e1, e2, e3, emb, c0, lsnr = self.enc(feat_erb, feat_spec) m = self.erb_dec(emb, e3, e2, e1, e0) m = self.pad_out(m.unsqueeze(-1)).squeeze(-1) spec = self.mask(spec, m) if self.run_df: df_coefs, df_alpha = self.df_dec(emb, c0) df_coefs = self.pad_out(df_coefs) if self.pad_specf: # Only pad the lower part of the spectrum. spec_f = self.pad_spec(spec) spec_f = self.df_op(spec_f, df_coefs) spec[..., : self.nb_df, :] = spec_f[..., : self.nb_df, :] else: spec = self.pad_spec(spec) spec = self.df_op(spec, df_coefs) else: df_alpha = torch.zeros(()) return spec, m, lsnr, df_alpha