inLine-XJY's picture
Upload 335 files
2b5b9ef verified
raw
history blame
2.2 kB
import numpy as np
import torch.nn.functional as F
from torch import nn
from .model import MLPLayers
class LinearProbe(nn.Module):
def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
"""
Args:
model: nn.Module
mlp: bool, if True, then use the MLP layer as the linear probe module
freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
in_ch: int, the output channel from CLAP model
out_ch: int, the output channel from linear probe (class_num)
act: torch.nn.functional, the activation function before the loss function
"""
super().__init__()
in_ch = 512
self.clap_model = model
self.clap_model.text_branch = None # to save memory
self.freeze = freeze
if mlp:
self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
else:
self.lp_layer = nn.Linear(in_ch, out_ch)
if self.freeze:
for param in self.clap_model.parameters():
param.requires_grad = False
if act == 'None':
self.act = None
elif act == 'relu':
self.act = nn.ReLU()
elif act == 'elu':
self.act = nn.ELU()
elif act == 'prelu':
self.act = nn.PReLU(num_parameters=in_ch)
elif act == 'softmax':
self.act = nn.Softmax(dim=-1)
elif act == 'sigmoid':
self.act = nn.Sigmoid()
def forward(self, x, mix_lambda=None, device=None):
"""
Args:
x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
mix_lambda: torch.tensor [batch], the mixup lambda
Returns:
class_prob: torch.tensor [batch, class_num]
"""
# batchnorm cancel grandient
if self.freeze:
self.clap_model.eval()
x = self.clap_model.audio_projection(
self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"])
out = self.lp_layer(x)
if self.act is not None:
out = self.act(out)
return out