Update ultravox_model.py
Browse files- ultravox_model.py +15 -17
ultravox_model.py
CHANGED
@@ -11,9 +11,9 @@ import transformers.modeling_outputs
|
|
11 |
import transformers.models
|
12 |
|
13 |
# We must use relative import in this directory to allow uploading to HF Hub
|
14 |
-
from . import
|
15 |
-
from . import
|
16 |
-
from . import
|
17 |
|
18 |
|
19 |
class UltravoxModel(
|
@@ -33,11 +33,11 @@ class UltravoxModel(
|
|
33 |
config: Model configuration class with all the parameters of the model.
|
34 |
"""
|
35 |
|
36 |
-
config_class =
|
37 |
-
config:
|
38 |
_no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
|
39 |
|
40 |
-
def __init__(self, config:
|
41 |
super().__init__(config)
|
42 |
|
43 |
self.keep_params: Set[str] = set()
|
@@ -188,13 +188,13 @@ class UltravoxModel(
|
|
188 |
return model_input
|
189 |
|
190 |
@classmethod
|
191 |
-
def _create_audio_tower(cls, config:
|
192 |
transformers.Wav2Vec2Model,
|
193 |
transformers.models.whisper.modeling_whisper.WhisperEncoder,
|
194 |
]:
|
195 |
if config.audio_model_id is not None:
|
196 |
if "whisper" in config.audio_model_id is not None:
|
197 |
-
audio_tower =
|
198 |
config.audio_model_id
|
199 |
)
|
200 |
else:
|
@@ -203,7 +203,7 @@ class UltravoxModel(
|
|
203 |
)
|
204 |
else:
|
205 |
if "whisper" in config.audio_config._name_or_path:
|
206 |
-
audio_tower =
|
207 |
else:
|
208 |
audio_tower = transformers.AutoModel.from_config(config.audio_config)
|
209 |
|
@@ -221,7 +221,7 @@ class UltravoxModel(
|
|
221 |
|
222 |
@classmethod
|
223 |
def _create_language_model(
|
224 |
-
cls, config:
|
225 |
) -> transformers.LlamaForCausalLM:
|
226 |
if config.text_model_id is not None:
|
227 |
language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
@@ -375,7 +375,7 @@ class SwiGLU(nn.Module):
|
|
375 |
|
376 |
|
377 |
class UltravoxProjector(nn.Sequential):
|
378 |
-
def __init__(self, config:
|
379 |
super().__init__()
|
380 |
self.hidden_dim = config.hidden_size
|
381 |
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
@@ -398,15 +398,13 @@ class UltravoxProjector(nn.Sequential):
|
|
398 |
return hidden_states
|
399 |
|
400 |
|
401 |
-
transformers.AutoConfig.register("ultravox",
|
402 |
-
transformers.AutoModel.register(
|
403 |
# transformers.AutoModelForCausalLM.register(
|
404 |
-
#
|
405 |
# )
|
406 |
UltravoxModel.register_for_auto_class()
|
407 |
-
transformers.AutoProcessor.register(
|
408 |
-
ultravox_config.UltravoxConfig, ultravox_processing.UltravoxProcessor
|
409 |
-
)
|
410 |
# UltravoxModel.register_for_auto_class("AutoModelForCausalLM")
|
411 |
|
412 |
|
|
|
11 |
import transformers.models
|
12 |
|
13 |
# We must use relative import in this directory to allow uploading to HF Hub
|
14 |
+
from .ultravox_config import UltravoxConfig
|
15 |
+
from .ultravox_processing import UltravoxProcessor
|
16 |
+
from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
|
17 |
|
18 |
|
19 |
class UltravoxModel(
|
|
|
33 |
config: Model configuration class with all the parameters of the model.
|
34 |
"""
|
35 |
|
36 |
+
config_class = UltravoxConfig
|
37 |
+
config: UltravoxConfig # for type hinting
|
38 |
_no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
|
39 |
|
40 |
+
def __init__(self, config: UltravoxConfig):
|
41 |
super().__init__(config)
|
42 |
|
43 |
self.keep_params: Set[str] = set()
|
|
|
188 |
return model_input
|
189 |
|
190 |
@classmethod
|
191 |
+
def _create_audio_tower(cls, config: UltravoxConfig) -> Union[
|
192 |
transformers.Wav2Vec2Model,
|
193 |
transformers.models.whisper.modeling_whisper.WhisperEncoder,
|
194 |
]:
|
195 |
if config.audio_model_id is not None:
|
196 |
if "whisper" in config.audio_model_id is not None:
|
197 |
+
audio_tower = ModifiedWhisperEncoder.from_pretrained(
|
198 |
config.audio_model_id
|
199 |
)
|
200 |
else:
|
|
|
203 |
)
|
204 |
else:
|
205 |
if "whisper" in config.audio_config._name_or_path:
|
206 |
+
audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
207 |
else:
|
208 |
audio_tower = transformers.AutoModel.from_config(config.audio_config)
|
209 |
|
|
|
221 |
|
222 |
@classmethod
|
223 |
def _create_language_model(
|
224 |
+
cls, config: UltravoxConfig
|
225 |
) -> transformers.LlamaForCausalLM:
|
226 |
if config.text_model_id is not None:
|
227 |
language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
|
375 |
|
376 |
|
377 |
class UltravoxProjector(nn.Sequential):
|
378 |
+
def __init__(self, config: UltravoxConfig):
|
379 |
super().__init__()
|
380 |
self.hidden_dim = config.hidden_size
|
381 |
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
|
|
398 |
return hidden_states
|
399 |
|
400 |
|
401 |
+
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
402 |
+
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
403 |
# transformers.AutoModelForCausalLM.register(
|
404 |
+
# UltravoxConfig, UltravoxModel
|
405 |
# )
|
406 |
UltravoxModel.register_for_auto_class()
|
407 |
+
transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
|
|
|
|
|
408 |
# UltravoxModel.register_for_auto_class("AutoModelForCausalLM")
|
409 |
|
410 |
|