Update ultravox_pipeline.py
Browse files- ultravox_pipeline.py +20 -19
ultravox_pipeline.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
from typing import Any, Dict, List, Optional
|
3 |
|
|
|
4 |
import transformers
|
5 |
|
6 |
# We must use relative import in this directory to allow uploading to HF Hub
|
@@ -41,27 +42,30 @@ class UltravoxPipeline(transformers.Pipeline):
|
|
41 |
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
|
42 |
|
43 |
def _sanitize_parameters(self, **kwargs):
|
44 |
-
|
45 |
-
if
|
46 |
-
generation_kwargs["temperature"] = kwargs["temperature"]
|
47 |
-
if "max_new_tokens" in kwargs:
|
48 |
-
generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
|
49 |
-
if "repetition_penalty" in kwargs:
|
50 |
-
generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
|
51 |
return {}, generation_kwargs, {}
|
52 |
|
53 |
def preprocess(self, inputs: Dict[str, Any]):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
prompt = inputs.get("prompt", "<|audio|>")
|
61 |
if "<|audio|>" not in prompt:
|
62 |
logging.warning(
|
63 |
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
|
64 |
)
|
|
|
65 |
prompt += " <|audio|>"
|
66 |
turns.append({"role": "user", "content": prompt})
|
67 |
|
@@ -69,17 +73,14 @@ class UltravoxPipeline(transformers.Pipeline):
|
|
69 |
turns, add_generation_prompt=True, tokenize=False
|
70 |
)
|
71 |
|
72 |
-
|
73 |
-
assert "audio" in inputs, "Audio input is required"
|
74 |
-
|
75 |
-
if "sampling_rate" not in inputs:
|
76 |
logging.warning(
|
77 |
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
|
78 |
)
|
79 |
|
80 |
output = self.processor(
|
81 |
text=text,
|
82 |
-
audio=
|
83 |
sampling_rate=inputs.get("sampling_rate", 16000),
|
84 |
)
|
85 |
if "audio_values" in output:
|
@@ -123,4 +124,4 @@ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
|
|
123 |
pipeline_class=UltravoxPipeline,
|
124 |
pt_model=transformers.AutoModel,
|
125 |
type="multimodal",
|
126 |
-
)
|
|
|
1 |
import logging
|
2 |
from typing import Any, Dict, List, Optional
|
3 |
|
4 |
+
import numpy as np
|
5 |
import transformers
|
6 |
|
7 |
# We must use relative import in this directory to allow uploading to HF Hub
|
|
|
42 |
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
|
43 |
|
44 |
def _sanitize_parameters(self, **kwargs):
|
45 |
+
generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
|
46 |
+
generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
|
|
|
|
|
|
|
|
|
|
|
47 |
return {}, generation_kwargs, {}
|
48 |
|
49 |
def preprocess(self, inputs: Dict[str, Any]):
|
50 |
+
turns: list = inputs.get("turns", [])
|
51 |
+
|
52 |
+
audio = inputs.get("audio", None)
|
53 |
+
# Convert to float32 if needed.
|
54 |
+
if isinstance(audio, np.ndarray):
|
55 |
+
if audio.dtype == np.float64:
|
56 |
+
audio = audio.astype(np.float32)
|
57 |
+
elif audio.dtype == np.int16:
|
58 |
+
audio = audio.astype(np.float32) / np.float32(32768.0)
|
59 |
+
elif audio.dtype == np.int32:
|
60 |
+
audio = audio.astype(np.float32) / np.float32(2147483648.0)
|
61 |
+
|
62 |
+
if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
|
63 |
prompt = inputs.get("prompt", "<|audio|>")
|
64 |
if "<|audio|>" not in prompt:
|
65 |
logging.warning(
|
66 |
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
|
67 |
)
|
68 |
+
|
69 |
prompt += " <|audio|>"
|
70 |
turns.append({"role": "user", "content": prompt})
|
71 |
|
|
|
73 |
turns, add_generation_prompt=True, tokenize=False
|
74 |
)
|
75 |
|
76 |
+
if "sampling_rate" not in inputs and audio is not None:
|
|
|
|
|
|
|
77 |
logging.warning(
|
78 |
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
|
79 |
)
|
80 |
|
81 |
output = self.processor(
|
82 |
text=text,
|
83 |
+
audio=audio,
|
84 |
sampling_rate=inputs.get("sampling_rate", 16000),
|
85 |
)
|
86 |
if "audio_values" in output:
|
|
|
124 |
pipeline_class=UltravoxPipeline,
|
125 |
pt_model=transformers.AutoModel,
|
126 |
type="multimodal",
|
127 |
+
)
|