farzadab commited on
Commit
3c2070f
·
verified ·
1 Parent(s): d66aec1

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +15 -17
ultravox_model.py CHANGED
@@ -11,9 +11,9 @@ import transformers.modeling_outputs
11
  import transformers.models
12
 
13
  # We must use relative import in this directory to allow uploading to HF Hub
14
- from . import ultravox_config
15
- from . import ultravox_processing
16
- from . import whisper_model_modified
17
 
18
 
19
  class UltravoxModel(
@@ -33,11 +33,11 @@ class UltravoxModel(
33
  config: Model configuration class with all the parameters of the model.
34
  """
35
 
36
- config_class = ultravox_config.UltravoxConfig
37
- config: ultravox_config.UltravoxConfig # for type hinting
38
  _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
39
 
40
- def __init__(self, config: ultravox_config.UltravoxConfig):
41
  super().__init__(config)
42
 
43
  self.keep_params: Set[str] = set()
@@ -188,13 +188,13 @@ class UltravoxModel(
188
  return model_input
189
 
190
  @classmethod
191
- def _create_audio_tower(cls, config: ultravox_config.UltravoxConfig) -> Union[
192
  transformers.Wav2Vec2Model,
193
  transformers.models.whisper.modeling_whisper.WhisperEncoder,
194
  ]:
195
  if config.audio_model_id is not None:
196
  if "whisper" in config.audio_model_id is not None:
197
- audio_tower = whisper_model_modified.WhisperEncoder.from_pretrained(
198
  config.audio_model_id
199
  )
200
  else:
@@ -203,7 +203,7 @@ class UltravoxModel(
203
  )
204
  else:
205
  if "whisper" in config.audio_config._name_or_path:
206
- audio_tower = whisper_model_modified.WhisperEncoder(config.audio_config)
207
  else:
208
  audio_tower = transformers.AutoModel.from_config(config.audio_config)
209
 
@@ -221,7 +221,7 @@ class UltravoxModel(
221
 
222
  @classmethod
223
  def _create_language_model(
224
- cls, config: ultravox_config.UltravoxConfig
225
  ) -> transformers.LlamaForCausalLM:
226
  if config.text_model_id is not None:
227
  language_model = transformers.AutoModelForCausalLM.from_pretrained(
@@ -375,7 +375,7 @@ class SwiGLU(nn.Module):
375
 
376
 
377
  class UltravoxProjector(nn.Sequential):
378
- def __init__(self, config: ultravox_config.UltravoxConfig):
379
  super().__init__()
380
  self.hidden_dim = config.hidden_size
381
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
@@ -398,15 +398,13 @@ class UltravoxProjector(nn.Sequential):
398
  return hidden_states
399
 
400
 
401
- transformers.AutoConfig.register("ultravox", ultravox_config.UltravoxConfig)
402
- transformers.AutoModel.register(ultravox_config.UltravoxConfig, UltravoxModel)
403
  # transformers.AutoModelForCausalLM.register(
404
- # ultravox_config.UltravoxConfig, UltravoxModel
405
  # )
406
  UltravoxModel.register_for_auto_class()
407
- transformers.AutoProcessor.register(
408
- ultravox_config.UltravoxConfig, ultravox_processing.UltravoxProcessor
409
- )
410
  # UltravoxModel.register_for_auto_class("AutoModelForCausalLM")
411
 
412
 
 
11
  import transformers.models
12
 
13
  # We must use relative import in this directory to allow uploading to HF Hub
14
+ from .ultravox_config import UltravoxConfig
15
+ from .ultravox_processing import UltravoxProcessor
16
+ from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
17
 
18
 
19
  class UltravoxModel(
 
33
  config: Model configuration class with all the parameters of the model.
34
  """
35
 
36
+ config_class = UltravoxConfig
37
+ config: UltravoxConfig # for type hinting
38
  _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
39
 
40
+ def __init__(self, config: UltravoxConfig):
41
  super().__init__(config)
42
 
43
  self.keep_params: Set[str] = set()
 
188
  return model_input
189
 
190
  @classmethod
191
+ def _create_audio_tower(cls, config: UltravoxConfig) -> Union[
192
  transformers.Wav2Vec2Model,
193
  transformers.models.whisper.modeling_whisper.WhisperEncoder,
194
  ]:
195
  if config.audio_model_id is not None:
196
  if "whisper" in config.audio_model_id is not None:
197
+ audio_tower = ModifiedWhisperEncoder.from_pretrained(
198
  config.audio_model_id
199
  )
200
  else:
 
203
  )
204
  else:
205
  if "whisper" in config.audio_config._name_or_path:
206
+ audio_tower = ModifiedWhisperEncoder(config.audio_config)
207
  else:
208
  audio_tower = transformers.AutoModel.from_config(config.audio_config)
209
 
 
221
 
222
  @classmethod
223
  def _create_language_model(
224
+ cls, config: UltravoxConfig
225
  ) -> transformers.LlamaForCausalLM:
226
  if config.text_model_id is not None:
227
  language_model = transformers.AutoModelForCausalLM.from_pretrained(
 
375
 
376
 
377
  class UltravoxProjector(nn.Sequential):
378
+ def __init__(self, config: UltravoxConfig):
379
  super().__init__()
380
  self.hidden_dim = config.hidden_size
381
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
 
398
  return hidden_states
399
 
400
 
401
+ transformers.AutoConfig.register("ultravox", UltravoxConfig)
402
+ transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
403
  # transformers.AutoModelForCausalLM.register(
404
+ # UltravoxConfig, UltravoxModel
405
  # )
406
  UltravoxModel.register_for_auto_class()
407
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
 
 
408
  # UltravoxModel.register_for_auto_class("AutoModelForCausalLM")
409
 
410