Spaces:
Sleeping
Sleeping
# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition | |
# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn | |
# Some layers are re-designed for CLAP | |
import os | |
os.environ['NUMBA_CACHE_DIR'] = '/tmp/' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchlibrosa.stft import Spectrogram, LogmelFilterBank | |
from torchlibrosa.augmentation import SpecAugmentation | |
from .utils import do_mixup, interpolate, pad_framewise_output | |
from .feature_fusion import iAFF, AFF, DAF | |
def init_layer(layer): | |
"""Initialize a Linear or Convolutional layer. """ | |
nn.init.xavier_uniform_(layer.weight) | |
if hasattr(layer, 'bias'): | |
if layer.bias is not None: | |
layer.bias.data.fill_(0.) | |
def init_bn(bn): | |
"""Initialize a Batchnorm layer. """ | |
bn.bias.data.fill_(0.) | |
bn.weight.data.fill_(1.) | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(ConvBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.conv2 = nn.Conv2d(in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.init_weight() | |
def init_weight(self): | |
init_layer(self.conv1) | |
init_layer(self.conv2) | |
init_bn(self.bn1) | |
init_bn(self.bn2) | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
x = input | |
x = F.relu_(self.bn1(self.conv1(x))) | |
x = F.relu_(self.bn2(self.conv2(x))) | |
if pool_type == 'max': | |
x = F.max_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg': | |
x = F.avg_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg+max': | |
x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
x2 = F.max_pool2d(x, kernel_size=pool_size) | |
x = x1 + x2 | |
else: | |
raise Exception('Incorrect argument!') | |
return x | |
class ConvBlock5x5(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(ConvBlock5x5, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(5, 5), stride=(1, 1), | |
padding=(2, 2), bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.init_weight() | |
def init_weight(self): | |
init_layer(self.conv1) | |
init_bn(self.bn1) | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
x = input | |
x = F.relu_(self.bn1(self.conv1(x))) | |
if pool_type == 'max': | |
x = F.max_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg': | |
x = F.avg_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg+max': | |
x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
x2 = F.max_pool2d(x, kernel_size=pool_size) | |
x = x1 + x2 | |
else: | |
raise Exception('Incorrect argument!') | |
return x | |
class AttBlock(nn.Module): | |
def __init__(self, n_in, n_out, activation='linear', temperature=1.): | |
super(AttBlock, self).__init__() | |
self.activation = activation | |
self.temperature = temperature | |
self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) | |
self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) | |
self.bn_att = nn.BatchNorm1d(n_out) | |
self.init_weights() | |
def init_weights(self): | |
init_layer(self.att) | |
init_layer(self.cla) | |
init_bn(self.bn_att) | |
def forward(self, x): | |
# x: (n_samples, n_in, n_time) | |
norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) | |
cla = self.nonlinear_transform(self.cla(x)) | |
x = torch.sum(norm_att * cla, dim=2) | |
return x, norm_att, cla | |
def nonlinear_transform(self, x): | |
if self.activation == 'linear': | |
return x | |
elif self.activation == 'sigmoid': | |
return torch.sigmoid(x) | |
class Cnn14(nn.Module): | |
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, | |
fmax, classes_num, enable_fusion=False, fusion_type='None'): | |
super(Cnn14, self).__init__() | |
window = 'hann' | |
center = True | |
pad_mode = 'reflect' | |
ref = 1.0 | |
amin = 1e-10 | |
top_db = None | |
self.enable_fusion = enable_fusion | |
self.fusion_type = fusion_type | |
# Spectrogram extractor | |
self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, | |
win_length=window_size, window=window, center=center, pad_mode=pad_mode, | |
freeze_parameters=True) | |
# Logmel feature extractor | |
self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, | |
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, | |
freeze_parameters=True) | |
# Spec augmenter | |
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, | |
freq_drop_width=8, freq_stripes_num=2) | |
self.bn0 = nn.BatchNorm2d(64) | |
if (self.enable_fusion) and (self.fusion_type == 'channel_map'): | |
self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) | |
else: | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) | |
self.fc1 = nn.Linear(2048, 2048, bias=True) | |
self.fc_audioset = nn.Linear(2048, classes_num, bias=True) | |
if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']): | |
self.mel_conv1d = nn.Sequential( | |
nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), | |
nn.BatchNorm1d(64) # No Relu | |
) | |
if self.fusion_type == 'daf_1d': | |
self.fusion_model = DAF() | |
elif self.fusion_type == 'aff_1d': | |
self.fusion_model = AFF(channels=64, type='1D') | |
elif self.fusion_type == 'iaff_1d': | |
self.fusion_model = iAFF(channels=64, type='1D') | |
if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']): | |
self.mel_conv2d = nn.Sequential( | |
nn.Conv2d(1, 64, kernel_size=(5,5), stride=(6, 2), padding=(2,2)), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True) | |
) | |
if self.fusion_type == 'daf_2d': | |
self.fusion_model = DAF() | |
elif self.fusion_type == 'aff_2d': | |
self.fusion_model = AFF(channels=64, type='2D') | |
elif self.fusion_type == 'iaff_2d': | |
self.fusion_model = iAFF(channels=64, type='2D') | |
self.init_weight() | |
def init_weight(self): | |
init_bn(self.bn0) | |
init_layer(self.fc1) | |
init_layer(self.fc_audioset) | |
def forward(self, input, mixup_lambda=None, device=None): | |
""" | |
Input: (batch_size, data_length)""" | |
if self.enable_fusion and input["longer"].sum() == 0: | |
# if no audio is longer than 10s, then randomly select one audio to be longer | |
input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True | |
if not self.enable_fusion: | |
x = self.spectrogram_extractor(input['waveform'].to(device=device, non_blocking=True)) # (batch_size, 1, time_steps, freq_bins) | |
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
else: | |
longer_list = input["longer"].to(device=device, non_blocking=True) | |
x = input["mel_fusion"].to(device=device, non_blocking=True) | |
longer_list_idx = torch.where(longer_list)[0] | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']: | |
new_x = x[:,0:1,:,:].clone().contiguous() | |
# local processing | |
if len(longer_list_idx) > 0: | |
fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous() | |
FB,FC,FT,FF = fusion_x_local.size() | |
fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) | |
fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous() | |
fusion_x_local = self.mel_conv1d(fusion_x_local) | |
fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1)) | |
fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2) | |
if fusion_x_local.size(-1) < FT: | |
fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1) | |
else: | |
fusion_x_local = fusion_x_local[:,:,:FT] | |
# 1D fusion | |
new_x = new_x.squeeze(1).permute((0,2,1)).contiguous() | |
new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local) | |
x = new_x.permute((0,2,1)).contiguous()[:,None,:,:] | |
else: | |
x = new_x | |
elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']: | |
x = x # no change | |
if self.training: | |
x = self.spec_augmenter(x) | |
# Mixup on spectrogram | |
if self.training and mixup_lambda is not None: | |
x = do_mixup(x, mixup_lambda) | |
if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']): | |
global_x = x[:,0:1,:,:] | |
# global processing | |
B, C, H, W = global_x.shape | |
global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type='avg') | |
if len(longer_list_idx) > 0: | |
local_x = x[longer_list_idx,1:,:,:].contiguous() | |
TH = global_x.size(-2) | |
# local processing | |
B, C, H, W = local_x.shape | |
local_x = local_x.view(B*C,1,H,W) | |
local_x = self.mel_conv2d(local_x) | |
local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3)) | |
local_x = local_x.permute((0,2,1,3,4)).contiguous().flatten(2,3) | |
TB,TC,_,TW = local_x.size() | |
if local_x.size(-2) < TH: | |
local_x = torch.cat([local_x, torch.zeros((TB,TC,TH-local_x.size(-2),TW), device=global_x.device)], dim=-2) | |
else: | |
local_x = local_x[:,:,:TH,:] | |
global_x[longer_list_idx] = self.fusion_model(global_x[longer_list_idx],local_x) | |
x = global_x | |
else: | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = torch.mean(x, dim=3) | |
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
latent_x = latent_x1 + latent_x2 | |
latent_x = latent_x.transpose(1, 2) | |
latent_x = F.relu_(self.fc1(latent_x)) | |
latent_output = interpolate(latent_x, 32) | |
(x1, _) = torch.max(x, dim=2) | |
x2 = torch.mean(x, dim=2) | |
x = x1 + x2 | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu_(self.fc1(x)) | |
embedding = F.dropout(x, p=0.5, training=self.training) | |
clipwise_output = torch.sigmoid(self.fc_audioset(x)) | |
output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output} | |
return output_dict | |
class Cnn6(nn.Module): | |
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, | |
fmax, classes_num, enable_fusion=False, fusion_type='None'): | |
super(Cnn6, self).__init__() | |
window = 'hann' | |
center = True | |
pad_mode = 'reflect' | |
ref = 1.0 | |
amin = 1e-10 | |
top_db = None | |
self.enable_fusion = enable_fusion | |
self.fusion_type = fusion_type | |
# Spectrogram extractor | |
self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, | |
win_length=window_size, window=window, center=center, pad_mode=pad_mode, | |
freeze_parameters=True) | |
# Logmel feature extractor | |
self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, | |
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, | |
freeze_parameters=True) | |
# Spec augmenter | |
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, | |
freq_drop_width=8, freq_stripes_num=2) | |
self.bn0 = nn.BatchNorm2d(64) | |
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) | |
self.fc1 = nn.Linear(512, 512, bias=True) | |
self.fc_audioset = nn.Linear(512, classes_num, bias=True) | |
self.init_weight() | |
def init_weight(self): | |
init_bn(self.bn0) | |
init_layer(self.fc1) | |
init_layer(self.fc_audioset) | |
def forward(self, input, mixup_lambda=None, device=None): | |
""" | |
Input: (batch_size, data_length)""" | |
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) | |
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
if self.training: | |
x = self.spec_augmenter(x) | |
# Mixup on spectrogram | |
if self.training and mixup_lambda is not None: | |
x = do_mixup(x, mixup_lambda) | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = torch.mean(x, dim=3) | |
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
latent_x = latent_x1 + latent_x2 | |
latent_x = latent_x.transpose(1, 2) | |
latent_x = F.relu_(self.fc1(latent_x)) | |
latent_output = interpolate(latent_x, 16) | |
(x1, _) = torch.max(x, dim=2) | |
x2 = torch.mean(x, dim=2) | |
x = x1 + x2 | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu_(self.fc1(x)) | |
embedding = F.dropout(x, p=0.5, training=self.training) | |
clipwise_output = torch.sigmoid(self.fc_audioset(x)) | |
output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output} | |
return output_dict | |
class Cnn10(nn.Module): | |
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, | |
fmax, classes_num, enable_fusion=False, fusion_type='None'): | |
super(Cnn10, self).__init__() | |
window = 'hann' | |
center = True | |
pad_mode = 'reflect' | |
ref = 1.0 | |
amin = 1e-10 | |
top_db = None | |
self.enable_fusion = enable_fusion | |
self.fusion_type = fusion_type | |
# Spectrogram extractor | |
self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, | |
win_length=window_size, window=window, center=center, pad_mode=pad_mode, | |
freeze_parameters=True) | |
# Logmel feature extractor | |
self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, | |
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, | |
freeze_parameters=True) | |
# Spec augmenter | |
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, | |
freq_drop_width=8, freq_stripes_num=2) | |
self.bn0 = nn.BatchNorm2d(64) | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
self.fc1 = nn.Linear(1024, 1024, bias=True) | |
self.fc_audioset = nn.Linear(1024, classes_num, bias=True) | |
self.init_weight() | |
def init_weight(self): | |
init_bn(self.bn0) | |
init_layer(self.fc1) | |
init_layer(self.fc_audioset) | |
def forward(self, input, mixup_lambda=None, device=None): | |
""" | |
Input: (batch_size, data_length)""" | |
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) | |
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
if self.training: | |
x = self.spec_augmenter(x) | |
# Mixup on spectrogram | |
if self.training and mixup_lambda is not None: | |
x = do_mixup(x, mixup_lambda) | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = torch.mean(x, dim=3) | |
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
latent_x = latent_x1 + latent_x2 | |
latent_x = latent_x.transpose(1, 2) | |
latent_x = F.relu_(self.fc1(latent_x)) | |
latent_output = interpolate(latent_x, 32) | |
(x1, _) = torch.max(x, dim=2) | |
x2 = torch.mean(x, dim=2) | |
x = x1 + x2 | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu_(self.fc1(x)) | |
embedding = F.dropout(x, p=0.5, training=self.training) | |
clipwise_output = torch.sigmoid(self.fc_audioset(x)) | |
output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output} | |
return output_dict | |
def create_pann_model(audio_cfg, enable_fusion=False, fusion_type='None'): | |
try: | |
ModelProto = eval(audio_cfg.model_name) | |
model = ModelProto( | |
sample_rate = audio_cfg.sample_rate, | |
window_size = audio_cfg.window_size, | |
hop_size =audio_cfg.hop_size, | |
mel_bins = audio_cfg.mel_bins, | |
fmin = audio_cfg.fmin, | |
fmax = audio_cfg.fmax, | |
classes_num = audio_cfg.class_num, | |
enable_fusion = enable_fusion, | |
fusion_type = fusion_type | |
) | |
return model | |
except: | |
raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.') | |