Kosuke-Szk's picture
Update README.md
adeb5b6
metadata
language: ja
license: apache-2.0
tags:
  - speech
  - speaker-diarization
datasets:
  - callhome

Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call

2 speakers diarization model which was fine-tuned facebook/wav2vec2-large-xlsr-53 on Japanese using phone-call data CallHome.

Usage

The model can be used directly as follows.

import numpy as np
import torch
from pydub import AudioSegment

from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor


def _make_timegrid(sound_duration: float, total_len: int):
    start_timegrid = np.linspace(0, sound_duration, total_len + 1)
    dt = start_timegrid[1] - start_timegrid[0]
    end_timegrid = start_timegrid + dt
    return start_timegrid[:total_len], end_timegrid[:total_len]

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16_000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=True,
)
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp")
filepath = "/path/to/file.wav"
sound = AudioSegment.from_file(filepath)
sound = sound.set_frame_rate(16_000)
sound_duration = sound.duration_seconds

feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0]
input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)

with torch.no_grad():
    logits = model(input_values).logits
pred = logits.argmax(dim=-1).squeeze(0)
start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred))

print("sec     speaker_label")
for p, start_time in zip(pred, start_timegrid):
    print(f"{start_time:.4f}  {p}")

Training

The model was trained on Japanese phone-call corpus CallHome.

License

The Apache 2.0 license