asahi417 commited on
Commit
3d49cbd
1 Parent(s): 0e5b68e

Upload KotobaWhisperPipeline

Browse files
Files changed (1) hide show
  1. kotoba_whisper.py +97 -18
kotoba_whisper.py CHANGED
@@ -1,5 +1,6 @@
1
- from typing import Union, Optional, Dict, List, Any
2
  import requests
 
 
3
 
4
  import torch
5
  import numpy as np
@@ -38,12 +39,12 @@ class Punctuator:
38
  return [
39
  {
40
  'timestamp': c['timestamp'],
 
41
  'text': validate_punctuation(c['text'], "".join(e))
42
  } for c, e in zip(pipeline_chunk, text_edit)
43
  ]
44
 
45
 
46
-
47
  class SpeakerDiarization:
48
 
49
  def __init__(self,
@@ -58,7 +59,12 @@ class SpeakerDiarization:
58
  model_id_diarizers
59
  ).to_pyannote_model().to(self.device)
60
 
61
- def __call__(self, audio: Union[torch.Tensor, np.ndarray], sampling_rate: int) -> Annotation:
 
 
 
 
 
62
  if sampling_rate is None:
63
  raise ValueError("sampling_rate must be provided")
64
  if type(audio) is np.ndarray:
@@ -69,7 +75,7 @@ class SpeakerDiarization:
69
  elif len(audio.shape) > 3:
70
  raise ValueError("audio shape must be (channel, time)")
71
  audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
72
- output = self.pipeline(audio)
73
  return output
74
 
75
 
@@ -84,8 +90,6 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
84
  device: Union[int, "torch.device"] = None,
85
  device_pyannote: Union[int, "torch.device"] = None,
86
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
87
- return_unique_speaker: bool = True,
88
- punctuator: bool = False,
89
  **kwargs):
90
  self.type = "seq2seq_whisper"
91
  if device is None:
@@ -99,11 +103,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
99
  model_id=model_pyannote,
100
  model_id_diarizers=model_diarizers
101
  )
102
- self.return_unique_speaker = return_unique_speaker
103
- if punctuator:
104
- self.punctuator = Punctuator()
105
- else:
106
- self.punctuator = None
107
  super().__init__(
108
  model=model,
109
  feature_extractor=feature_extractor,
@@ -113,6 +113,71 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
113
  **kwargs
114
  )
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
117
  if isinstance(inputs, str):
118
  if inputs.startswith("http://") or inputs.startswith("https://"):
@@ -259,18 +324,31 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
259
  model_outputs,
260
  decoder_kwargs: Optional[Dict] = None,
261
  return_language=None,
 
 
 
 
 
262
  *args,
263
  **kwargs):
264
  assert len(model_outputs) > 0
265
- audio_array = list(model_outputs)[0]["audio_array"]
266
- sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
267
- timelines = sd.get_timeline()
268
  outputs = super().postprocess(
269
  model_outputs=model_outputs,
270
  decoder_kwargs=decoder_kwargs,
271
  return_timestamps=True,
272
  return_language=return_language
273
  )
 
 
 
 
 
 
 
 
 
 
 
274
  pointer_ts = 0
275
  pointer_chunk = 0
276
  new_chunks = []
@@ -306,18 +384,19 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
306
  pointer_ts += 1
307
  for i in new_chunks:
308
  if "speaker" in i:
309
- if self.return_unique_speaker:
310
  i["speaker"] = [i["speaker"][0]]
311
  else:
312
  i["speaker"] = list(set(i["speaker"]))
313
  else:
314
  i["speaker"] = []
315
  outputs["chunks"] = new_chunks
316
- if self.punctuator:
 
 
317
  outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
318
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
319
  outputs["speakers"] = sd.labels()
320
- outputs.pop("audio_array")
321
  speakers = []
322
  for s in outputs["speakers"]:
323
  chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
@@ -326,5 +405,5 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
326
  outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
327
  speakers.append(s)
328
  outputs["speakers"] = speakers
 
329
  return outputs
330
-
 
 
1
  import requests
2
+ from typing import Union, Optional, Dict, List, Any
3
+ from collections import defaultdict
4
 
5
  import torch
6
  import numpy as np
 
39
  return [
40
  {
41
  'timestamp': c['timestamp'],
42
+ 'speaker': c['speaker'],
43
  'text': validate_punctuation(c['text'], "".join(e))
44
  } for c, e in zip(pipeline_chunk, text_edit)
45
  ]
46
 
47
 
 
48
  class SpeakerDiarization:
49
 
50
  def __init__(self,
 
59
  model_id_diarizers
60
  ).to_pyannote_model().to(self.device)
61
 
62
+ def __call__(self,
63
+ audio: Union[torch.Tensor, np.ndarray],
64
+ sampling_rate: int,
65
+ num_speakers: Optional[int] = None,
66
+ min_speakers: Optional[int] = None,
67
+ max_speakers: Optional[int] = None) -> Annotation:
68
  if sampling_rate is None:
69
  raise ValueError("sampling_rate must be provided")
70
  if type(audio) is np.ndarray:
 
75
  elif len(audio.shape) > 3:
76
  raise ValueError("audio shape must be (channel, time)")
77
  audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
78
+ output = self.pipeline(audio, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
79
  return output
80
 
81
 
 
90
  device: Union[int, "torch.device"] = None,
91
  device_pyannote: Union[int, "torch.device"] = None,
92
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
 
 
93
  **kwargs):
