Upload KotobaWhisperPipeline
Browse files- kotoba_whisper.py +97 -18
kotoba_whisper.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from typing import Union, Optional, Dict, List, Any
|
2 |
import requests
|
|
|
|
|
3 |
|
4 |
import torch
|
5 |
import numpy as np
|
@@ -38,12 +39,12 @@ class Punctuator:
|
|
38 |
return [
|
39 |
{
|
40 |
'timestamp': c['timestamp'],
|
|
|
41 |
'text': validate_punctuation(c['text'], "".join(e))
|
42 |
} for c, e in zip(pipeline_chunk, text_edit)
|
43 |
]
|
44 |
|
45 |
|
46 |
-
|
47 |
class SpeakerDiarization:
|
48 |
|
49 |
def __init__(self,
|
@@ -58,7 +59,12 @@ class SpeakerDiarization:
|
|
58 |
model_id_diarizers
|
59 |
).to_pyannote_model().to(self.device)
|
60 |
|
61 |
-
def __call__(self,
|
|
|
|
|
|
|
|
|
|
|
62 |
if sampling_rate is None:
|
63 |
raise ValueError("sampling_rate must be provided")
|
64 |
if type(audio) is np.ndarray:
|
@@ -69,7 +75,7 @@ class SpeakerDiarization:
|
|
69 |
elif len(audio.shape) > 3:
|
70 |
raise ValueError("audio shape must be (channel, time)")
|
71 |
audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
|
72 |
-
output = self.pipeline(audio)
|
73 |
return output
|
74 |
|
75 |
|
@@ -84,8 +90,6 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
84 |
device: Union[int, "torch.device"] = None,
|
85 |
device_pyannote: Union[int, "torch.device"] = None,
|
86 |
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
87 |
-
return_unique_speaker: bool = True,
|
88 |
-
punctuator: bool = False,
|
89 |
**kwargs):
|
90 |
self.type = "seq2seq_whisper"
|
91 |
if device is None:
|
@@ -99,11 +103,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
99 |
model_id=model_pyannote,
|
100 |
model_id_diarizers=model_diarizers
|
101 |
)
|
102 |
-
self.
|
103 |
-
if punctuator:
|
104 |
-
self.punctuator = Punctuator()
|
105 |
-
else:
|
106 |
-
self.punctuator = None
|
107 |
super().__init__(
|
108 |
model=model,
|
109 |
feature_extractor=feature_extractor,
|
@@ -113,6 +113,71 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
113 |
**kwargs
|
114 |
)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
117 |
if isinstance(inputs, str):
|
118 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
@@ -259,18 +324,31 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
259 |
model_outputs,
|
260 |
decoder_kwargs: Optional[Dict] = None,
|
261 |
return_language=None,
|
|
|
|
|
|
|
|
|
|
|
262 |
*args,
|
263 |
**kwargs):
|
264 |
assert len(model_outputs) > 0
|
265 |
-
audio_array = list(model_outputs)[0]["audio_array"]
|
266 |
-
sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
|
267 |
-
timelines = sd.get_timeline()
|
268 |
outputs = super().postprocess(
|
269 |
model_outputs=model_outputs,
|
270 |
decoder_kwargs=decoder_kwargs,
|
271 |
return_timestamps=True,
|
272 |
return_language=return_language
|
273 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
pointer_ts = 0
|
275 |
pointer_chunk = 0
|
276 |
new_chunks = []
|
@@ -306,18 +384,19 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
306 |
pointer_ts += 1
|
307 |
for i in new_chunks:
|
308 |
if "speaker" in i:
|
309 |
-
if
|
310 |
i["speaker"] = [i["speaker"][0]]
|
311 |
else:
|
312 |
i["speaker"] = list(set(i["speaker"]))
|
313 |
else:
|
314 |
i["speaker"] = []
|
315 |
outputs["chunks"] = new_chunks
|
316 |
-
if
|
|
|
|
|
317 |
outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
|
318 |
outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
|
319 |
outputs["speakers"] = sd.labels()
|
320 |
-
outputs.pop("audio_array")
|
321 |
speakers = []
|
322 |
for s in outputs["speakers"]:
|
323 |
chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
|
@@ -326,5 +405,5 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
326 |
outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
|
327 |
speakers.append(s)
|
328 |
outputs["speakers"] = speakers
|
|
|
329 |
return outputs
|
330 |
-
|
|
|
|
|
1 |
import requests
|
2 |
+
from typing import Union, Optional, Dict, List, Any
|
3 |
+
from collections import defaultdict
|
4 |
|
5 |
import torch
|
6 |
import numpy as np
|
|
|
39 |
return [
|
40 |
{
|
41 |
'timestamp': c['timestamp'],
|
42 |
+
'speaker': c['speaker'],
|
43 |
'text': validate_punctuation(c['text'], "".join(e))
|
44 |
} for c, e in zip(pipeline_chunk, text_edit)
|
45 |
]
|
46 |
|
47 |
|
|
|
48 |
class SpeakerDiarization:
|
49 |
|
50 |
def __init__(self,
|
|
|
59 |
model_id_diarizers
|
60 |
).to_pyannote_model().to(self.device)
|
61 |
|
62 |
+
def __call__(self,
|
63 |
+
audio: Union[torch.Tensor, np.ndarray],
|
64 |
+
sampling_rate: int,
|
65 |
+
num_speakers: Optional[int] = None,
|
66 |
+
min_speakers: Optional[int] = None,
|
67 |
+
max_speakers: Optional[int] = None) -> Annotation:
|
68 |
if sampling_rate is None:
|
69 |
raise ValueError("sampling_rate must be provided")
|
70 |
if type(audio) is np.ndarray:
|
|
|
75 |
elif len(audio.shape) > 3:
|
76 |
raise ValueError("audio shape must be (channel, time)")
|
77 |
audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
|
78 |
+
output = self.pipeline(audio, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
|
79 |
return output
|
80 |
|
81 |
|
|
|
90 |
device: Union[int, "torch.device"] = None,
|
91 |
device_pyannote: Union[int, "torch.device"] = None,
|
92 |
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
|
|
|
|
93 |
**kwargs):
|
94 |
self.type = "seq2seq_whisper"
|
95 |
if device is None:
|
|
|
103 |
model_id=model_pyannote,
|
104 |
model_id_diarizers=model_diarizers
|
105 |
)
|
106 |
+
self.punctuator = None
|
|
|
|
|
|
|
|
|
107 |
super().__init__(
|
108 |
model=model,
|
109 |
feature_extractor=feature_extractor,
|
|
|
113 |
**kwargs
|
114 |
)
|
115 |
|
116 |
+
def _sanitize_parameters(self,
|
117 |
+
chunk_length_s=None,
|
118 |
+
stride_length_s=None,
|
119 |
+
ignore_warning=None,
|
120 |
+
decoder_kwargs=None,
|
121 |
+
return_timestamps=None,
|
122 |
+
return_language=None,
|
123 |
+
generate_kwargs=None,
|
124 |
+
max_new_tokens=None,
|
125 |
+
add_punctuation: bool =False,
|
126 |
+
return_unique_speaker: bool =True,
|
127 |
+
num_speakers: Optional[int] = None,
|
128 |
+
min_speakers: Optional[int] = None,
|
129 |
+
max_speakers: Optional[int] = None):
|
130 |
+
# No parameters on this pipeline right now
|
131 |
+
preprocess_params = {}
|
132 |
+
if chunk_length_s is not None:
|
133 |
+
preprocess_params["chunk_length_s"] = chunk_length_s
|
134 |
+
if stride_length_s is not None:
|
135 |
+
preprocess_params["stride_length_s"] = stride_length_s
|
136 |
+
|
137 |
+
forward_params = defaultdict(dict)
|
138 |
+
if max_new_tokens is not None:
|
139 |
+
forward_params["max_new_tokens"] = max_new_tokens
|
140 |
+
if generate_kwargs is not None:
|
141 |
+
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
142 |
+
raise ValueError(
|
143 |
+
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
|
144 |
+
" only 1 version"
|
145 |
+
)
|
146 |
+
forward_params.update(generate_kwargs)
|
147 |
+
|
148 |
+
postprocess_params = {}
|
149 |
+
if decoder_kwargs is not None:
|
150 |
+
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
151 |
+
if return_timestamps is not None:
|
152 |
+
# Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
|
153 |
+
if self.type == "seq2seq" and return_timestamps:
|
154 |
+
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
|
155 |
+
if self.type == "ctc_with_lm" and return_timestamps != "word":
|
156 |
+
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
|
157 |
+
if self.type == "ctc" and return_timestamps not in ["char", "word"]:
|
158 |
+
raise ValueError(
|
159 |
+
"CTC can either predict character level timestamps, or word level timestamps. "
|
160 |
+
"Set `return_timestamps='char'` or `return_timestamps='word'` as required."
|
161 |
+
)
|
162 |
+
if self.type == "seq2seq_whisper" and return_timestamps == "char":
|
163 |
+
raise ValueError(
|
164 |
+
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
165 |
+
"Use `return_timestamps='word'` or `return_timestamps=True` respectively."
|
166 |
+
)
|
167 |
+
forward_params["return_timestamps"] = return_timestamps
|
168 |
+
postprocess_params["return_timestamps"] = return_timestamps
|
169 |
+
if return_language is not None:
|
170 |
+
if self.type != "seq2seq_whisper":
|
171 |
+
raise ValueError("Only Whisper can return language for now.")
|
172 |
+
postprocess_params["return_language"] = return_language
|
173 |
+
postprocess_params["return_language"] = return_language
|
174 |
+
postprocess_params["add_punctuation"] = add_punctuation
|
175 |
+
postprocess_params["return_unique_speaker"] = return_unique_speaker
|
176 |
+
postprocess_params["num_speakers"] = num_speakers
|
177 |
+
postprocess_params["min_speakers"] = min_speakers
|
178 |
+
postprocess_params["max_speakers"] = max_speakers
|
179 |
+
return preprocess_params, forward_params, postprocess_params
|
180 |
+
|
181 |
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
182 |
if isinstance(inputs, str):
|
183 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
|
|
324 |
model_outputs,
|
325 |
decoder_kwargs: Optional[Dict] = None,
|
326 |
return_language=None,
|
327 |
+
add_punctuation: bool = False,
|
328 |
+
return_unique_speaker: bool = True,
|
329 |
+
num_speakers: Optional[int] = None,
|
330 |
+
min_speakers: Optional[int] = None,
|
331 |
+
max_speakers: Optional[int] = None,
|
332 |
*args,
|
333 |
**kwargs):
|
334 |
assert len(model_outputs) > 0
|
|
|
|
|
|
|
335 |
outputs = super().postprocess(
|
336 |
model_outputs=model_outputs,
|
337 |
decoder_kwargs=decoder_kwargs,
|
338 |
return_timestamps=True,
|
339 |
return_language=return_language
|
340 |
)
|
341 |
+
audio_array = outputs.pop("audio_array")[0]
|
342 |
+
sd = self.model_speaker_diarization(
|
343 |
+
audio_array,
|
344 |
+
num_speakers=num_speakers,
|
345 |
+
min_speakers=min_speakers,
|
346 |
+
max_speakers=max_speakers,
|
347 |
+
sampling_rate=self.feature_extractor.sampling_rate
|
348 |
+
)
|
349 |
+
diarization_result = {s: [[i.start, i.end] for i in sd.label_timeline(s)] for s in sd.labels()}
|
350 |
+
timelines = sd.get_timeline()
|
351 |
+
|
352 |
pointer_ts = 0
|
353 |
pointer_chunk = 0
|
354 |
new_chunks = []
|
|
|
384 |
pointer_ts += 1
|
385 |
for i in new_chunks:
|
386 |
if "speaker" in i:
|
387 |
+
if return_unique_speaker:
|
388 |
i["speaker"] = [i["speaker"][0]]
|
389 |
else:
|
390 |
i["speaker"] = list(set(i["speaker"]))
|
391 |
else:
|
392 |
i["speaker"] = []
|
393 |
outputs["chunks"] = new_chunks
|
394 |
+
if add_punctuation:
|
395 |
+
if self.punctuator is None:
|
396 |
+
self.punctuator = Punctuator()
|
397 |
outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
|
398 |
outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
|
399 |
outputs["speakers"] = sd.labels()
|
|
|
400 |
speakers = []
|
401 |
for s in outputs["speakers"]:
|
402 |
chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
|
|
|
405 |
outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
|
406 |
speakers.append(s)
|
407 |
outputs["speakers"] = speakers
|
408 |
+
outputs["diarization_result"] = diarization_result
|
409 |
return outputs
|
|