shangeth commited on
Commit
c8fbf2f
·
verified ·
1 Parent(s): 4e326ae

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +68 -0
trainer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from peft import LoraConfig, get_peft_model, PeftModel
6
+ import pytorch_lightning as pl
7
+ from model import HubertXCNNEnoder
8
+
9
+ class SpeechLLMLightning(pl.LightningModule):
10
+ def __init__(self, audio_enc_dim=512, llm_dim=2048, llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
11
+ super().__init__()
12
+ self.save_hyperparameters()
13
+
14
+ self.audio_enc_dim = audio_enc_dim
15
+ self.llm_dim = llm_dim
16
+ self.llm_name = llm_name
17
+
18
+ self.audio_encoder = HubertXCNNEnoder(self.audio_enc_dim, self.llm_dim)
19
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_name)
20
+ self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
21
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
22
+ self.llm_name,
23
+ device_map="auto",
24
+ )
25
+
26
+ peft_config = LoraConfig(
27
+ r=4,
28
+ lora_alpha=8,
29
+ target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
30
+ lora_dropout=0.05,
31
+ task_type="CAUSAL_LM",
32
+ )
33
+
34
+ self.llm_model = get_peft_model(self.llm_model, peft_config)
35
+ self.llm_model.print_trainable_parameters()
36
+
37
+ for param in self.llm_model.parameters():
38
+ param.requires_grad = False
39
+
40
+ self.audio_encoder.eval()
41
+ self.llm_model.eval()
42
+
43
+
44
+ def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
45
+ batch_size = mel.shape[0]
46
+
47
+ speech_embeds = self.audio_encoder(mel)
48
+ embedder = self.llm_model.model.model.embed_tokens
49
+ pre_prompt_embeds = embedder(pre_tokenized_ids)
50
+ post_prompt_embeds = embedder(post_tokenized_ids)
51
+ output_prompt_embeds = embedder(output_tokenized_ids)
52
+
53
+ combined_embeds = torch.cat([pre_prompt_embeds, speech_embeds, post_prompt_embeds, output_prompt_embeds], dim=1)
54
+ atts = torch.ones(combined_embeds.size()[:-1], dtype=torch.long).to(combined_embeds.device)
55
+
56
+ input_token_length = pre_tokenized_ids.shape[1] + speech_embeds.shape[1] + post_tokenized_ids.shape[1]
57
+ label_ids = torch.cat([
58
+ torch.ones([batch_size, input_token_length], device=combined_embeds.device)*-100,
59
+ output_tokenized_ids
60
+ ], 1).to(combined_embeds.device).to(torch.int64)
61
+ return combined_embeds, atts, label_ids
62
+
63
+ def forward(self, embeds, atts, label_ids):
64
+ return self.llm_model(
65
+ inputs_embeds=embeds,
66
+ attention_mask=atts,
67
+ labels=label_ids,
68
+ )