farzadab commited on
Commit
3e41a18
1 Parent(s): 8cbab83

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. ultravox_model.py +407 -0
ultravox_model.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, Optional, Set, Tuple, Union
3
+
4
+ import peft
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import transformers
9
+ import transformers.activations
10
+ import transformers.modeling_outputs
11
+ import transformers.models
12
+
13
+ # We must use relative import in this directory to allow uploading to HF Hub
14
+ # Even "from . import X" pattern doesn't work (undocumented and unclear why)
15
+ from .ultravox_config import UltravoxConfig
16
+ from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
17
+
18
+
19
+ class UltravoxModel(
20
+ transformers.LlamaPreTrainedModel,
21
+ transformers.GenerationMixin,
22
+ ):
23
+ """
24
+ The Ultravox model which consists of an audio encoder and a language model.
25
+
26
+ Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
27
+ projected to the language model's embedding space using a few linear layers.
28
+ The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
29
+
30
+ A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
31
+
32
+ Parameters:
33
+ config: Model configuration class with all the parameters of the model.
34
+ """
35
+
36
+ config_class = UltravoxConfig
37
+ config: UltravoxConfig # for type hinting
38
+ _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
39
+
40
+ def __init__(self, config: UltravoxConfig):
41
+ super().__init__(config)
42
+
43
+ self.keep_params: Set[str] = set()
44
+ self.vocab_size = config.vocab_size
45
+
46
+ self.audio_tower = self._create_audio_tower(config)
47
+ self.multi_modal_projector = UltravoxProjector(config)
48
+ self.language_model = self._create_language_model(config)
49
+
50
+ self.post_init()
51
+
52
+ def get_input_embeddings(self):
53
+ return self.language_model.get_input_embeddings()
54
+
55
+ def set_input_embeddings(self, value):
56
+ self.language_model.set_input_embeddings(value)
57
+
58
+ def get_output_embeddings(self):
59
+ return self.language_model.get_output_embeddings()
60
+
61
+ def set_output_embeddings(self, new_embeddings):
62
+ self.language_model.set_output_embeddings(new_embeddings)
63
+
64
+ def set_decoder(self, decoder):
65
+ self.language_model.set_decoder(decoder)
66
+
67
+ def get_decoder(self):
68
+ return self.language_model.get_decoder()
69
+
70
+ def tie_weights(self):
71
+ return self.language_model.tie_weights()
72
+
73
+ def _setup_cache(
74
+ self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
75
+ ):
76
+ self.language_model._setup_cache(cache_cls, max_batch_size, max_cache_len)
77
+
78
+ def _reorder_cache(self, past_key_values, beam_idx):
79
+ return self.language_model._reorder_cache(past_key_values, beam_idx)
80
+
81
+ def resize_token_embeddings(
82
+ self,
83
+ new_num_tokens: Optional[int] = None,
84
+ pad_to_multiple_of: Optional[int] = None,
85
+ ) -> nn.Embedding:
86
+ model_embeds = self.language_model.resize_token_embeddings(
87
+ new_num_tokens, pad_to_multiple_of
88
+ )
89
+ # update vocab size
90
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
91
+ self.config.vocab_size = model_embeds.num_embeddings
92
+ self.vocab_size = model_embeds.num_embeddings
93
+ return model_embeds
94
+
95
+ def forward(
96
+ self,
97
+ input_ids: torch.Tensor,
98
+ audio_values: Optional[torch.FloatTensor] = None,
99
+ inputs_embeds: Optional[torch.FloatTensor] = None,
100
+ labels: Optional[torch.Tensor] = None,
101
+ attention_mask: Optional[torch.Tensor] = None,
102
+ audio_token_start_idx: Optional[torch.Tensor] = None,
103
+ audio_token_len: Optional[torch.Tensor] = None,
104
+ past_key_values: Optional[Tuple] = None,
105
+ **kwargs,
106
+ ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
107
+ """
108
+ Forward pass for the Ultravox model.
109
+
110
+ `input_ids` are the tokenized text input. They are embedded by the language model as usual.
111
+ `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
112
+ projected to the language model's embedding space using a few linear layers.
113
+ The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
114
+ of the audio embeddings in the merged embeddings.
115
+
116
+ Args:
117
+ input_ids: The tokenized text input.
118
+ audio_values: The processed audio values.
119
+ inputs_embeds: The embeddings for the input tokens.
120
+ labels: The tokenized text labels.
121
+ attention_mask: The attention mask for the input.
122
+ position_ids: The position ids for the input.
123
+ past_key_values: The past key value cache for the language model attention layers.
124
+ **kwargs: Additional keyword arguments. Passed directly to the language model.
125
+ """
126
+ if inputs_embeds is None:
127
+ # B x T -> B x T x D
128
+ inputs_embeds = self.get_input_embeddings().forward(input_ids)
129
+
130
+ if audio_values is not None:
131
+ assert (
132
+ audio_token_start_idx is not None and audio_token_len is not None
133
+ ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
134
+ assert (
135
+ len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
136
+ ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
137
+
138
+ # B x A/3200 x D
139
+ audio_tower_output = self.audio_tower.forward(
140
+ audio_values
141
+ ).last_hidden_state
142
+ audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
143
+
144
+ audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
145
+
146
+ # combine audio and text embeddings
147
+ for i, (audio, start, length) in enumerate(
148
+ zip(audio_embeds, audio_token_start_idx, audio_token_len)
149
+ ):
150
+ length = min(length, audio.shape[0])
151
+ inputs_embeds[i, start : start + length] = audio[:length]
152
+
153
+ lm_output = self.language_model.forward(
154
+ inputs_embeds=inputs_embeds,
155
+ labels=labels,
156
+ attention_mask=attention_mask,
157
+ past_key_values=past_key_values,
158
+ **kwargs,
159
+ )
160
+
161
+ return lm_output
162
+
163
+ def prepare_inputs_for_generation(
164
+ self,
165
+ input_ids: torch.Tensor,
166
+ audio_values: Optional[torch.FloatTensor] = None,
167
+ audio_token_start_idx: Optional[torch.Tensor] = None,
168
+ audio_token_len: Optional[torch.Tensor] = None,
169
+ past_key_values: Optional[Tuple] = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ inputs_embeds: Optional[torch.Tensor] = None,
172
+ **kwargs,
173
+ ) -> Dict[str, Any]:
174
+ model_input = self.language_model.prepare_inputs_for_generation(
175
+ input_ids=input_ids,
176
+ past_key_values=past_key_values,
177
+ attention_mask=attention_mask,
178
+ inputs_embeds=inputs_embeds,
179
+ **kwargs,
180
+ )
181
+
182
+ if past_key_values is None and audio_values is not None:
183
+ # We only want to use audio features in the 1st generation step
184
+ model_input["audio_values"] = audio_values
185
+ model_input["audio_token_start_idx"] = audio_token_start_idx
186
+ model_input["audio_token_len"] = audio_token_len
187
+
188
+ return model_input
189
+
190
+ @classmethod
191
+ def _create_audio_tower(
192
+ cls, config: UltravoxConfig
193
+ ) -> Union[transformers.Wav2Vec2Model, ModifiedWhisperEncoder]:
194
+ if config.audio_model_id is not None:
195
+ if "whisper" in config.audio_model_id is not None:
196
+ audio_tower = ModifiedWhisperEncoder.from_pretrained(
197
+ config.audio_model_id
198
+ )
199
+ else:
200
+ audio_tower = transformers.AutoModel.from_pretrained(
201
+ config.audio_model_id
202
+ )
203
+ else:
204
+ if "whisper" in config.audio_config._name_or_path:
205
+ audio_tower = ModifiedWhisperEncoder(config.audio_config)
206
+ else:
207
+ audio_tower = transformers.AutoModel.from_config(config.audio_config)
208
+
209
+ if isinstance(
210
+ audio_tower,
211
+ (transformers.Wav2Vec2BertModel, transformers.WhisperModel),
212
+ ):
213
+ # For these models we only need the encoder part
214
+ # Wav2Vec2BertModel -> Wav2Vec2BertEncoder
215
+ # WhisperModel -> WhisperEncoder
216
+ audio_tower = audio_tower.encoder
217
+
218
+ audio_tower = apply_lora(audio_tower, config.audio_model_lora_config)
219
+ return audio_tower
220
+
221
+ @classmethod
222
+ def _create_language_model(
223
+ cls, config: UltravoxConfig
224
+ ) -> transformers.LlamaForCausalLM:
225
+ if config.text_model_id is not None:
226
+ language_model = transformers.AutoModelForCausalLM.from_pretrained(
227
+ config.text_model_id, attn_implementation=config._attn_implementation
228
+ )
229
+ else:
230
+ language_model = transformers.AutoModelForCausalLM.from_config(
231
+ config.text_config, attn_implementation=config._attn_implementation
232
+ )
233
+
234
+ language_model = apply_lora(language_model, config.text_model_lora_config)
235
+ return language_model
236
+
237
+ def merge_and_unload(self):
238
+ if isinstance(self.language_model, peft.PeftModel):
239
+ self.language_model = self.language_model.merge_and_unload()
240
+ # no need to download base language model weights anymore, so we can remove the id
241
+ self.config.text_model_id = None
242
+ self.keep_params.update(
243
+ set(
244
+ [
245
+ f"language_model.{name}"
246
+ for name, _ in self.language_model.named_parameters()
247
+ ]
248
+ )
249
+ )
250
+
251
+ if isinstance(self.audio_tower, peft.PeftModel):
252
+ self.audio_tower = self.audio_tower.merge_and_unload()
253
+ # no need to download base audio model weights anymore, so we can remove the id
254
+ self.config.audio_model_id = None
255
+ self.keep_params.update(
256
+ set(
257
+ [
258
+ f"audio_tower.{name}"
259
+ for name, _ in self.audio_tower.named_parameters()
260
+ ]
261
+ )
262
+ )
263
+
264
+ for param in ["text_model_lora_config", "audio_model_lora_config"]:
265
+ if hasattr(self.config, param):
266
+ delattr(self.config, param)
267
+
268
+ def push_to_hub(self, *args, **kwargs):
269
+ self.merge_and_unload()
270
+ self.to(self.language_model.dtype)
271
+ return super().push_to_hub(*args, **kwargs)
272
+
273
+ def state_dict(self, *args, **kwargs):
274
+ named_params = dict(self.named_parameters())
275
+ state_dict = super().state_dict(*args, **kwargs)
276
+
277
+ state_dict = {
278
+ k: v
279
+ for k, v in state_dict.items()
280
+ if k in self.keep_params
281
+ or (k in named_params and named_params[k].requires_grad)
282
+ }
283
+ return state_dict
284
+
285
+ def load_state_dict(
286
+ self,
287
+ state_dict: Dict[str, Any],
288
+ *args,
289
+ **kwargs,
290
+ ):
291
+ self.keep_params.update(set(state_dict.keys()))
292
+ return super().load_state_dict(state_dict, *args, **kwargs)
293
+
294
+ def print_trainable_parameters(self):
295
+ """
296
+ Prints the number of trainable parameters in the model (reuses Peft model's method)
297
+ """
298
+ count_params = peft.peft_model.PeftModel.get_nb_trainable_parameters
299
+
300
+ trainable_params, all_param = count_params(self)
301
+
302
+ logging.info(
303
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
304
+ f" || trainable%: {100 * trainable_params / all_param:.1f}%"
305
+ )
306
+
307
+ lm_trainable_params, lm_all_params = count_params(self.language_model)
308
+ audio_trainable_params, audio_all_params = count_params(self.audio_tower)
309
+
310
+ projector_trainable_params = (
311
+ trainable_params - lm_trainable_params - audio_trainable_params
312
+ )
313
+ projector_all_params = all_param - lm_all_params - audio_all_params
314
+
315
+ logging.info(
316
+ f"Trainable%: "
317
+ f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
318
+ f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%"
319
+ f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%"
320
+ )
321
+
322
+
323
+ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
324
+ """
325
+ Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
326
+ """
327
+ lora_config = peft.LoraConfig(**lora_config or {})
328
+
329
+ if lora_config.r == 0:
330
+ # freeze the model entirely
331
+ for param in model.parameters():
332
+ param.requires_grad = False
333
+ else:
334
+ model = peft.get_peft_model(model, lora_config)
335
+
336
+ return model
337
+
338
+
339
+ class StackAudioFrames(nn.Module):
340
+ """
341
+ Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
342
+
343
+ The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
344
+ NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
345
+ we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
346
+ In most cases this extra padding will get removed in the model's forward function so it has no effect.
347
+ """
348
+
349
+ def __init__(self, stack_factor: int = 8):
350
+ super().__init__()
351
+ self.stack_factor = stack_factor
352
+
353
+ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
354
+ B, T, C = audio_embeds.shape
355
+ T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
356
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
357
+ B, T, C = audio_embeds.shape
358
+ audio_embeds = audio_embeds.view(
359
+ B, T // self.stack_factor, C * self.stack_factor
360
+ )
361
+ return audio_embeds
362
+
363
+
364
+ class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
365
+ def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
366
+ super().__init__(hidden_size=hidden_size, eps=eps)
367
+ self.weight.data.fill_(init)
368
+
369
+
370
+ class SwiGLU(nn.Module):
371
+ def forward(self, x):
372
+ x, gate = x.chunk(2, dim=-1)
373
+ return F.silu(gate) * x
374
+
375
+
376
+ class UltravoxProjector(nn.Sequential):
377
+ def __init__(self, config: UltravoxConfig):
378
+ super().__init__()
379
+ self.hidden_dim = config.hidden_size
380
+ self._pad_and_stack = StackAudioFrames(config.stack_factor)
381
+ dim = config.audio_config.hidden_size * config.stack_factor
382
+ self.ln_pre = RMSNorm(dim, init=config.norm_init)
383
+ self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
384
+ dim = self.hidden_dim
385
+ self.act = transformers.activations.get_activation(config.projector_act)
386
+ dim = dim // 2 if config.projector_act == "swiglu" else dim
387
+ self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
388
+ self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
389
+
390
+ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
391
+ audio_features = self._pad_and_stack(audio_features)
392
+ audio_features = self.ln_pre(audio_features)
393
+ hidden_states = self.linear_1(audio_features)
394
+ hidden_states = self.act(hidden_states)
395
+ hidden_states = self.linear_2(hidden_states)
396
+ hidden_states = self.ln_post(hidden_states)
397
+ return hidden_states
398
+
399
+
400
+ UltravoxConfig.register_for_auto_class()
401
+ UltravoxModel.register_for_auto_class()
402
+
403
+ transformers.AutoConfig.register("ultravox", UltravoxConfig)
404
+ transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
405
+ # transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processo work standalone
406
+
407
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU