File size: 3,027 Bytes
48ac659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8971856
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from lib import *

import contextlib
import io
import laion_clap
import torch

class AudioCaptioner(torch.nn.Module):

    def get_dummy_token(self, batch_size: int) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64)

    def embed_waveform(self, waveform):
        # compute the prefix
        input_dict = {
            'waveform': waveform # you can add more key-values
        }
        audio_embeds = self.clap_model.model.encode_audio(
            input_dict,
            device=waveform.device
        )

        # get BxD-dim embedding (last layer) D = 1024 -> 512 after audio projection
        audio_embedding = torch.nn.functional.normalize(self.clap_model.model.audio_projection(audio_embeds['embedding']), dim=-1)
        return audio_embedding

    def create_prefix(self, waveform, batch_size):
        if waveform is not None:
            audio_embedding = self.embed_waveform(waveform)
        else:
            audio_embedding = torch.zeros(batch_size, self.prefix_size).cuda()
        # project the prefix through map net and append it
        prefix_projections = self.clip_project(audio_embedding).view(-1, self.prefix_length, self.gpt_embedding_size)
        return prefix_projections

    def forward(self, tokens: torch.Tensor, waveform: torch.Tensor, mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None, freeze_gpt = False):
        # embed the text
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.create_prefix(waveform, tokens.shape[0])
        embedding_text = torch.cat((prefix_projections, embedding_text), dim=1)
        # offset labels
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        # push through GPT
        if freeze_gpt:
            with torch.no_grad():
                out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
        else:
            out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
                 num_layers: int = 8):
        super(AudioCaptioner, self).__init__()
        self.prefix_size = prefix_size
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
                                 self.gpt_embedding_size * prefix_length))
        self.clap_model = laion_clap.CLAP_Module(
            enable_fusion=False,
            amodel = 'HTSAT-base'
        )
        with contextlib.redirect_stdout(io.StringIO()):
            self.clap_model.load_ckpt(ckpt = 'checkpoints/music_audioset_epoch_15_esc_90.14.pt')