Spaces:
Running
Running
File size: 4,766 Bytes
595b5f3 ddbe0b6 6cee2a2 595b5f3 32d4384 595b5f3 7d9eec3 ada247c ddbe0b6 595b5f3 be3301d 595b5f3 7d9eec3 595b5f3 6cee2a2 ddbe0b6 595b5f3 1ba51b4 ddbe0b6 595b5f3 ddbe0b6 595b5f3 1ba51b4 595b5f3 ddbe0b6 595b5f3 15b3a25 1ba51b4 15b3a25 595b5f3 ada247c 32d4384 595b5f3 ada247c 595b5f3 ddbe0b6 595b5f3 ddbe0b6 595b5f3 ddbe0b6 595b5f3 15b3a25 595b5f3 32d4384 595b5f3 32d4384 595b5f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import torch
from typing import List, Union, BinaryIO, Optional, Tuple
import numpy as np
import time
import logging
from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
from modules.diarize.audio_loader import load_audio
from modules.whisper.data_classes import *
class Diarizer:
def __init__(self,
model_dir: str = DIARIZATION_MODELS_DIR
):
self.device = self.get_device()
self.available_device = self.get_available_device()
self.compute_type = "float16"
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
self.pipe = None
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
transcribed_result: List[Segment],
use_auth_token: str,
device: Optional[str] = None
) -> Tuple[List[Segment], float]:
"""
Diarize transcribed result as a post-processing
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio input. This can be file path or binary type.
transcribed_result: List[Segment]
transcribed result through whisper.
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
device: Optional[str]
Device for diarization.
Returns
----------
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
start_time = time.time()
if device is None:
device = self.device
if device != self.device or self.pipe is None:
self.update_pipe(
device=device,
use_auth_token=use_auth_token
)
audio = load_audio(audio)
diarization_segments = self.pipe(audio)
diarized_result = assign_word_speakers(
diarization_segments,
{"segments": transcribed_result}
)
segments_result = []
for segment in diarized_result["segments"]:
speaker = "None"
if "speaker" in segment:
speaker = segment["speaker"]
diarized_text = speaker + "|" + segment["text"].strip()
segments_result.append(Segment(
start=segment["start"],
end=segment["end"],
text=diarized_text
))
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
def update_pipe(self,
use_auth_token: str,
device: str
):
"""
Set pipeline for diarization
Parameters
----------
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
device: str
Device for diarization.
"""
self.device = device
os.makedirs(self.model_dir, exist_ok=True)
if (not os.listdir(self.model_dir) and
not use_auth_token):
print(
"\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
"Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
)
return
logger = logging.getLogger("speechbrain.utils.train_logger")
# Disable redundant torchvision warning message
logger.disabled = True
self.pipe = DiarizationPipeline(
use_auth_token=use_auth_token,
device=device,
cache_dir=self.model_dir
)
logger.disabled = False
@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
@staticmethod
def get_available_device():
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
elif torch.backends.mps.is_available():
devices.append("mps")
return devices |