farzadab commited on
Commit
d9b04dc
1 Parent(s): 416504a

Update ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +28 -0
ultravox_processing.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import torch
5
  import transformers
6
 
 
 
7
 
8
  class UltravoxProcessor(transformers.ProcessorMixin):
9
  """
@@ -59,6 +61,29 @@ class UltravoxProcessor(transformers.ProcessorMixin):
59
 
60
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def __call__(
63
  self,
64
  text: Optional[str] = None,
@@ -178,3 +203,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
178
  tokenizer_input_names = self.tokenizer.model_input_names
179
  audio_processor_input_names = self.audio_processor.model_input_names
180
  return list(set(tokenizer_input_names + audio_processor_input_names))
 
 
 
 
4
  import torch
5
  import transformers
6
 
7
+ from .ultravox_config import UltravoxConfig
8
+
9
 
10
  class UltravoxProcessor(transformers.ProcessorMixin):
11
  """
 
61
 
62
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
63
 
64
+ @classmethod
65
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
66
+ config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
67
+ pretrained_model_name_or_path, **kwargs
68
+ )
69
+ audio_processor = transformers.AutoProcessor.from_pretrained(
70
+ config.audio_model_id
71
+ or config.audio_config._name_or_path
72
+ or "facebook/wav2vec2-base-960h"
73
+ )
74
+
75
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
76
+ pretrained_model_name_or_path, **kwargs
77
+ )
78
+ tokenizer.padding_side = "left"
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+
81
+ return cls(
82
+ audio_processor=audio_processor,
83
+ tokenizer=tokenizer,
84
+ stack_factor=config.stack_factor,
85
+ )
86
+
87
  def __call__(
88
  self,
89
  text: Optional[str] = None,
 
203
  tokenizer_input_names = self.tokenizer.model_input_names
204
  audio_processor_input_names = self.audio_processor.model_input_names
205
  return list(set(tokenizer_input_names + audio_processor_input_names))
206
+
207
+
208
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)