Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call

2 speakers diarization model which was fine-tuned facebook/wav2vec2-large-xlsr-53 on Japanese using phone-call data CallHome.

Usage

The model can be used directly as follows.

import numpy as np
import torch
from pydub import AudioSegment

from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor


def _make_timegrid(sound_duration: float, total_len: int):
    start_timegrid = np.linspace(0, sound_duration, total_len + 1)
    dt = start_timegrid[1] - start_timegrid[0]
    end_timegrid = start_timegrid + dt
    return start_timegrid[:total_len], end_timegrid[:total_len]

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16_000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=True,
)
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp")
filepath = "/path/to/file.wav"
sound = AudioSegment.from_file(filepath)
sound = sound.set_frame_rate(16_000)
sound_duration = sound.duration_seconds

feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0]
input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)

with torch.no_grad():
    logits = model(input_values).logits
pred = logits.argmax(dim=-1).squeeze(0)
start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred))

print("sec     speaker_label")
for p, start_time in zip(pred, start_timegrid):
    print(f"{start_time:.4f}  {p}")

Training

The model was trained on Japanese phone-call corpus CallHome.

License

The Apache 2.0 license

Downloads last month
201
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.