|
|
|
|
|
|
|
|
|
|
|
import faster_whisper |
|
from typing import List, Union, Optional, NamedTuple |
|
import torch |
|
import numpy as np |
|
import tqdm |
|
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram |
|
from whisperx.types import TranscriptionResult, SingleSegment |
|
from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens |
|
|
|
|
|
class VadFreeFasterWhisperPipeline(FasterWhisperPipeline): |
|
""" |
|
FasterWhisperModel without VAD |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model, |
|
options: NamedTuple, |
|
tokenizer=None, |
|
device: Union[int, str, "torch.device"] = -1, |
|
framework="pt", |
|
language: Optional[str] = None, |
|
suppress_numerals: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Initialize the VadFreeFasterWhisperPipeline. |
|
|
|
Args: |
|
model: The Whisper model instance. |
|
options: Transcription options. |
|
tokenizer: The tokenizer instance. |
|
device: Device to run the model on. |
|
framework: The framework to use ('pt' for PyTorch). |
|
language: The language for transcription. |
|
suppress_numerals: Whether to suppress numeral tokens. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
None |
|
""" |
|
super().__init__( |
|
model=model, |
|
vad=None, |
|
vad_params={}, |
|
options=options, |
|
tokenizer=tokenizer, |
|
device=device, |
|
framework=framework, |
|
language=language, |
|
suppress_numerals=suppress_numerals, |
|
**kwargs, |
|
) |
|
|
|
def detect_language(self, audio: np.ndarray): |
|
""" |
|
Detect the language of the audio. |
|
|
|
Args: |
|
audio (np.ndarray): The input audio signal. |
|
|
|
Returns: |
|
tuple: Detected language and its probability. |
|
""" |
|
model_n_mels = self.model.feat_kwargs.get("feature_size") |
|
if audio.shape[0] > N_SAMPLES: |
|
|
|
start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES) |
|
audio_sample = audio[start_index : start_index + N_SAMPLES] |
|
else: |
|
audio_sample = audio[:N_SAMPLES] |
|
padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0] |
|
segment = log_mel_spectrogram( |
|
audio_sample, |
|
n_mels=model_n_mels if model_n_mels is not None else 80, |
|
padding=padding, |
|
) |
|
encoder_output = self.model.encode(segment) |
|
results = self.model.model.detect_language(encoder_output) |
|
language_token, language_probability = results[0][0] |
|
language = language_token[2:-2] |
|
return language, language_probability |
|
|
|
def transcribe( |
|
self, |
|
audio: Union[str, np.ndarray], |
|
vad_segments: List[dict], |
|
batch_size=None, |
|
num_workers=0, |
|
language=None, |
|
task=None, |
|
chunk_size=30, |
|
print_progress=False, |
|
combined_progress=False, |
|
) -> TranscriptionResult: |
|
""" |
|
Transcribe the audio into text. |
|
|
|
Args: |
|
audio (Union[str, np.ndarray]): The input audio signal or path to audio file. |
|
vad_segments (List[dict]): List of VAD segments. |
|
batch_size (int, optional): Batch size for transcription. Defaults to None. |
|
num_workers (int, optional): Number of workers for loading data. Defaults to 0. |
|
language (str, optional): Language for transcription. Defaults to None. |
|
task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None. |
|
chunk_size (int, optional): Size of chunks for processing. Defaults to 30. |
|
print_progress (bool, optional): Whether to print progress. Defaults to False. |
|
combined_progress (bool, optional): Whether to combine progress. Defaults to False. |
|
|
|
Returns: |
|
TranscriptionResult: The transcription result containing segments and language. |
|
""" |
|
if isinstance(audio, str): |
|
audio = load_audio(audio) |
|
|
|
def data(audio, segments): |
|
for seg in segments: |
|
f1 = int(seg["start"] * SAMPLE_RATE) |
|
f2 = int(seg["end"] * SAMPLE_RATE) |
|
yield {"inputs": audio[f1:f2]} |
|
|
|
if self.tokenizer is None: |
|
language = language or self.detect_language(audio) |
|
task = task or "transcribe" |
|
self.tokenizer = faster_whisper.tokenizer.Tokenizer( |
|
self.model.hf_tokenizer, |
|
self.model.model.is_multilingual, |
|
task=task, |
|
language=language, |
|
) |
|
else: |
|
language = language or self.tokenizer.language_code |
|
task = task or self.tokenizer.task |
|
if task != self.tokenizer.task or language != self.tokenizer.language_code: |
|
self.tokenizer = faster_whisper.tokenizer.Tokenizer( |
|
self.model.hf_tokenizer, |
|
self.model.model.is_multilingual, |
|
task=task, |
|
language=language, |
|
) |
|
|
|
if self.suppress_numerals: |
|
previous_suppress_tokens = self.options.suppress_tokens |
|
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) |
|
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens |
|
new_suppressed_tokens = list(set(new_suppressed_tokens)) |
|
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) |
|
|
|
segments: List[SingleSegment] = [] |
|
batch_size = batch_size or self._batch_size |
|
total_segments = len(vad_segments) |
|
progress = tqdm.tqdm(total=total_segments, desc="Transcribing") |
|
for idx, out in enumerate( |
|
self.__call__( |
|
data(audio, vad_segments), |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
) |
|
): |
|
if print_progress: |
|
progress.update(1) |
|
text = out["text"] |
|
if batch_size in [0, 1, None]: |
|
text = text[0] |
|
segments.append( |
|
{ |
|
"text": text, |
|
"start": round(vad_segments[idx]["start"], 3), |
|
"end": round(vad_segments[idx]["end"], 3), |
|
"speaker": vad_segments[idx].get("speaker", None), |
|
} |
|
) |
|
|
|
|
|
if self.preset_language is None: |
|
self.tokenizer = None |
|
|
|
|
|
if self.suppress_numerals: |
|
self.options = self.options._replace( |
|
suppress_tokens=previous_suppress_tokens |
|
) |
|
|
|
return {"segments": segments, "language": language} |
|
|
|
|
|
def load_asr_model( |
|
whisper_arch: str, |
|
device: str, |
|
device_index: int = 0, |
|
compute_type: str = "float16", |
|
asr_options: Optional[dict] = None, |
|
language: Optional[str] = None, |
|
vad_model=None, |
|
vad_options=None, |
|
model: Optional[WhisperModel] = None, |
|
task: str = "transcribe", |
|
download_root: Optional[str] = None, |
|
threads: int = 4, |
|
) -> VadFreeFasterWhisperPipeline: |
|
""" |
|
Load a Whisper model for inference. |
|
|
|
Args: |
|
whisper_arch (str): The name of the Whisper model to load. |
|
device (str): The device to load the model on. |
|
device_index (int, optional): The device index. Defaults to 0. |
|
compute_type (str, optional): The compute type to use for the model. Defaults to "float16". |
|
asr_options (Optional[dict], optional): Options for ASR. Defaults to None. |
|
language (Optional[str], optional): The language of the model. Defaults to None. |
|
vad_model: The VAD model instance. Defaults to None. |
|
vad_options: Options for VAD. Defaults to None. |
|
model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None. |
|
task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe". |
|
download_root (Optional[str], optional): The root directory to download the model to. Defaults to None. |
|
threads (int, optional): The number of CPU threads to use per worker. Defaults to 4. |
|
|
|
Returns: |
|
VadFreeFasterWhisperPipeline: The loaded Whisper pipeline. |
|
|
|
Raises: |
|
ValueError: If the whisper architecture is not recognized. |
|
""" |
|
|
|
if whisper_arch.endswith(".en"): |
|
language = "en" |
|
|
|
model = model or WhisperModel( |
|
whisper_arch, |
|
device=device, |
|
device_index=device_index, |
|
compute_type=compute_type, |
|
download_root=download_root, |
|
cpu_threads=threads, |
|
) |
|
if language is not None: |
|
tokenizer = faster_whisper.tokenizer.Tokenizer( |
|
model.hf_tokenizer, |
|
model.model.is_multilingual, |
|
task=task, |
|
language=language, |
|
) |
|
else: |
|
print( |
|
"No language specified, language will be detected for each audio file (increases inference time)." |
|
) |
|
tokenizer = None |
|
|
|
default_asr_options = { |
|
"beam_size": 5, |
|
"best_of": 5, |
|
"patience": 1, |
|
"length_penalty": 1, |
|
"repetition_penalty": 1, |
|
"no_repeat_ngram_size": 0, |
|
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], |
|
"compression_ratio_threshold": 2.4, |
|
"log_prob_threshold": -1.0, |
|
"no_speech_threshold": 0.6, |
|
"condition_on_previous_text": False, |
|
"prompt_reset_on_temperature": 0.5, |
|
"initial_prompt": None, |
|
"prefix": None, |
|
"suppress_blank": True, |
|
"suppress_tokens": [-1], |
|
"without_timestamps": True, |
|
"max_initial_timestamp": 0.0, |
|
"word_timestamps": False, |
|
"prepend_punctuations": "\"'“¿([{-", |
|
"append_punctuations": "\"'.。,,!!??::”)]}、", |
|
"suppress_numerals": False, |
|
"max_new_tokens": None, |
|
"clip_timestamps": None, |
|
"hallucination_silence_threshold": None, |
|
} |
|
|
|
if asr_options is not None: |
|
default_asr_options.update(asr_options) |
|
|
|
suppress_numerals = default_asr_options["suppress_numerals"] |
|
del default_asr_options["suppress_numerals"] |
|
|
|
default_asr_options = faster_whisper.transcribe.TranscriptionOptions( |
|
**default_asr_options |
|
) |
|
|
|
return VadFreeFasterWhisperPipeline( |
|
model=model, |
|
options=default_asr_options, |
|
tokenizer=tokenizer, |
|
language=language, |
|
suppress_numerals=suppress_numerals, |
|
) |
|
|