Update ultravox_model.py
Browse files- ultravox_model.py +18 -8
ultravox_model.py
CHANGED
@@ -16,10 +16,7 @@ from .ultravox_config import UltravoxConfig
|
|
16 |
from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
|
17 |
|
18 |
|
19 |
-
class UltravoxModel(
|
20 |
-
transformers.LlamaPreTrainedModel,
|
21 |
-
transformers.GenerationMixin,
|
22 |
-
):
|
23 |
"""
|
24 |
The Ultravox model which consists of an audio encoder and a language model.
|
25 |
|
@@ -101,7 +98,7 @@ class UltravoxModel(
|
|
101 |
attention_mask: Optional[torch.Tensor] = None,
|
102 |
audio_token_start_idx: Optional[torch.Tensor] = None,
|
103 |
audio_token_len: Optional[torch.Tensor] = None,
|
104 |
-
past_key_values: Optional[Tuple] = None,
|
105 |
**kwargs,
|
106 |
) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
|
107 |
"""
|
@@ -166,7 +163,7 @@ class UltravoxModel(
|
|
166 |
audio_values: Optional[torch.FloatTensor] = None,
|
167 |
audio_token_start_idx: Optional[torch.Tensor] = None,
|
168 |
audio_token_len: Optional[torch.Tensor] = None,
|
169 |
-
past_key_values: Optional[Tuple] = None,
|
170 |
attention_mask: Optional[torch.Tensor] = None,
|
171 |
inputs_embeds: Optional[torch.Tensor] = None,
|
172 |
**kwargs,
|
@@ -179,7 +176,7 @@ class UltravoxModel(
|
|
179 |
**kwargs,
|
180 |
)
|
181 |
|
182 |
-
if past_key_values
|
183 |
# We only want to use audio features in the 1st generation step
|
184 |
model_input["audio_values"] = audio_values
|
185 |
model_input["audio_token_start_idx"] = audio_token_start_idx
|
@@ -320,6 +317,19 @@ class UltravoxModel(
|
|
320 |
)
|
321 |
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
|
324 |
"""
|
325 |
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
|
@@ -402,6 +412,6 @@ UltravoxModel.register_for_auto_class()
|
|
402 |
|
403 |
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
404 |
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
405 |
-
# transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make
|
406 |
|
407 |
transformers.activations.ACT2FN["swiglu"] = SwiGLU
|
|
|
16 |
from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
|
17 |
|
18 |
|
19 |
+
class UltravoxModel(transformers.LlamaPreTrainedModel):
|
|
|
|
|
|
|
20 |
"""
|
21 |
The Ultravox model which consists of an audio encoder and a language model.
|
22 |
|
|
|
98 |
attention_mask: Optional[torch.Tensor] = None,
|
99 |
audio_token_start_idx: Optional[torch.Tensor] = None,
|
100 |
audio_token_len: Optional[torch.Tensor] = None,
|
101 |
+
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
|
102 |
**kwargs,
|
103 |
) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
|
104 |
"""
|
|
|
163 |
audio_values: Optional[torch.FloatTensor] = None,
|
164 |
audio_token_start_idx: Optional[torch.Tensor] = None,
|
165 |
audio_token_len: Optional[torch.Tensor] = None,
|
166 |
+
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
|
167 |
attention_mask: Optional[torch.Tensor] = None,
|
168 |
inputs_embeds: Optional[torch.Tensor] = None,
|
169 |
**kwargs,
|
|
|
176 |
**kwargs,
|
177 |
)
|
178 |
|
179 |
+
if is_cache_empty(past_key_values) and audio_values is not None:
|
180 |
# We only want to use audio features in the 1st generation step
|
181 |
model_input["audio_values"] = audio_values
|
182 |
model_input["audio_token_start_idx"] = audio_token_start_idx
|
|
|
317 |
)
|
318 |
|
319 |
|
320 |
+
def is_cache_empty(
|
321 |
+
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
|
322 |
+
) -> bool:
|
323 |
+
"""
|
324 |
+
Check if the cache is empty.
|
325 |
+
"""
|
326 |
+
if past_key_values is None:
|
327 |
+
return True
|
328 |
+
if isinstance(past_key_values, tuple):
|
329 |
+
return all(len(c) == 0 for c in past_key_values)
|
330 |
+
return past_key_values.get_seq_length() == 0
|
331 |
+
|
332 |
+
|
333 |
def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
|
334 |
"""
|
335 |
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
|
|
|
412 |
|
413 |
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
414 |
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
415 |
+
# transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processor work standalone
|
416 |
|
417 |
transformers.activations.ACT2FN["swiglu"] = SwiGLU
|