farzadab commited on
Commit
6b22e1f
·
verified ·
1 Parent(s): fb6249a

Update ultravox_pipeline.py

Browse files
Files changed (1) hide show
  1. ultravox_pipeline.py +20 -19
ultravox_pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  from typing import Any, Dict, List, Optional
3
 
 
4
  import transformers
5
 
6
  # We must use relative import in this directory to allow uploading to HF Hub
@@ -41,27 +42,30 @@ class UltravoxPipeline(transformers.Pipeline):
41
  super().__init__(model=model, tokenizer=tokenizer, **kwargs)
42
 
43
  def _sanitize_parameters(self, **kwargs):
44
- generation_kwargs = {}
45
- if "temperature" in kwargs:
46
- generation_kwargs["temperature"] = kwargs["temperature"]
47
- if "max_new_tokens" in kwargs:
48
- generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
49
- if "repetition_penalty" in kwargs:
50
- generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
51
  return {}, generation_kwargs, {}
52
 
53
  def preprocess(self, inputs: Dict[str, Any]):
54
- if "turns" in inputs:
55
- turns = inputs["turns"]
56
- else:
57
- turns = []
58
-
59
- if not turns or turns[-1]["role"] != "user":
 
 
 
 
 
 
 
60
  prompt = inputs.get("prompt", "<|audio|>")
61
  if "<|audio|>" not in prompt:
62
  logging.warning(
63
  "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
64
  )
 
65
  prompt += " <|audio|>"
66
  turns.append({"role": "user", "content": prompt})
67
 
@@ -69,17 +73,14 @@ class UltravoxPipeline(transformers.Pipeline):
69
  turns, add_generation_prompt=True, tokenize=False
70
  )
71
 
72
- # TODO: allow text-only mode?
73
- assert "audio" in inputs, "Audio input is required"
74
-
75
- if "sampling_rate" not in inputs:
76
  logging.warning(
77
  "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
78
  )
79
 
80
  output = self.processor(
81
  text=text,
82
- audio=inputs["audio"],
83
  sampling_rate=inputs.get("sampling_rate", 16000),
84
  )
85
  if "audio_values" in output:
@@ -123,4 +124,4 @@ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
123
  pipeline_class=UltravoxPipeline,
124
  pt_model=transformers.AutoModel,
125
  type="multimodal",
126
- )
 
1
  import logging
2
  from typing import Any, Dict, List, Optional
3
 
4
+ import numpy as np
5
  import transformers
6
 
7
  # We must use relative import in this directory to allow uploading to HF Hub
 
42
  super().__init__(model=model, tokenizer=tokenizer, **kwargs)
43
 
44
  def _sanitize_parameters(self, **kwargs):
45
+ generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
46
+ generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
 
 
 
 
 
47
  return {}, generation_kwargs, {}
48
 
49
  def preprocess(self, inputs: Dict[str, Any]):
50
+ turns: list = inputs.get("turns", [])
51
+
52
+ audio = inputs.get("audio", None)
53
+ # Convert to float32 if needed.
54
+ if isinstance(audio, np.ndarray):
55
+ if audio.dtype == np.float64:
56
+ audio = audio.astype(np.float32)
57
+ elif audio.dtype == np.int16:
58
+ audio = audio.astype(np.float32) / np.float32(32768.0)
59
+ elif audio.dtype == np.int32:
60
+ audio = audio.astype(np.float32) / np.float32(2147483648.0)
61
+
62
+ if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
63
  prompt = inputs.get("prompt", "<|audio|>")
64
  if "<|audio|>" not in prompt:
65
  logging.warning(
66
  "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
67
  )
68
+
69
  prompt += " <|audio|>"
70
  turns.append({"role": "user", "content": prompt})
71
 
 
73
  turns, add_generation_prompt=True, tokenize=False
74
  )
75
 
76
+ if "sampling_rate" not in inputs and audio is not None:
 
 
 
77
  logging.warning(
78
  "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
79
  )
80
 
81
  output = self.processor(
82
  text=text,
83
+ audio=audio,
84
  sampling_rate=inputs.get("sampling_rate", 16000),
85
  )
86
  if "audio_values" in output:
 
124
  pipeline_class=UltravoxPipeline,
125
  pt_model=transformers.AutoModel,
126
  type="multimodal",
127
+ )