farzadab commited on
Commit
27da03b
·
verified ·
1 Parent(s): fb702a8

Delete ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +0 -402
ultravox_model.py DELETED
@@ -1,402 +0,0 @@
1
- import logging
2
- from typing import Any, Dict, Optional, Set, Tuple, Union
3
-
4
- import peft
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import transformers
9
- import transformers.activations
10
- import transformers.modeling_outputs
11
- import transformers.models
12
-
13
- # We must use relative import in this directory to allow uploading to HF Hub
14
- from . import ultravox_config
15
- from . import whisper_model_modified
16
-
17
-
18
- class UltravoxModel(
19
- transformers.LlamaPreTrainedModel,
20
- transformers.GenerationMixin,
21
- ):
22
- """
23
- The Ultravox model which consists of an audio encoder and a language model.
24
-
25
- Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
26
- projected to the language model's embedding space using a few linear layers.
27
- The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
28
-
29
- A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
30
-
31
- Parameters:
32
- config: Model configuration class with all the parameters of the model.
33
- """
34
-
35
- config_class = ultravox_config.UltravoxConfig
36
- config: ultravox_config.UltravoxConfig # for type hinting
37
- _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
38
-
39
- def __init__(self, config: ultravox_config.UltravoxConfig):
40
- super().__init__(config)
41
-
42
- self.keep_params: Set[str] = set()
43
- self.vocab_size = config.vocab_size
44
-
45
- self.audio_tower = self._create_audio_tower(config)
46
- self.multi_modal_projector = UltravoxProjector(config)
47
- self.language_model = self._create_language_model(config)
48
-
49
- self.post_init()
50
-
51
- def get_input_embeddings(self):
52
- return self.language_model.get_input_embeddings()
53
-
54
- def set_input_embeddings(self, value):
55
- self.language_model.set_input_embeddings(value)
56
-
57
- def get_output_embeddings(self):
58
- return self.language_model.get_output_embeddings()
59
-
60
- def set_output_embeddings(self, new_embeddings):
61
- self.language_model.set_output_embeddings(new_embeddings)
62
-
63
- def set_decoder(self, decoder):
64
- self.language_model.set_decoder(decoder)
65
-
66
- def get_decoder(self):
67
- return self.language_model.get_decoder()
68
-
69
- def tie_weights(self):
70
- return self.language_model.tie_weights()
71
-
72
- def _setup_cache(
73
- self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
74
- ):
75
- self.language_model._setup_cache(cache_cls, max_batch_size, max_cache_len)
76
-
77
- def _reorder_cache(self, past_key_values, beam_idx):
78
- return self.language_model._reorder_cache(past_key_values, beam_idx)
79
-
80
- def resize_token_embeddings(
81
- self,
82
- new_num_tokens: Optional[int] = None,
83
- pad_to_multiple_of: Optional[int] = None,
84
- ) -> nn.Embedding:
85
- model_embeds = self.language_model.resize_token_embeddings(
86
- new_num_tokens, pad_to_multiple_of
87
- )
88
- # update vocab size
89
- self.config.text_config.vocab_size = model_embeds.num_embeddings
90
- self.config.vocab_size = model_embeds.num_embeddings
91
- self.vocab_size = model_embeds.num_embeddings
92
- return model_embeds
93
-
94
- def forward(
95
- self,
96
- input_ids: torch.Tensor,
97
- audio_values: Optional[torch.FloatTensor] = None,
98
- inputs_embeds: Optional[torch.FloatTensor] = None,
99
- labels: Optional[torch.Tensor] = None,
100
- attention_mask: Optional[torch.Tensor] = None,
101
- audio_token_start_idx: Optional[torch.Tensor] = None,
102
- audio_token_len: Optional[torch.Tensor] = None,
103
- past_key_values: Optional[Tuple] = None,
104
- **kwargs,
105
- ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
106
- """
107
- Forward pass for the Ultravox model.
108
-
109
- `input_ids` are the tokenized text input. They are embedded by the language model as usual.
110
- `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
111
- projected to the language model's embedding space using a few linear layers.
112
- The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
113
- of the audio embeddings in the merged embeddings.
114
-
115
- Args:
116
- input_ids: The tokenized text input.
117
- audio_values: The processed audio values.
118
- inputs_embeds: The embeddings for the input tokens.
119
- labels: The tokenized text labels.
120
- attention_mask: The attention mask for the input.
121
- position_ids: The position ids for the input.
122
- past_key_values: The past key value cache for the language model attention layers.
123
- **kwargs: Additional keyword arguments. Passed directly to the language model.
124
- """
125
- if inputs_embeds is None:
126
- # B x T -> B x T x D
127
- inputs_embeds = self.get_input_embeddings().forward(input_ids)
128
-
129
- if audio_values is not None:
130
- assert (
131
- audio_token_start_idx is not None and audio_token_len is not None
132
- ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
133
- assert (
134
- len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
135
- ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
136
-
137
- # B x A/3200 x D
138
- audio_tower_output = self.audio_tower.forward(
139
- audio_values
140
- ).last_hidden_state
141
- audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
142
-
143
- audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
144
-
145
- # combine audio and text embeddings
146
- for i, (audio, start, length) in enumerate(
147
- zip(audio_embeds, audio_token_start_idx, audio_token_len)
148
- ):
149
- length = min(length, audio.shape[0])
150
- inputs_embeds[i, start : start + length] = audio[:length]
151
-
152
- lm_output = self.language_model.forward(
153
- inputs_embeds=inputs_embeds,
154
- labels=labels,
155
- attention_mask=attention_mask,
156
- past_key_values=past_key_values,
157
- **kwargs,
158
- )
159
-
160
- return lm_output
161
-
162
- def prepare_inputs_for_generation(
163
- self,
164
- input_ids: torch.Tensor,
165
- audio_values: Optional[torch.FloatTensor] = None,
166
- audio_token_start_idx: Optional[torch.Tensor] = None,
167
- audio_token_len: Optional[torch.Tensor] = None,
168
- past_key_values: Optional[Tuple] = None,
169
- attention_mask: Optional[torch.Tensor] = None,
170
- inputs_embeds: Optional[torch.Tensor] = None,
171
- **kwargs,
172
- ) -> Dict[str, Any]:
173
- model_input = self.language_model.prepare_inputs_for_generation(
174
- input_ids=input_ids,
175
- past_key_values=past_key_values,
176
- attention_mask=attention_mask,
177
- inputs_embeds=inputs_embeds,
178
- **kwargs,
179
- )
180
-
181
- if past_key_values is None and audio_values is not None:
182
- # We only want to use audio features in the 1st generation step
183
- model_input["audio_values"] = audio_values
184
- model_input["audio_token_start_idx"] = audio_token_start_idx
185
- model_input["audio_token_len"] = audio_token_len
186
-
187
- return model_input
188
-
189
- @classmethod
190
- def _create_audio_tower(cls, config: ultravox_config.UltravoxConfig) -> Union[
191
- transformers.Wav2Vec2Model,
192
- transformers.models.whisper.modeling_whisper.WhisperEncoder,
193
- ]:
194
- if config.audio_model_id is not None:
195
- if "whisper" in config.audio_model_id is not None:
196
- audio_tower = whisper_model_modified.WhisperEncoder.from_pretrained(
197
- config.audio_model_id
198
- )
199
- else:
200
- audio_tower = transformers.AutoModel.from_pretrained(
201
- config.audio_model_id
202
- )
203
- else:
204
- if "whisper" in config.audio_config._name_or_path:
205
- audio_tower = whisper_model_modified.WhisperEncoder(config.audio_config)
206
- else:
207
- audio_tower = transformers.AutoModel.from_config(config.audio_config)
208
-
209
- if isinstance(
210
- audio_tower,
211
- (transformers.Wav2Vec2BertModel, transformers.WhisperModel),
212
- ):
213
- # For these models we only need the encoder part
214
- # Wav2Vec2BertModel -> Wav2Vec2BertEncoder
215
- # WhisperModel -> WhisperEncoder
216
- audio_tower = audio_tower.encoder
217
-
218
- audio_tower = apply_lora(audio_tower, config.audio_model_lora_config)
219
- return audio_tower
220
-
221
- @classmethod
222
- def _create_language_model(
223
- cls, config: ultravox_config.UltravoxConfig
224
- ) -> transformers.LlamaForCausalLM:
225
- if config.text_model_id is not None:
226
- language_model = transformers.AutoModelForCausalLM.from_pretrained(
227
- config.text_model_id, attn_implementation=config._attn_implementation
228
- )
229
- else:
230
- language_model = transformers.AutoModelForCausalLM.from_config(
231
- config.text_config, attn_implementation=config._attn_implementation
232
- )
233
-
234
- language_model = apply_lora(language_model, config.text_model_lora_config)
235
- return language_model
236
-
237
- def merge_and_unload(self):
238
- if isinstance(self.language_model, peft.PeftModel):
239
- self.language_model = self.language_model.merge_and_unload()
240
- # no need to download base language model weights anymore, so we can remove the id
241
- self.config.text_model_id = None
242
- self.keep_params.update(
243
- set(
244
- [
245
- f"language_model.{name}"
246
- for name, _ in self.language_model.named_parameters()
247
- ]
248
- )
249
- )
250
-
251
- if isinstance(self.audio_tower, peft.PeftModel):
252
- self.audio_tower = self.audio_tower.merge_and_unload()
253
- # no need to download base audio model weights anymore, so we can remove the id
254
- self.config.audio_model_id = None
255
- self.keep_params.update(
256
- set(
257
- [
258
- f"audio_tower.{name}"
259
- for name, _ in self.audio_tower.named_parameters()
260
- ]
261
- )
262
- )
263
-
264
- for param in ["text_model_lora_config", "audio_model_lora_config"]:
265
- if hasattr(self.config, param):
266
- delattr(self.config, param)
267
-
268
- def push_to_hub(self, *args, **kwargs):
269
- self.merge_and_unload()
270
- self.to(self.language_model.dtype)
271
- return super().push_to_hub(*args, **kwargs)
272
-
273
- def state_dict(self, *args, **kwargs):
274
- named_params = dict(self.named_parameters())
275
- state_dict = super().state_dict(*args, **kwargs)
276
-
277
- state_dict = {
278
- k: v
279
- for k, v in state_dict.items()
280
- if k in self.keep_params
281
- or (k in named_params and named_params[k].requires_grad)
282
- }
283
- return state_dict
284
-
285
- def load_state_dict(
286
- self,
287
- state_dict: Dict[str, Any],
288
- *args,
289
- **kwargs,
290
- ):
291
- self.keep_params.update(set(state_dict.keys()))
292
- return super().load_state_dict(state_dict, *args, **kwargs)
293
-
294
- def print_trainable_parameters(self):
295
- """
296
- Prints the number of trainable parameters in the model (reuses Peft model's method)
297
- """
298
- count_params = peft.peft_model.PeftModel.get_nb_trainable_parameters
299
-
300
- trainable_params, all_param = count_params(self)
301
-
302
- logging.info(
303
- f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
304
- f" || trainable%: {100 * trainable_params / all_param:.1f}%"
305
- )
306
-
307
- lm_trainable_params, lm_all_params = count_params(self.language_model)
308
- audio_trainable_params, audio_all_params = count_params(self.audio_tower)
309
-
310
- projector_trainable_params = (
311
- trainable_params - lm_trainable_params - audio_trainable_params
312
- )
313
- projector_all_params = all_param - lm_all_params - audio_all_params
314
-
315
- logging.info(
316
- f"Trainable%: "
317
- f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
318
- f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%"
319
- f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%"
320
- )
321
-
322
-
323
- def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
324
- """
325
- Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
326
- """
327
- lora_config = peft.LoraConfig(**lora_config or {})
328
-
329
- if lora_config.r == 0:
330
- # freeze the model entirely
331
- for param in model.parameters():
332
- param.requires_grad = False
333
- else:
334
- model = peft.get_peft_model(model, lora_config)
335
-
336
- return model
337
-
338
-
339
- class StackAudioFrames(nn.Module):
340
- """
341
- Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
342
-
343
- The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
344
- NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
345
- we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
346
- In most cases this extra padding will get removed in the model's forward function so it has no effect.
347
- """
348
-
349
- def __init__(self, stack_factor: int = 8):
350
- super().__init__()
351
- self.stack_factor = stack_factor
352
-
353
- def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
354
- B, T, C = audio_embeds.shape
355
- T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
356
- audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
357
- B, T, C = audio_embeds.shape
358
- audio_embeds = audio_embeds.view(
359
- B, T // self.stack_factor, C * self.stack_factor
360
- )
361
- return audio_embeds
362
-
363
-
364
- class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
365
- def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
366
- super().__init__(hidden_size=hidden_size, eps=eps)
367
- self.weight.data.fill_(init)
368
-
369
-
370
- class SwiGLU(nn.Module):
371
- def forward(self, x):
372
- x, gate = x.chunk(2, dim=-1)
373
- return F.silu(gate) * x
374
-
375
-
376
- class UltravoxProjector(nn.Sequential):
377
- def __init__(self, config: ultravox_config.UltravoxConfig):
378
- super().__init__()
379
- self.hidden_dim = config.hidden_size
380
- self._pad_and_stack = StackAudioFrames(config.stack_factor)
381
- dim = config.audio_config.hidden_size * config.stack_factor
382
- self.ln_pre = RMSNorm(dim, init=config.norm_init)
383
- self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
384
- dim = self.hidden_dim
385
- self.act = transformers.activations.get_activation(config.projector_act)
386
- dim = dim // 2 if config.projector_act == "swiglu" else dim
387
- self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
388
- self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
389
-
390
- def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
391
- audio_features = self._pad_and_stack(audio_features)
392
- audio_features = self.ln_pre(audio_features)
393
- hidden_states = self.linear_1(audio_features)
394
- hidden_states = self.act(hidden_states)
395
- hidden_states = self.linear_2(hidden_states)
396
- hidden_states = self.ln_post(hidden_states)
397
- return hidden_states
398
-
399
-
400
- UltravoxModel.register_for_auto_class("AutoModelForCausalLM")
401
-
402
- transformers.activations.ACT2FN["swiglu"] = SwiGLU