asahi417 commited on
Commit
aaccb5f
·
verified ·
1 Parent(s): 4f38470

Upload KotobaWhisperPipeline

Browse files
Files changed (2) hide show
  1. README.md +2 -3
  2. kotoba_whisper.py +76 -25
README.md CHANGED
@@ -2,15 +2,14 @@
2
  language: ja
3
  library_name: transformers
4
  license: apache-2.0
 
5
  tags:
6
  - audio
7
  - automatic-speech-recognition
8
  - hf-asr-leaderboard
9
  widget:
10
  - example_title: Sample 1
11
- src: >-
12
- https://huggingface.co/kotoba-tech/kotoba-whisper-v2.2/resolve/main/sample_audio/sample_diarization_japanese.mp3
13
- pipeline_tag: automatic-speech-recognition
14
  ---
15
 
16
  # Kotoba-Whisper-v2.2
 
2
  language: ja
3
  library_name: transformers
4
  license: apache-2.0
5
+ pipeline_tag: automatic-speech-recognition
6
  tags:
7
  - audio
8
  - automatic-speech-recognition
9
  - hf-asr-leaderboard
10
  widget:
11
  - example_title: Sample 1
12
+ src: https://huggingface.co/kotoba-tech/kotoba-whisper-v2.2/resolve/main/sample_audio/sample_diarization_japanese.mp3
 
 
13
  ---
14
 
15
  # Kotoba-Whisper-v2.2
kotoba_whisper.py CHANGED
@@ -12,29 +12,63 @@ from transformers.tokenization_utils import PreTrainedTokenizer
12
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
13
  from pyannote.audio import Pipeline
14
  from pyannote.core.annotation import Annotation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class SpeakerDiarization:
18
 
19
- def __init__(self, model_id: str, device: torch.device):
 
 
 
20
  self.device = device
21
  self.pipeline = Pipeline.from_pretrained(model_id)
22
  self.pipeline = self.pipeline.to(self.device)
 
 
 
 
23
 
24
- def __call__(self,
25
- audio: Union[str, torch.Tensor, np.ndarray],
26
- sampling_rate: Optional[int] = None) -> Annotation:
27
- if type(audio) is torch.Tensor or type(audio) is np.ndarray:
28
- if sampling_rate is None:
29
- raise ValueError("sampling_rate must be provided")
30
- if type(audio) is np.ndarray:
31
- audio = torch.as_tensor(audio)
32
- audio = torch.as_tensor(audio, dtype=torch.float32)
33
- if len(audio.shape) == 1:
34
- audio = audio.unsqueeze(0)
35
- elif len(audio.shape) > 3:
36
- raise ValueError("audio shape must be (channel, time)")
37
- audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
38
  output = self.pipeline(audio)
39
  return output
40
 
@@ -43,23 +77,33 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
43
 
44
  def __init__(self,
45
  model: "PreTrainedModel",
46
- model_diarizarization: str="pyannote/speaker-diarization-3.1",
 
47
  feature_extractor: Union["SequenceFeatureExtractor", str] = None,
48
  tokenizer: Optional[PreTrainedTokenizer] = None,
49
  device: Union[int, "torch.device"] = None,
50
- device_diarizarization: Union[int, "torch.device"] = None,
51
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
52
- return_unique_speaker: bool = False,
 
53
  **kwargs):
54
  self.type = "seq2seq_whisper"
55
  if device is None:
56
  device = "cpu"
57
- if device_diarizarization is None:
58
- device_diarizarization = device
59
- if type(device_diarizarization) is str:
60
- device_diarizarization = torch.device(device_diarizarization)
61
- self.model_speaker_diarization = SpeakerDiarization(model_diarizarization, device_diarizarization)
 
 
 
 
62
  self.return_unique_speaker = return_unique_speaker
 
 
 
 
