Feature Extraction
Transformers
Safetensors
English
custom_model
multi-modal
speech-language
custom_code
Eval Results
File size: 5,639 Bytes
392f8fa
 
 
e913847
392f8fa
 
 
 
 
 
e913847
 
392f8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e67d00
e913847
 
392f8fa
1bd6289
392f8fa
1e67d00
 
 
 
 
 
 
 
4cf34a2
392f8fa
90b0faf
 
392f8fa
 
90b0faf
392f8fa
 
4cf34a2
392f8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
90b0faf
 
392f8fa
 
 
90b0faf
392f8fa
 
 
 
 
 
 
 
 
 
 
 
90b0faf
 
 
392f8fa
 
 
 
 
90b0faf
392f8fa
 
 
 
e913b2a
392f8fa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from torch import nn
import torchaudio
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, HubertModel, AutoProcessor, AutoConfig, AutoModel, AutoFeatureExtractor
from .config import SpeechLLMModelConfig
from peft import LoraConfig, get_peft_model

class TransformerAudioEnoder(nn.Module):
    def __init__(self, model_name='microsoft/wavlm-large', finetune=False):
        super().__init__()
        config = AutoConfig.from_pretrained(model_name)
        self.encoder =  AutoModel.from_config(config)  

    def forward(self, x):
        return self.encoder(x).last_hidden_state
    
    def return_device(self):
        return next(self.parameters()).device


class CNNConnector(nn.Module):
    def __init__(self, in_channels, out_channels, k=2):
        super().__init__()
        self.layer = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(in_channels, out_channels//2, kernel_size=5,
                      stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(out_channels//2, out_channels, kernel_size=5,
                      stride=k, padding=0),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=5,
                      stride=1, padding=0),
        )

    def forward(self, x):
        return self.layer(x.transpose(1,2)).transpose(1,2)
    
    
class SpeechLLMModel(PreTrainedModel):
    config_class = SpeechLLMModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.audio_processor = AutoFeatureExtractor.from_pretrained(config.audio_processor_name)
        self.audio_encoder = TransformerAudioEnoder(config.audio_encoder_name)
        self.connector = CNNConnector(config.audio_enc_dim, config.llm_dim)
        
        llm_config = AutoConfig.from_pretrained(config.llm_model_name)
        self.llm_model =  AutoModelForCausalLM.from_config(llm_config)  
        self.llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_name)
        self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token

        peft_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
            lora_dropout=0.05,
            task_type="CAUSAL_LM",
        )
        self.llm_model = get_peft_model(self.llm_model, peft_config)
        self.llm_model = self.llm_model.merge_and_unload()

    def encode(self, speech, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
        batch_size = speech.shape[0]

        with torch.no_grad():
            speech_embeds = self.audio_encoder(speech)
            speech_embeds = self.connector(speech_embeds)
            
        embedder = self.llm_model.model.embed_tokens
        pre_prompt_embeds = embedder(pre_tokenized_ids)
        post_prompt_embeds = embedder(post_tokenized_ids)
        output_prompt_embeds = embedder(output_tokenized_ids)

        combined_embeds = torch.cat([pre_prompt_embeds, speech_embeds, post_prompt_embeds, output_prompt_embeds], dim=1)
        atts = torch.ones(combined_embeds.size()[:-1], dtype=torch.long).to(combined_embeds.device)

        input_token_length = pre_tokenized_ids.shape[1] + speech_embeds.shape[1] + post_tokenized_ids.shape[1]
        label_ids = torch.cat([
            torch.ones([batch_size, input_token_length], device=combined_embeds.device) * -100,
            output_tokenized_ids
        ], 1).to(combined_embeds.device).to(torch.int64)
        return combined_embeds, atts, label_ids

    def forward(self, audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
        combined_embeds, atts, label_ids = self.encode(audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
        outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
        return outputs
    
    def generate_meta(self, audio_path=None, audio_tensor=None, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
        device = self.audio_encoder.return_device()
        pre_speech_prompt = f'''Instruction:
{instruction}

Input: 
<speech>'''
        post_speech_prompt = f'''</speech>

Output:'''
        output_prompt = '\n<s>'

        with torch.no_grad():
            if audio_tensor == None and audio_path != None:
                audio_tensor, sr = torchaudio.load(audio_path)
            audio_tensor = self.audio_processor(audio_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values

            pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
            post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
            output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]

            combined_embeds, atts, label_ids = self.encode(audio_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))

            out = self.llm_model.generate(
                inputs_embeds=combined_embeds,
                max_new_tokens=max_new_tokens,
                pad_token_id=self.llm_tokenizer.pad_token_id
            ).cpu().tolist()[0]

            output_text = self.llm_tokenizer.decode(out, skip_special_tokens=True)
            return output_text