File size: 6,439 Bytes
dddb9f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel
BirdAST_FEATURE_EXTRACTOR = ASTFeatureExtractor()
DEFAULT_SR = 16_000
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_N_CLASSES = 728
DEFAULT_ACTIVATION = "silu"
DEFAULT_N_MLP_LAYERS = 1
def birdast_seq_preprocess(audio_array, sr=DEFAULT_SR):
"""
Preprocess audio array for BirdAST model
audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
sr: int, sampling rate of the audio array (default: 16_000)
Note:
1. The audio array should be normalized to [-1, 1].
2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
"""
# Extract features
features = BirdAST_FEATURE_EXTRACTOR(audio_array, sampling_rate=sr, padding="max_length", return_tensors="pt")
# Convert to PyTorch tensor
spectrogram = torch.tensor(features['input_values']).squeeze(0)
return spectrogram
def birdast_seq_inference(
model_weights,
spectrogram,
device = 'cpu',
backbone_name=DEFAULT_BACKBONE,
n_classes=DEFAULT_N_CLASSES,
activation=DEFAULT_ACTIVATION,
n_mlp_layers=DEFAULT_N_MLP_LAYERS
):
"""
Perform inference on BirdAST model
model_weights: list, list of model weights
spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
device: str, device to run inference (default: 'cpu')
backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
n_classes: int, number of classes (default: 728)
activation: str, activation function (default: 'silu')
n_mlp_layers: int, number of MLP layers (default: 1)
Returns:
predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
"""
model = BirdAST_Seq(
backbone_name=backbone_name,
n_classes=n_classes,
n_mlp_layers=n_mlp_layers,
activation=activation
)
predict_collects = []
for _weight in model_weights:
model.load_state_dict(torch.load(_weight, map_location=device))
if device != 'cpu': model.to(device)
model.eval()
with torch.no_grad():
if device != 'cpu': spectrogram = spectrogram.to(device)
#check if the input tensor is in the correct shape
if spectrogram.dim() == 2:
spectrogram = spectrogram.unsqueeze(0) #-> (batch_size, n_frames, n_mels)
output = model(spectrogram)
logits = output['logits']
predictions = F.softmax(logits, dim=1)
predict_collects.append(predictions)
if device != 'cpu':
predict_collects = [pred.cpu() for pred in predict_collects]
predict_collects = torch.cat(predict_collects, dim=0).numpy()
return predict_collects
class SelfAttentionPooling(nn.Module):
"""
Implementation of SelfAttentionPooling
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
https://arxiv.org/pdf/2008.01077v1.pdf
"""
def __init__(self, input_dim):
super(SelfAttentionPooling, self).__init__()
self.W = nn.Linear(input_dim, 1)
self.softmax = nn.Softmax(dim=1)
def forward(self, batch_rep):
"""
input:
batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
attention_weight:
att_w : size (N, T, 1)
return:
utter_rep: size (N, H)
"""
att_w = self.softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
utter_rep = torch.sum(batch_rep * att_w, dim=1)
return utter_rep
class BirdAST_Seq(nn.Module):
def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'):
super(BirdAST_Seq, self).__init__()
# pre-trained backbone
backbone_config = ASTConfig.from_pretrained(backbone_name)
self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config)
self.hidden_size = backbone_config.hidden_size
# set activation functions
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'silu':
self.activation = nn.SiLU()
elif activation == 'gelu':
self.activation = nn.GELU()
else:
raise ValueError("Unsupported activation function. Choose 'relu', 'silu' or 'gelu'")
#define self-attention pooling layer
self.sa_pool = SelfAttentionPooling(self.hidden_size)
# define MLP layers with activation
layers = []
for _ in range(n_mlp_layers):
layers.append(nn.Linear(self.hidden_size, self.hidden_size))
layers.append(self.activation)
layers.append(nn.Linear(self.hidden_size, n_classes))
self.mlp = nn.Sequential(*layers)
def forward(self, spectrogram):
# spectrogram: (batch_size, n_mels, n_frames)
# output: (batch_size, n_classes)
ast_output = self.ast(spectrogram, output_hidden_states=False)
hidden_state = ast_output.last_hidden_state
pool_output = self.sa_pool(hidden_state)
logits = self.mlp(pool_output)
return {'logits': logits}
if __name__ == '__main__':
import numpy as np
import matplotlib.pyplot as plt
# example usage of BirdAST_Seq
# create random audio array
audio_array = np.random.randn(160_000 * 10)
# Preprocess audio array
spectrogram = birdast_seq_preprocess(audio_array)
model_weights_dir = '/workspace/voice_of_jungle/training_logs'
# Load model weights
model_weights = [f'{model_weights_dir}/BirdAST_SeqPool_GroupKFold_fold_{i}.pth' for i in range(5)]
# Perform inference
predictions = birdast_seq_inference(model_weights, spectrogram.unsqueeze(0))
# Plot predictions
fig, ax = plt.subplots()
for i, pred in enumerate(predictions):
ax.plot(pred[0], label=f'model_{i}')
ax.legend()
fig.savefig('test_BirdAST_Seq.png')
print("Inference completed successfully!") |