63
  super().__init__(
64
  model=model,
65
  feature_extractor=feature_extractor,
@@ -269,11 +313,18 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
269
  else:
270
  i["speaker"] = []
271
  outputs["chunks"] = new_chunks
 
 
272
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
273
  outputs["speakers"] = sd.labels()
274
  outputs.pop("audio_array")
 
275
  for s in outputs["speakers"]:
276
- outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
277
- outputs[f"chunks/{s}"] = [c for c in outputs["chunks"] if s in c["speaker"]]
 
 
 
 
278
  return outputs
279
 
 
12
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
13
  from pyannote.audio import Pipeline
14
  from pyannote.core.annotation import Annotation
15
+ from punctuators.models import PunctCapSegModelONNX
16
+ from diarizers import SegmentationModel
17
+
18
+
19
+ class Punctuator:
20
+
21
+ ja_punctuations = ["!", "?", "、", "。"]
22
+
23
+ def __init__(self, model: str = "pcs_47lang"):
24
+ self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
25
+
26
+ def punctuate(self, pipeline_chunk: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
27
+
28
+ def validate_punctuation(raw: str, punctuated: str):
29
+ if 'unk' in punctuated.lower() or any(p in raw for p in self.ja_punctuations):
30
+ return raw
31
+ if punctuated.count("。") > 1:
32
+ ind = punctuated.rfind("。")
33
+ punctuated = punctuated.replace("。", "")
34
+ punctuated = punctuated[:ind] + "。" + punctuated[ind:]
35
+ return punctuated
36
+
37
+ text_edit = self.punctuation_model.infer([c['text'] for c in pipeline_chunk])
38
+ return [
39
+ {
40
+ 'timestamp': c['timestamp'],
41
+ 'text': validate_punctuation(c['text'], "".join(e))
42
+ } for c, e in zip(pipeline_chunk, text_edit)
43
+ ]
44
+
45
 
46
 
47
  class SpeakerDiarization:
48
 
49
+ def __init__(self,
50
+ device: torch.device,
51
+ model_id: str = "pyannote/speaker-diarization-3.1",
52
+ model_id_diarizers: Optional[str] = None):
53
  self.device = device
54
  self.pipeline = Pipeline.from_pretrained(model_id)
55
  self.pipeline = self.pipeline.to(self.device)
56
+ if model_id_diarizers:
57
+ self.pipeline._segmentation.model = SegmentationModel().from_pretrained(
58
+ model_id_diarizers
59
+ ).to_pyannote_model().to(self.device)
60
 
61
+ def __call__(self, audio: Union[torch.Tensor, np.ndarray], sampling_rate: int) -> Annotation:
62
+ if sampling_rate is None:
63
+ raise ValueError("sampling_rate must be provided")
64
+ if type(audio) is np.ndarray:
65
+ audio = torch.as_tensor(audio)
66
+ audio = torch.as_tensor(audio, dtype=torch.float32)
67
+ if len(audio.shape) == 1:
68
+ audio = audio.unsqueeze(0)
69
+ elif len(audio.shape) > 3:
70
+ raise ValueError("audio shape must be (channel, time)")
71
+ audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
 
 
 
72
  output = self.pipeline(audio)
73
  return output
74
 
 
77
 
78
  def __init__(self,
79
  model: "PreTrainedModel",
80
+ model_pyannote: str = "pyannote/speaker-diarization-3.1",
81
+ model_diarizers: Optional[str] = "diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn",
82
  feature_extractor: Union["SequenceFeatureExtractor", str] = None,
83
  tokenizer: Optional[PreTrainedTokenizer] = None,
84
  device: Union[int, "torch.device"] = None,
85
+ device_pyannote: Union[int, "torch.device"] = None,
86
  torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
87
+ return_unique_speaker: bool = True,
88
+ punctuator: bool = False,
89
  **kwargs):
90
  self.type = "seq2seq_whisper"
91
  if device is None:
92
  device = "cpu"
93
+ if device_pyannote is None:
94
+ device_pyannote = device
95
+ if type(device_pyannote) is str:
96
+ device_pyannote = torch.device(device_pyannote)
97
+ self.model_speaker_diarization = SpeakerDiarization(
98
+ device=device_pyannote,
99
+ model_id=model_pyannote,
100
+ model_id_diarizers=model_diarizers
101
+ )
102
  self.return_unique_speaker = return_unique_speaker
103
+ if punctuator:
104
+ self.punctuator = Punctuator()
105
+ else:
106
+ self.punctuator = None
107
  super().__init__(
108
  model=model,
109
  feature_extractor=feature_extractor,
 
313
  else:
314
  i["speaker"] = []
315
  outputs["chunks"] = new_chunks
316
+ if self.punctuator:
317
+ outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
318
  outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
319
  outputs["speakers"] = sd.labels()
320
  outputs.pop("audio_array")
321
+ speakers = []
322
  for s in outputs["speakers"]:
323
+ chunk_s = [c for c in outputs["chunks"] if s in c["speaker"]]
324
+ if len(chunk_s) != 0:
325
+ outputs[f"chunks/{s}"] = chunk_s
326
+ outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
327
+ speakers.append(s)
328
+ outputs["speakers"] = speakers
329
  return outputs
330