|
--- |
|
language: ja |
|
license: apache-2.0 |
|
tags: |
|
- speech |
|
- speaker-diarization |
|
datasets: |
|
- callhome |
|
--- |
|
|
|
# Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call |
|
|
|
2 speakers diarization model which was fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using phone-call data [CallHome](https://media.talkbank.org/ca/CallHome/jpn/). |
|
|
|
## Usage |
|
The model can be used directly as follows. |
|
|
|
```python |
|
import numpy as np |
|
import torch |
|
from pydub import AudioSegment |
|
|
|
from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor |
|
|
|
|
|
def _make_timegrid(sound_duration: float, total_len: int): |
|
start_timegrid = np.linspace(0, sound_duration, total_len + 1) |
|
dt = start_timegrid[1] - start_timegrid[0] |
|
end_timegrid = start_timegrid + dt |
|
return start_timegrid[:total_len], end_timegrid[:total_len] |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor( |
|
feature_size=1, |
|
sampling_rate=16_000, |
|
padding_value=0.0, |
|
do_normalize=True, |
|
return_attention_mask=True, |
|
) |
|
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp") |
|
filepath = "/path/to/file.wav" |
|
sound = AudioSegment.from_file(filepath) |
|
sound = sound.set_frame_rate(16_000) |
|
sound_duration = sound.duration_seconds |
|
|
|
feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0] |
|
input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_values).logits |
|
pred = logits.argmax(dim=-1).squeeze(0) |
|
start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred)) |
|
|
|
print("sec speaker_label") |
|
for p, start_time in zip(pred, start_timegrid): |
|
print(f"{start_time:.4f} {p}") |
|
``` |
|
|
|
## Training |
|
|
|
The model was trained on Japanese phone-call corpus [CallHome](https://media.talkbank.org/ca/CallHome/jpn/). |
|
|
|
## License |
|
|
|
[The Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0) |
|
|