94
  self.type = "seq2seq_whisper"
95
  if device is None:
 
103
  model_id=model_pyannote,
104
  model_id_diarizers=model_diarizers
105
  )
106
+ self.punctuator = None
 
 
 
 
107
  super().__init__(
108
  model=model,
109
  feature_extractor=feature_extractor,
 
113
  **kwargs
114
  )
115
 
116
+ def _sanitize_parameters(self,
117
+ chunk_length_s=None,
118
+ stride_length_s=None,
119
+ ignore_warning=None,
120
+ decoder_kwargs=None,
121
+ return_timestamps=None,
122
+ return_language=None,
123
+ generate_kwargs=None,
124
+ max_new_tokens=None,
125
+ add_punctuation: bool =False,
126
+ return_unique_speaker: bool =True,
127
+ num_speakers: Optional[int] = None,
128
+ min_speakers: Optional[int] = None,
129
+ max_speakers: Optional[int] = None):
130
+ # No parameters on this pipeline right now
131
+ preprocess_params = {}
132
+ if chunk_length_s is not None:
133
+ preprocess_params["chunk_length_s"] = chunk_length_s
134
+ if stride_length_s is not None:
135
+ preprocess_params["stride_length_s"] = stride_length_s
136
+
137
+ forward_params = defaultdict(dict)
138
+ if max_new_tokens is not None:
139
+ forward_params["max_new_tokens"] = max_new_tokens
140
+ if generate_kwargs is not None:
141
+ if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
142
+ raise ValueError(
143
+ "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
144
+ " only 1 version"
145
+ )
146
+ forward_params.update(generate_kwargs)
147
+
148
+ postprocess_params = {}
149
+ if decoder_kwargs is not None:
150
+ postprocess_params["decoder_kwargs"] = decoder_kwargs
151
+ if return_timestamps is not None:
152
+ # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
153
+ if self.type == "seq2seq" and return_timestamps:
154
+ raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
155
+ if self.type == "ctc_with_lm" and return_timestamps != "word":
156
+ raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
157
+ if self.type == "ctc" and return_timestamps not in ["char", "word"]:
158
+ raise ValueError(
159
+ "CTC can either predict character level timestamps, or word level timestamps. "
160
+ "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
161
+ )
162
+ if self.type == "seq2seq_whisper" and return_timestamps == "char":
163
+ raise ValueError(
164
+ "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
165
+ "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
166
+ )
167
+ forward_params["return_timestamps"] = return_timestamps
168
+ postprocess_params["return_timestamps"] = return_timestamps
169
+ if return_language is not None:
170
+ if self.type != "seq2seq_whisper":
171
+ raise ValueError("Only Whisper can return language for now.")
172
+ postprocess_params["return_language"] = return_language
173
+ postprocess_params["return_language"] = return_language
174
+ postprocess_params["add_punctuation"] = add_punctuation
175
+ postprocess_params["return_unique_speaker"] = return_unique_speaker
176
+ postprocess_params["num_speakers"] = num_speakers
177
+ postprocess_params["min_speakers"] = min_speakers
178
+ postprocess_params["max_speakers"] = max_speakers
179
+ return preprocess_params, forward_params, postprocess_params
180
+
181
  def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
182
  if isinstance(inputs, str):
183
  if inputs.startswith("http://") or inputs.startswith("https://"):
 
324
  model_outputs,
325
  decoder_kwargs: Optional[Dict] = None,
326
  return_language=None,
327
+ add_punctuation: bool = False,
328
+ return_unique_speaker: bool = True,
329
+ num_speakers: Optional[int] = None,
330
+ min_speakers: Optional[int] = None,
331
+ max_speakers: Optional[int] = None,
332
  *args,
333
  **kwargs):
334
  assert len(model_outputs) > 0
 
 
 
335
  outputs = super().postprocess(
336
  model_outputs=model_outputs,
337
  decoder_kwargs=decoder_kwargs,
338
  return_timestamps=True,
339
  return_language=return_language
340
  )
341
+ audio_array = outputs.pop("audio_array")[0]
342
+ sd = self.model_speaker_diarization(
343
+ audio_array,
344
+ num_speakers=num_speakers,
345
+ min_speakers=min_speakers,
346
+ max_speakers=max_speakers,
347
+ sampling_rate=self.feature_extractor.sampling_rate
348
+ )
349
+ diarization_result = {s: [[i.start, i.end] for i in sd.label_timeline(s)] for s in sd.labels()}
350
+ timelines = sd.get_timeline()
351
+
352
  pointer_ts = 0
353
  pointer_chunk = 0
354
  new_chunks = []
 
384
  pointer_ts += 1
385
  for i in new_chunks:
386
  if "speaker" in i:
387
+ if return_unique_speaker:
388
  i["speaker"] = [i["speaker"][0]]
389
  else:
390
  i["speaker"] = list(set(i["speaker"]))
391
  else:
392
  i["speaker"] = []
393
  outputs["chunks"] = new_chunks
394
+ if add_punctuation:
395
+ if self.punctuator is None:
396
+ self.punctuator = Punctuator()
397
  outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
398
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
399
  outputs["speakers"] = sd.labels()
 
400
  speakers = []
401
  for s in outputs["speakers"]:
402
  chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
 
405
  outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
406
  speakers.append(s)
407
  outputs["speakers"] = speakers
408
+ outputs["diarization_result"] = diarization_result
409
  return outputs