farzadab commited on
Commit
6917624
1 Parent(s): d612b34

Upload UltravoxPipeline

Browse files
Files changed (4) hide show
  1. config.json +64 -0
  2. tokenizer_config.json +42 -0
  3. ultravox_model.py +402 -0
  4. ultravox_pipeline.py +110 -0
config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UltravoxModel"
4
+ ],
5
+ "audio_config": {
6
+ "_name_or_path": "facebook/wav2vec2-base-960h",
7
+ "architectures": [
8
+ "Wav2Vec2ForCTC"
9
+ ],
10
+ "feat_extract_dropout": 0.0,
11
+ "feat_proj_dropout": 0.1,
12
+ "gradient_checkpointing": false,
13
+ "hidden_dropout_prob": 0.1,
14
+ "model_type": "wav2vec2"
15
+ },
16
+ "audio_model_id": "facebook/wav2vec2-base-960h",
17
+ "audio_token_index": 32000,
18
+ "auto_map": {
19
+ "AutoConfig": "ultravox_config.UltravoxConfig",
20
+ "AutoModel": "ultravox_model.UltravoxModel"
21
+ },
22
+ "custom_pipelines": {
23
+ "ultravox-pipeline": {
24
+ "default": {
25
+ "model": {
26
+ "pt": [
27
+ "fixie-ai/ultravox-v0.2",
28
+ "main"
29
+ ]
30
+ }
31
+ },
32
+ "impl": "ultravox_pipeline.UltravoxPipeline",
33
+ "pt": [
34
+ "UltravoxModel"
35
+ ],
36
+ "tf": [],
37
+ "type": "multimodal"
38
+ }
39
+ },
40
+ "hidden_size": 4096,
41
+ "ignore_index": -100,
42
+ "initializer_range": 0.02,
43
+ "model_type": "ultravox",
44
+ "norm_init": 0.4,
45
+ "projector_act": "swiglu",
46
+ "stack_factor": 8,
47
+ "text_config": {
48
+ "_name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
49
+ "architectures": [
50
+ "LlamaForCausalLM"
51
+ ],
52
+ "hidden_size": 2048,
53
+ "intermediate_size": 5632,
54
+ "model_type": "llama",
55
+ "num_hidden_layers": 22,
56
+ "num_key_value_heads": 4,
57
+ "rms_norm_eps": 1e-05,
58
+ "torch_dtype": "bfloat16"
59
+ },
60
+ "text_model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
61
+ "torch_dtype": "float32",
62
+ "transformers_version": "4.41.2",
63
+ "vocab_size": 32000
64
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "bos_token": "<s>",
31
+ "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": false,
35
+ "model_max_length": 2048,
36
+ "pad_token": "</s>",
37
+ "padding_side": "right",
38
+ "sp_model_kwargs": {},
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false
42
+ }
ultravox_model.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, Optional, Set, Tuple, Union
3
+
4
+ import peft
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import transformers
9
+ import transformers.activations
10
+ import transformers.modeling_outputs
11
+ import transformers.models
12
+
13
+ # We must use relative import in this directory to allow uploading to HF Hub
14
+ from . import ultravox_config
15
+ from . import whisper_model_modified
16
+
17
+
18
+ class UltravoxModel(
19
+ transformers.LlamaPreTrainedModel,
20
+ transformers.GenerationMixin,
21
+ ):
22
+ """
23
+ The Ultravox model which consists of an audio encoder and a language model.
24
+
25
+ Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
26
+ projected to the language model's embedding space using a few linear layers.
27
+ The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
28
+
29
+ A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
30
+
31
+ Parameters:
32
+ config: Model configuration class with all the parameters of the model.
33
+ """
34
+
35
+ config_class = ultravox_config.UltravoxConfig
36
+ config: ultravox_config.UltravoxConfig # for type hinting
37
+ _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
38
+
39
+ def __init__(self, config: ultravox_config.UltravoxConfig):
40
+ super().__init__(config)
41
+
42
+ self.keep_params: Set[str] = set()
43
+ self.vocab_size = config.vocab_size
44
+
45
+ self.audio_tower = self._create_audio_tower(config)
46
+ self.multi_modal_projector = UltravoxProjector(config)
47
+ self.language_model = self._create_language_model(config)
48
+
49
+ self.post_init()
50
+
51
+ def get_input_embeddings(self):
52
+ return self.language_model.get_input_embeddings()
53
+
54
+ def set_input_embeddings(self, value):
55
+ self.language_model.set_input_embeddings(value)
56
+
57
+ def get_output_embeddings(self):
58
+ return self.language_model.get_output_embeddings()
59
+
60
+ def set_output_embeddings(self, new_embeddings):
61
+ self.language_model.set_output_embeddings(new_embeddings)
62
+
63
+ def set_decoder(self, decoder):
64
+ self.language_model.set_decoder(decoder)
65
+
66
+ def get_decoder(self):
67
+ return self.language_model.get_decoder()
68
+
69
+ def tie_weights(self):
70
+ return self.language_model.tie_weights()
71
+
72
+ def _setup_cache(
73
+ self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
74
+ ):
75
+ self.language_model._setup_cache(cache_cls, max_batch_size, max_cache_len)
76
+
77
+ def _reorder_cache(self, past_key_values, beam_idx):
78
+ return self.language_model._reorder_cache(past_key_values, beam_idx)
79
+
80
+ def resize_token_embeddings(
81
+ self,
82
+ new_num_tokens: Optional[int] = None,
83
+ pad_to_multiple_of: Optional[int] = None,
84
+ ) -> nn.Embedding:
85
+ model_embeds = self.language_model.resize_token_embeddings(
86
+ new_num_tokens, pad_to_multiple_of
87
+ )
88
+ # update vocab size
89
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
90
+ self.config.vocab_size = model_embeds.num_embeddings
91
+ self.vocab_size = model_embeds.num_embeddings
92
+ return model_embeds
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ audio_values: Optional[torch.FloatTensor] = None,
98
+ inputs_embeds: Optional[torch.FloatTensor] = None,
99
+ labels: Optional[torch.Tensor] = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ audio_token_start_idx: Optional[torch.Tensor] = None,
102
+ audio_token_len: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple] = None,
104
+ **kwargs,
105
+ ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
106
+ """
107
+ Forward pass for the Ultravox model.
108
+
109
+ `input_ids` are the tokenized text input. They are embedded by the language model as usual.
110
+ `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
111
+ projected to the language model's embedding space using a few linear layers.
112
+ The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
113
+ of the audio embeddings in the merged embeddings.
114
+
115
+ Args:
116
+ input_ids: The tokenized text input.
117
+ audio_values: The processed audio values.
118
+ inputs_embeds: The embeddings for the input tokens.
119
+ labels: The tokenized text labels.
120
+ attention_mask: The attention mask for the input.
121
+ position_ids: The position ids for the input.
122
+ past_key_values: The past key value cache for the language model attention layers.
123
+ **kwargs: Additional keyword arguments. Passed directly to the language model.
124
+ """
125
+ if inputs_embeds is None:
126
+ # B x T -> B x T x D
127
+ inputs_embeds = self.get_input_embeddings().forward(input_ids)
128
+
129
+ if audio_values is not None:
130
+ assert (
131
+ audio_token_start_idx is not None and audio_token_len is not None
132
+ ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
133
+ assert (
134
+ len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
135
+ ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
136
+
137
+ # B x A/3200 x D
138
+ audio_tower_output = self.audio_tower.forward(
139
+ audio_values
140
+ ).last_hidden_state
141
+ audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
142
+
143
+ audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
144
+
145
+ # combine audio and text embeddings
146
+ for i, (audio, start, length) in enumerate(
147
+ zip(audio_embeds, audio_token_start_idx, audio_token_len)
148
+ ):
149
+ length = min(length, audio.shape[0])
150
+ inputs_embeds[i, start : start + length] = audio[:length]
151
+
152
+ lm_output = self.language_model.forward(
153
+ inputs_embeds=inputs_embeds,
154
+ labels=labels,
155
+ attention_mask=attention_mask,
156
+ past_key_values=past_key_values,
157
+ **kwargs,
158
+ )
159
+
160
+ return lm_output
161
+
162
+ def prepare_inputs_for_generation(
163
+ self,
164
+ input_ids: torch.Tensor,
165
+ audio_values: Optional[torch.FloatTensor] = None,
166
+ audio_token_start_idx: Optional[torch.Tensor] = None,
167
+ audio_token_len: Optional[torch.Tensor] = None,
168
+ past_key_values: Optional[Tuple] = None,
169
+ attention_mask: Optional[torch.Tensor] = None,
170
+ inputs_embeds: Optional[torch.Tensor] = None,
171
+ **kwargs,
172
+ ) -> Dict[str, Any]:
173
+ model_input = self.language_model.prepare_inputs_for_generation(
174
+ input_ids=input_ids,
175
+ past_key_values=past_key_values,
176
+ attention_mask=attention_mask,
177
+ inputs_embeds=inputs_embeds,
178
+ **kwargs,
179
+ )
180
+
181
+ if past_key_values is None and audio_values is not None:
182
+ # We only want to use audio features in the 1st generation step
183
+ model_input["audio_values"] = audio_values
184
+ model_input["audio_token_start_idx"] = audio_token_start_idx
185
+ model_input["audio_token_len"] = audio_token_len
186
+
187
+ return model_input
188
+
189
+ @classmethod
190
+ def _create_audio_tower(cls, config: ultravox_config.UltravoxConfig) -> Union[
191
+ transformers.Wav2Vec2Model,
192
+ transformers.models.whisper.modeling_whisper.WhisperEncoder,
193
+ ]:
194
+ if config.audio_model_id is not None:
195
+ if "whisper" in config.audio_model_id is not None:
196
+ audio_tower = whisper_model_modified.WhisperEncoder.from_pretrained(
197
+ config.audio_model_id
198
+ )
199
+ else:
200
+ audio_tower = transformers.AutoModel.from_pretrained(
201
+ config.audio_model_id
202
+ )
203
+ else:
204
+ if "whisper" in config.audio_config._name_or_path:
205
+ audio_tower = whisper_model_modified.WhisperEncoder(config.audio_config)
206
+ else:
207
+ audio_tower = transformers.AutoModel.from_config(config.audio_config)
208
+
209
+ if isinstance(
210
+ audio_tower,
211
+ (transformers.Wav2Vec2BertModel, transformers.WhisperModel),
212
+ ):
213
+ # For these models we only need the encoder part
214
+ # Wav2Vec2BertModel -> Wav2Vec2BertEncoder
215
+ # WhisperModel -> WhisperEncoder
216
+ audio_tower = audio_tower.encoder
217
+
218
+ audio_tower = apply_lora(audio_tower, config.audio_model_lora_config)
219
+ return audio_tower
220
+
221
+ @classmethod
222
+ def _create_language_model(
223
+ cls, config: ultravox_config.UltravoxConfig
224
+ ) -> transformers.LlamaForCausalLM:
225
+ if config.text_model_id is not None:
226
+ language_model = transformers.AutoModelForCausalLM.from_pretrained(
227
+ config.text_model_id, attn_implementation=config._attn_implementation
228
+ )
229
+ else:
230
+ language_model = transformers.AutoModelForCausalLM.from_config(
231
+ config.text_config, attn_implementation=config._attn_implementation
232
+ )
233
+
234
+ language_model = apply_lora(language_model, config.text_model_lora_config)
235
+ return language_model
236
+
237
+ def merge_and_unload(self):
238
+ if isinstance(self.language_model, peft.PeftModel):
239
+ self.language_model = self.language_model.merge_and_unload()
240
+ # no need to download base language model weights anymore, so we can remove the id
241
+ self.config.text_model_id = None
242
+ self.keep_params.update(
243
+ set(
244
+ [
245
+ f"language_model.{name}"
246
+ for name, _ in self.language_model.named_parameters()
247
+ ]
248
+ )
249
+ )
250
+
251
+ if isinstance(self.audio_tower, peft.PeftModel):
252
+ self.audio_tower = self.audio_tower.merge_and_unload()
253
+ # no need to download base audio model weights anymore, so we can remove the id
254
+ self.config.audio_model_id = None
255
+ self.keep_params.update(
256
+ set(
257
+ [
258
+ f"audio_tower.{name}"
259
+ for name, _ in self.audio_tower.named_parameters()
260
+ ]
261
+ )
262
+ )
263
+
264
+ for param in ["text_model_lora_config", "audio_model_lora_config"]:
265
+ if hasattr(self.config, param):
266
+ delattr(self.config, param)
267
+
268
+ def push_to_hub(self, *args, **kwargs):
269
+ self.merge_and_unload()
270
+ self.to(self.language_model.dtype)
271
+ return super().push_to_hub(*args, **kwargs)
272
+
273
+ def state_dict(self, *args, **kwargs):
274
+ named_params = dict(self.named_parameters())
275
+ state_dict = super().state_dict(*args, **kwargs)
276
+
277
+ state_dict = {
278
+ k: v
279
+ for k, v in state_dict.items()
280
+ if k in self.keep_params
281
+ or (k in named_params and named_params[k].requires_grad)
282
+ }
283
+ return state_dict
284
+
285
+ def load_state_dict(
286
+ self,
287
+ state_dict: Dict[str, Any],
288
+ *args,
289
+ **kwargs,
290
+ ):
291
+ self.keep_params.update(set(state_dict.keys()))
292
+ return super().load_state_dict(state_dict, *args, **kwargs)
293
+
294
+ def print_trainable_parameters(self):
295
+ """
296
+ Prints the number of trainable parameters in the model (reuses Peft model's method)
297
+ """
298
+ count_params = peft.peft_model.PeftModel.get_nb_trainable_parameters
299
+
300
+ trainable_params, all_param = count_params(self)
301
+
302
+ logging.info(
303
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
304
+ f" || trainable%: {100 * trainable_params / all_param:.1f}%"
305
+ )
306
+
307
+ lm_trainable_params, lm_all_params = count_params(self.language_model)
308
+ audio_trainable_params, audio_all_params = count_params(self.audio_tower)
309
+
310
+ projector_trainable_params = (
311
+ trainable_params - lm_trainable_params - audio_trainable_params
312
+ )
313
+ projector_all_params = all_param - lm_all_params - audio_all_params
314
+
315
+ logging.info(
316
+ f"Trainable%: "
317
+ f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
318
+ f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%"
319
+ f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%"
320
+ )
321
+
322
+
323
+ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
324
+ """
325
+ Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
326
+ """
327
+ lora_config = peft.LoraConfig(**lora_config or {})
328
+
329
+ if lora_config.r == 0:
330
+ # freeze the model entirely
331
+ for param in model.parameters():
332
+ param.requires_grad = False
333
+ else:
334
+ model = peft.get_peft_model(model, lora_config)
335
+
336
+ return model
337
+
338
+
339
+ class StackAudioFrames(nn.Module):
340
+ """
341
+ Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
342
+
343
+ The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
344
+ NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
345
+ we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
346
+ In most cases this extra padding will get removed in the model's forward function so it has no effect.
347
+ """
348
+
349
+ def __init__(self, stack_factor: int = 8):
350
+ super().__init__()
351
+ self.stack_factor = stack_factor
352
+
353
+ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
354
+ B, T, C = audio_embeds.shape
355
+ T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
356
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
357
+ B, T, C = audio_embeds.shape
358
+ audio_embeds = audio_embeds.view(
359
+ B, T // self.stack_factor, C * self.stack_factor
360
+ )
361
+ return audio_embeds
362
+
363
+
364
+ class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
365
+ def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
366
+ super().__init__(hidden_size=hidden_size, eps=eps)
367
+ self.weight.data.fill_(init)
368
+
369
+
370
+ class SwiGLU(nn.Module):
371
+ def forward(self, x):
372
+ x, gate = x.chunk(2, dim=-1)
373
+ return F.silu(gate) * x
374
+
375
+
376
+ class UltravoxProjector(nn.Sequential):
377
+ def __init__(self, config: ultravox_config.UltravoxConfig):
378
+ super().__init__()
379
+ self.hidden_dim = config.hidden_size
380
+ self._pad_and_stack = StackAudioFrames(config.stack_factor)
381
+ dim = config.audio_config.hidden_size * config.stack_factor
382
+ self.ln_pre = RMSNorm(dim, init=config.norm_init)
383
+ self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
384
+ dim = self.hidden_dim
385
+ self.act = transformers.activations.get_activation(config.projector_act)
386
+ dim = dim // 2 if config.projector_act == "swiglu" else dim
387
+ self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
388
+ self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
389
+
390
+ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
391
+ audio_features = self._pad_and_stack(audio_features)
392
+ audio_features = self.ln_pre(audio_features)
393
+ hidden_states = self.linear_1(audio_features)
394
+ hidden_states = self.act(hidden_states)
395
+ hidden_states = self.linear_2(hidden_states)
396
+ hidden_states = self.ln_post(hidden_states)
397
+ return hidden_states
398
+
399
+
400
+ UltravoxModel.register_for_auto_class("AutoModelForCausalLM")
401
+
402
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU
ultravox_pipeline.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import transformers
5
+
6
+ # We must use relative import in this directory to allow uploading to HF Hub
7
+ from . import ultravox_model
8
+ from . import ultravox_processing
9
+
10
+
11
+ class UltravoxPipeline(transformers.Pipeline):
12
+ def __init__(
13
+ self,
14
+ model: ultravox_model.UltravoxModel,
15
+ tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
16
+ audio_processor: Optional[transformers.ProcessorMixin] = None,
17
+ **kwargs
18
+ ):
19
+ if tokenizer is None:
20
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
21
+ model.config._name_or_path
22
+ )
23
+
24
+ if audio_processor is None:
25
+ audio_processor = transformers.Wav2Vec2Processor.from_pretrained(
26
+ model.config.audio_model_id
27
+ )
28
+
29
+ self.processor = ultravox_processing.UltravoxProcessor(
30
+ audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor
31
+ )
32
+
33
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
34
+
35
+ def _sanitize_parameters(self, **kwargs):
36
+ generation_kwargs = {}
37
+ if "temperature" in kwargs:
38
+ generation_kwargs["temperature"] = kwargs["temperature"]
39
+ if "max_new_tokens" in kwargs:
40
+ generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
41
+ if "repetition_penalty" in kwargs:
42
+ generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
43
+ return {}, generation_kwargs, {}
44
+
45
+ def preprocess(self, inputs: Dict[str, Any]):
46
+ if "turns" in inputs:
47
+ turns = inputs["turns"]
48
+ else:
49
+ prompt = inputs.get("prompt", "<|audio|>")
50
+ if "<|audio|>" not in prompt:
51
+ logging.warning(
52
+ "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
53
+ )
54
+ prompt += " <|audio|>"
55
+ turns = [{"role": "user", "content": prompt}]
56
+
57
+ text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False)
58
+
59
+ # TODO: allow text-only mode?
60
+ assert "audio" in inputs, "Audio input is required"
61
+
62
+ if "sampling_rate" not in inputs:
63
+ logging.warning(
64
+ "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
65
+ )
66
+
67
+ return self.processor(
68
+ text=text,
69
+ audio=inputs["audio"],
70
+ sampling_rate=inputs.get("sampling_rate", 16000),
71
+ )
72
+
73
+ def _forward(
74
+ self,
75
+ model_inputs: Dict[str, Any],
76
+ temperature: Optional[float] = None,
77
+ max_new_tokens: Optional[int] = None,
78
+ repetition_penalty: float = 1.1,
79
+ ) -> List[int]:
80
+ temperature = temperature or None
81
+ do_sample = temperature is not None
82
+
83
+ terminators = [self.tokenizer.eos_token_id]
84
+ if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
85
+ terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
86
+
87
+ input_len = model_inputs["input_ids"].shape[1]
88
+
89
+ outputs = self.model.generate(
90
+ **model_inputs,
91
+ do_sample=do_sample,
92
+ temperature=temperature,
93
+ max_new_tokens=max_new_tokens,
94
+ repetition_penalty=repetition_penalty,
95
+ eos_token_id=terminators
96
+ )
97
+ return outputs[0][input_len:]
98
+
99
+ def postprocess(self, model_outputs) -> str:
100
+ output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
101
+ return output_text
102
+
103
+
104
+ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
105
+ "ultravox-pipeline",
106
+ pipeline_class=UltravoxPipeline,
107
+ pt_model=ultravox_model.UltravoxModel,
108
+ default={"pt": ("fixie-ai/ultravox-v0.2", "main")},
109
+ type="multimodal",
110
+ )