farzadab commited on
Commit
ffd36f6
·
verified ·
1 Parent(s): 1d58827

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +18 -8
ultravox_model.py CHANGED
@@ -16,10 +16,7 @@ from .ultravox_config import UltravoxConfig
16
  from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
17
 
18
 
19
- class UltravoxModel(
20
- transformers.LlamaPreTrainedModel,
21
- transformers.GenerationMixin,
22
- ):
23
  """
24
  The Ultravox model which consists of an audio encoder and a language model.
25
 
@@ -101,7 +98,7 @@ class UltravoxModel(
101
  attention_mask: Optional[torch.Tensor] = None,
102
  audio_token_start_idx: Optional[torch.Tensor] = None,
103
  audio_token_len: Optional[torch.Tensor] = None,
104
- past_key_values: Optional[Tuple] = None,
105
  **kwargs,
106
  ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
107
  """
@@ -166,7 +163,7 @@ class UltravoxModel(
166
  audio_values: Optional[torch.FloatTensor] = None,
167
  audio_token_start_idx: Optional[torch.Tensor] = None,
168
  audio_token_len: Optional[torch.Tensor] = None,
169
- past_key_values: Optional[Tuple] = None,
170
  attention_mask: Optional[torch.Tensor] = None,
171
  inputs_embeds: Optional[torch.Tensor] = None,
172
  **kwargs,
@@ -179,7 +176,7 @@ class UltravoxModel(
179
  **kwargs,
180
  )
181
 
182
- if past_key_values is None and audio_values is not None:
183
  # We only want to use audio features in the 1st generation step
184
  model_input["audio_values"] = audio_values
185
  model_input["audio_token_start_idx"] = audio_token_start_idx
@@ -320,6 +317,19 @@ class UltravoxModel(
320
  )
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
324
  """
325
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
@@ -402,6 +412,6 @@ UltravoxModel.register_for_auto_class()
402
 
403
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
404
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
405
- # transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processo work standalone
406
 
407
  transformers.activations.ACT2FN["swiglu"] = SwiGLU
 
16
  from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
17
 
18
 
19
+ class UltravoxModel(transformers.LlamaPreTrainedModel):
 
 
 
20
  """
21
  The Ultravox model which consists of an audio encoder and a language model.
22
 
 
98
  attention_mask: Optional[torch.Tensor] = None,
99
  audio_token_start_idx: Optional[torch.Tensor] = None,
100
  audio_token_len: Optional[torch.Tensor] = None,
101
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
102
  **kwargs,
103
  ) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
104
  """
 
163
  audio_values: Optional[torch.FloatTensor] = None,
164
  audio_token_start_idx: Optional[torch.Tensor] = None,
165
  audio_token_len: Optional[torch.Tensor] = None,
166
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
167
  attention_mask: Optional[torch.Tensor] = None,
168
  inputs_embeds: Optional[torch.Tensor] = None,
169
  **kwargs,
 
176
  **kwargs,
177
  )
178
 
179
+ if is_cache_empty(past_key_values) and audio_values is not None:
180
  # We only want to use audio features in the 1st generation step
181
  model_input["audio_values"] = audio_values
182
  model_input["audio_token_start_idx"] = audio_token_start_idx
 
317
  )
318
 
319
 
320
+ def is_cache_empty(
321
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
322
+ ) -> bool:
323
+ """
324
+ Check if the cache is empty.
325
+ """
326
+ if past_key_values is None:
327
+ return True
328
+ if isinstance(past_key_values, tuple):
329
+ return all(len(c) == 0 for c in past_key_values)
330
+ return past_key_values.get_seq_length() == 0
331
+
332
+
333
  def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
334
  """
335
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
 
412
 
413
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
414
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
415
+ # transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processor work standalone
416
 
417
  transformers.activations.ACT2FN["swiglu"] = SwiGLU