File size: 2,948 Bytes
d5d7329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

from collections import defaultdict
from logging import getLogger
from pathlib import Path

import librosa
import soundfile as sf
import torch
from joblib import Parallel, delayed
from pyannote.audio import Pipeline
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib

LOG = getLogger(__name__)


def _process_one(
    input_path: Path,
    output_dir: Path,
    sr: int,
    *,
    min_speakers: int = 1,
    max_speakers: int = 1,
    huggingface_token: str | None = None,
) -> None:
    try:
        audio, sr = librosa.load(input_path, sr=sr, mono=True)
    except Exception as e:
        LOG.warning(f"Failed to read {input_path}: {e}")
        return
    pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization", use_auth_token=huggingface_token
    )
    if pipeline is None:
        raise ValueError("Failed to load pipeline")

    LOG.info(f"Processing {input_path}. This may take a while...")
    diarization = pipeline(
        input_path, min_speakers=min_speakers, max_speakers=max_speakers
    )

    LOG.info(f"Found {len(diarization)} tracks, writing to {output_dir}")
    speaker_count = defaultdict(int)

    output_dir.mkdir(parents=True, exist_ok=True)
    for segment, track, speaker in tqdm(
        list(diarization.itertracks(yield_label=True)), desc=f"Writing {input_path}"
    ):
        if segment.end - segment.start < 1:
            continue
        speaker_count[speaker] += 1
        audio_cut = audio[int(segment.start * sr) : int(segment.end * sr)]
        sf.write(
            (output_dir / f"{speaker}_{speaker_count[speaker]}.wav"),
            audio_cut,
            sr,
        )

    LOG.info(f"Speaker count: {speaker_count}")


def preprocess_speaker_diarization(
    input_dir: Path | str,
    output_dir: Path | str,
    sr: int,
    *,
    min_speakers: int = 1,
    max_speakers: int = 1,
    huggingface_token: str | None = None,
    n_jobs: int = -1,
) -> None:
    if huggingface_token is not None and not huggingface_token.startswith("hf_"):
        LOG.warning("Huggingface token probably should start with hf_")
    if not torch.cuda.is_available():
        LOG.warning("CUDA is not available. This will be extremely slow.")
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    input_dir.mkdir(parents=True, exist_ok=True)
    output_dir.mkdir(parents=True, exist_ok=True)
    input_paths = list(input_dir.rglob("*.*"))
    with tqdm_joblib(desc="Preprocessing speaker diarization", total=len(input_paths)):
        Parallel(n_jobs=n_jobs)(
            delayed(_process_one)(
                input_path,
                output_dir / input_path.relative_to(input_dir).parent / input_path.stem,
                sr,
                max_speakers=max_speakers,
                min_speakers=min_speakers,
                huggingface_token=huggingface_token,
            )
            for input_path in input_paths
        )