|
from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import reduce |
|
|
|
|
|
class MuVis(nn.Module): |
|
def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000): |
|
super(MuVis, self).__init__() |
|
self.sampling_rate = sampling_rate |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True) |
|
self.wav_lin = nn.Linear(embed_dims, latent_dims) |
|
|
|
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True) |
|
self.img_lin = nn.Linear(embed_dims, latent_dims) |
|
|
|
|
|
def forward(self, wav=None, img=None): |
|
wav_out = None |
|
img_out = None |
|
|
|
if wav is not None: |
|
wav_out = self.ast(**wav)["last_hidden_state"] |
|
wav_out = self.wav_lin(wav_out) |
|
wav_out = reduce(wav_out, "b n d -> b d", "mean") |
|
wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True) |
|
|
|
if img is not None: |
|
img_out = self.vit(**img)["last_hidden_state"] |
|
img_out = self.img_lin(img_out) |
|
img_out = reduce(img_out, "b n d -> b d", "mean") |
|
img_out = img_out / img_out.norm(dim=-1, keepdim=True) |
|
|
|
|
|
assert wav_out is not None or img_out is not None |
|
|
|
if wav_out is None or img_out is None: |
|
return wav_out if img_out is None else img_out |
|
return (wav_out, img_out) |
|
|
|
|
|
class MuVisConfig(PretrainedConfig): |
|
model_type = "muvis" |
|
|
|
def __init__( |
|
self, |
|
embed_dims=768, |
|
latent_dims=128, |
|
sampling_rate=16000, |
|
**kwargs, |
|
): |
|
self.embed_dims = embed_dims |
|
self.latent_dims = latent_dims |
|
self.sampling_rate = sampling_rate |
|
super().__init__(**kwargs) |
|
|
|
|
|
class MuVisModel(PreTrainedModel): |
|
config_class = MuVisConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = MuVis( |
|
embed_dims=config.embed_dims, |
|
latent_dims=config.latent_dims, |
|
sampling_rate=config.sampling_rate, |
|
) |
|
|
|
def forward(self, wav=None, img=None): |
|
return self.model(wav=wav, img=img) |