from lib import * import contextlib import io import laion_clap import torch class AudioCaptioner(torch.nn.Module): def get_dummy_token(self, batch_size: int) -> torch.Tensor: return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64) def embed_waveform(self, waveform): # compute the prefix input_dict = { 'waveform': waveform # you can add more key-values } audio_embeds = self.clap_model.model.encode_audio( input_dict, device=waveform.device ) # get BxD-dim embedding (last layer) D = 1024 -> 512 after audio projection audio_embedding = torch.nn.functional.normalize(self.clap_model.model.audio_projection(audio_embeds['embedding']), dim=-1) return audio_embedding def create_prefix(self, waveform, batch_size): if waveform is not None: audio_embedding = self.embed_waveform(waveform) else: audio_embedding = torch.zeros(batch_size, self.prefix_size).cuda() # project the prefix through map net and append it prefix_projections = self.clip_project(audio_embedding).view(-1, self.prefix_length, self.gpt_embedding_size) return prefix_projections def forward(self, tokens: torch.Tensor, waveform: torch.Tensor, mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, freeze_gpt = False): # embed the text embedding_text = self.gpt.transformer.wte(tokens) prefix_projections = self.create_prefix(waveform, tokens.shape[0]) embedding_text = torch.cat((prefix_projections, embedding_text), dim=1) # offset labels if labels is not None: dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) labels = torch.cat((dummy_token, tokens), dim=1) # push through GPT if freeze_gpt: with torch.no_grad(): out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask) else: out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask) return out def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512, num_layers: int = 8): super(AudioCaptioner, self).__init__() self.prefix_size = prefix_size self.prefix_length = prefix_length self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length)) self.clap_model = laion_clap.CLAP_Module( enable_fusion=False, amodel = 'HTSAT-base' ) with contextlib.redirect_stdout(io.StringIO()): self.clap_model.load_ckpt(ckpt = 'checkpoints/music_audioset_epoch_15_esc_90.14.pt')