Feature Extraction
Transformers
Safetensors
English
custom_model
multi-modal
speech-language
custom_code
Eval Results
File size: 5,248 Bytes
1c6b840
 
 
c4ec89c
1c6b840
 
 
 
053eb5b
1c6b840
c4ec89c
 
 
1c6b840
 
 
 
 
 
 
 
 
 
 
 
 
053eb5b
 
 
1c6b840
 
 
 
 
 
053eb5b
 
 
c4ec89c
 
2c3d5a3
3f65edd
053eb5b
2c3d5a3
 
 
 
 
 
 
 
c4ec89c
 
1c6b840
8e999ad
 
1c6b840
 
8e999ad
c4ec89c
1c6b840
 
 
 
 
 
 
 
 
 
 
 
 
 
8e999ad
 
1c6b840
 
 
8e999ad
053eb5b
1c6b840
 
 
 
 
 
 
 
 
 
 
 
8e999ad
 
 
1c6b840
 
 
 
8e999ad
1c6b840
 
 
 
187d4d3
1c6b840
 
053eb5b
c4ec89c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch import nn
import torchaudio
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, HubertModel, AutoProcessor, AutoConfig, AutoModel
from .config import SpeechLLMModelConfig
from peft import LoraConfig, get_peft_model

class HubertXCNNEnoder(nn.Module):
    def __init__(self, audio_enc_dim, llm_dim, encoder_name):
        super().__init__()
        config = AutoConfig.from_pretrained(encoder_name)
        self.encoder =  AutoModel.from_config(config)  

        self.cnn = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(audio_enc_dim, llm_dim // 2, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(llm_dim // 2, llm_dim, kernel_size=5, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv1d(llm_dim, llm_dim, kernel_size=3, stride=1, padding=0),
        )

    def forward(self, x):
        x = self.encoder(x).last_hidden_state
        x = self.cnn(x.transpose(1, 2)).transpose(1, 2)
        return x
    
    def return_device(self):
        return next(self.parameters()).device

class SpeechLLMModel(PreTrainedModel):
    config_class = SpeechLLMModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.audio_processor = AutoProcessor.from_pretrained(config.audio_processor_name)
        self.audio_encoder = HubertXCNNEnoder(config.audio_enc_dim, config.llm_dim, config.audio_encoder_name)

        llm_config = AutoConfig.from_pretrained(config.llm_model_name)
        self.llm_model =  AutoModelForCausalLM.from_config(llm_config)  
        self.llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_name)
        self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token

        peft_config = LoraConfig(
            r=4,
            lora_alpha=8,
            target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
            lora_dropout=0.05,
            task_type="CAUSAL_LM",
        )
        self.llm_model = get_peft_model(self.llm_model, peft_config)
        self.llm_model = self.llm_model.merge_and_unload()
        

    def encode(self, speech, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
        batch_size = speech.shape[0]

        with torch.no_grad():
            speech_embeds = self.audio_encoder(speech)
        embedder = self.llm_model.model.embed_tokens
        pre_prompt_embeds = embedder(pre_tokenized_ids)
        post_prompt_embeds = embedder(post_tokenized_ids)
        output_prompt_embeds = embedder(output_tokenized_ids)

        combined_embeds = torch.cat([pre_prompt_embeds, speech_embeds, post_prompt_embeds, output_prompt_embeds], dim=1)
        atts = torch.ones(combined_embeds.size()[:-1], dtype=torch.long).to(combined_embeds.device)

        input_token_length = pre_tokenized_ids.shape[1] + speech_embeds.shape[1] + post_tokenized_ids.shape[1]
        label_ids = torch.cat([
            torch.ones([batch_size, input_token_length], device=combined_embeds.device) * -100,
            output_tokenized_ids
        ], 1).to(combined_embeds.device).to(torch.int64)
        return combined_embeds, atts, label_ids

    def forward(self, audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
        combined_embeds, atts, label_ids = self.encode(audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
        outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
        return outputs
    
    def generate_meta(self, audio_path=None, audio_tensor=None, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
        device = self.audio_encoder.return_device()
        pre_speech_prompt = f'''Instruction:
{instruction}

Input: 
<speech>'''
        post_speech_prompt = f'''</speech>

Output:'''
        output_prompt = '\n<s>'

        with torch.no_grad():

            if audio_tensor == None and audio_path != None:
                audio_tensor, sr = torchaudio.load(audio_path)
            audio_tensor = self.audio_processor(audio_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
            pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
            post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
            output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]

            combined_embeds, atts, label_ids = self.encode(audio_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))

            out = self.llm_model.generate(
                inputs_embeds=combined_embeds,
                max_new_tokens=max_new_tokens,
                pad_token_id=self.llm_tokenizer.pad_token_id
            ).cpu().tolist()[0]

            output_text = self.llm_tokenizer.decode(out, skip_special_tokens=True)
            return output_text