muvis / modeling.py
juliagsy's picture
Create modeling.py
0db0a20 verified
from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel
import numpy as np
import torch
import torch.nn as nn
from einops import reduce
class MuVis(nn.Module):
def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000):
super(MuVis, self).__init__()
self.sampling_rate = sampling_rate
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True)
self.wav_lin = nn.Linear(embed_dims, latent_dims)
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True)
self.img_lin = nn.Linear(embed_dims, latent_dims)
def forward(self, wav=None, img=None):
wav_out = None
img_out = None
if wav is not None:
wav_out = self.ast(**wav)["last_hidden_state"]
wav_out = self.wav_lin(wav_out)
wav_out = reduce(wav_out, "b n d -> b d", "mean")
wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True)
if img is not None:
img_out = self.vit(**img)["last_hidden_state"]
img_out = self.img_lin(img_out)
img_out = reduce(img_out, "b n d -> b d", "mean")
img_out = img_out / img_out.norm(dim=-1, keepdim=True)
assert wav_out is not None or img_out is not None
if wav_out is None or img_out is None:
return wav_out if img_out is None else img_out
return (wav_out, img_out)
class MuVisConfig(PretrainedConfig):
model_type = "muvis"
def __init__(
self,
embed_dims=768,
latent_dims=128,
sampling_rate=16000,
**kwargs,
):
self.embed_dims = embed_dims
self.latent_dims = latent_dims
self.sampling_rate = sampling_rate
super().__init__(**kwargs)
class MuVisModel(PreTrainedModel):
config_class = MuVisConfig
def __init__(self, config):
super().__init__(config)
self.model = MuVis(
embed_dims=config.embed_dims,
latent_dims=config.latent_dims,
sampling_rate=config.sampling_rate,
)
def forward(self, wav=None, img=None):
return self.model(wav=wav, img=img)