|
from abc import ABC, abstractmethod |
|
from collections import Counter |
|
from typing import Any, Iterator, List, Dict |
|
|
|
from pprint import pprint |
|
|
|
|
|
try: |
|
import tensorflow as tf |
|
except ModuleNotFoundError: |
|
|
|
pass |
|
|
|
import torch |
|
|
|
import ffmpeg |
|
import numpy as np |
|
|
|
from src.utils import format_timestamp |
|
|
|
|
|
|
|
|
|
SPEECH_TRESHOLD = 0.3 |
|
MAX_SILENT_PERIOD = 10 |
|
MAX_MERGE_SIZE = 150 |
|
|
|
|
|
SEGMENT_PADDING_LEFT = 1 |
|
SEGMENT_PADDING_RIGHT = 1 |
|
|
|
|
|
TRANSCRIBE_NON_SPEECH = False |
|
|
|
|
|
MIN_SEGMENT_DURATION = 1 |
|
|
|
VAD_MAX_PROCESSING_CHUNK = 60 * 60 |
|
|
|
class AbstractTranscription(ABC): |
|
def __init__(self, segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None, max_merge_size: float = None, transcribe_non_speech: bool = False): |
|
self.sampling_rate = 16000 |
|
self.segment_padding_left = segment_padding_left |
|
self.segment_padding_right = segment_padding_right |
|
self.max_silent_period = max_silent_period |
|
self.max_merge_size = max_merge_size |
|
self.transcribe_non_speech = transcribe_non_speech |
|
|
|
def get_audio_segment(self, str, start_time: str = None, duration: str = None): |
|
return load_audio(str, self.sampling_rate, start_time, duration) |
|
|
|
@abstractmethod |
|
def get_transcribe_timestamps(self, audio: str): |
|
""" |
|
Get the start and end timestamps of the sections that should be transcribed by this VAD method. |
|
|
|
Parameters |
|
---------- |
|
audio: str |
|
The audio file. |
|
|
|
Returns |
|
------- |
|
A list of start and end timestamps, in fractional seconds. |
|
""" |
|
return |
|
|
|
def transcribe(self, audio: str, whisperCallable): |
|
""" |
|
Transcribe the given audo file. |
|
|
|
Parameters |
|
---------- |
|
audio: str |
|
The audio file. |
|
|
|
whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor]], dict[str, Union[dict, Any]]] |
|
The callback that is used to invoke Whisper on an audio file/buffer. |
|
|
|
Returns |
|
------- |
|
A list of start and end timestamps, in fractional seconds. |
|
""" |
|
|
|
|
|
seconds_timestamps = self.get_transcribe_timestamps(audio) |
|
|
|
padded = self.pad_timestamps(seconds_timestamps, self.segment_padding_left, self.segment_padding_right) |
|
merged = self.merge_timestamps(padded, self.max_silent_period, self.max_merge_size) |
|
|
|
print("Timestamps:") |
|
pprint(merged) |
|
|
|
if self.transcribe_non_speech: |
|
max_audio_duration = get_audio_duration(audio) |
|
|
|
|
|
merged = self.expand_gaps(merged, total_duration=max_audio_duration) |
|
|
|
print("Transcribing non-speech:") |
|
pprint(merged) |
|
|
|
result = { |
|
'text': "", |
|
'segments': [], |
|
'language': "" |
|
} |
|
languageCounter = Counter() |
|
|
|
|
|
for segment in merged: |
|
segment_start = segment['start'] |
|
segment_end = segment['end'] |
|
segment_expand_amount = segment.get('expand_amount', 0) |
|
|
|
segment_duration = segment_end - segment_start |
|
|
|
if segment_duration < MIN_SEGMENT_DURATION: |
|
continue; |
|
|
|
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration)) |
|
|
|
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", segment_duration, "expanded: ", segment_expand_amount) |
|
segment_result = whisperCallable(segment_audio) |
|
|
|
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration) |
|
|
|
|
|
result['text'] += segment_result['text'] |
|
result['segments'].extend(adjusted_segments) |
|
|
|
|
|
languageCounter[segment_result['language']] += 1 |
|
|
|
if len(languageCounter) > 0: |
|
result['language'] = languageCounter.most_common(1)[0][0] |
|
|
|
return result |
|
|
|
def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float): |
|
result = [] |
|
last_end_time = 0 |
|
|
|
for segment in segments: |
|
segment_start = float(segment['start']) |
|
segment_end = float(segment['end']) |
|
|
|
if (last_end_time != segment_start): |
|
delta = segment_start - last_end_time |
|
|
|
if (min_gap_length is None or delta >= min_gap_length): |
|
result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } ) |
|
|
|
last_end_time = segment_end |
|
result.append(segment) |
|
|
|
|
|
if (total_duration is not None and last_end_time < total_duration): |
|
delta = total_duration - segment_start |
|
|
|
if (min_gap_length is None or delta >= min_gap_length): |
|
result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } ) |
|
|
|
return result |
|
|
|
|
|
def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float): |
|
result = [] |
|
|
|
if len(segments) == 0: |
|
return result |
|
|
|
|
|
if (segments[0]['start'] > 0): |
|
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } ) |
|
|
|
for i in range(len(segments) - 1): |
|
current_segment = segments[i] |
|
next_segment = segments[i + 1] |
|
|
|
delta = next_segment['start'] - current_segment['end'] |
|
|
|
|
|
if (delta >= 0): |
|
current_segment = current_segment.copy() |
|
current_segment['expand_amount'] = delta |
|
current_segment['end'] = next_segment['start'] |
|
|
|
result.append(current_segment) |
|
|
|
|
|
last_segment = segments[-1] |
|
result.append(last_segment) |
|
|
|
|
|
if (total_duration is not None): |
|
last_segment = result[-1] |
|
|
|
if (last_segment['end'] < total_duration): |
|
last_segment = last_segment.copy() |
|
last_segment['end'] = total_duration |
|
result[-1] = last_segment |
|
|
|
return result |
|
|
|
def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None): |
|
result = [] |
|
|
|
for segment in segments: |
|
segment_start = float(segment['start']) |
|
segment_end = float(segment['end']) |
|
|
|
|
|
if (max_source_time is not None): |
|
if (segment_start > max_source_time): |
|
continue |
|
segment_end = min(max_source_time, segment_end) |
|
|
|
new_segment = segment.copy() |
|
|
|
|
|
new_segment['start'] = segment_start + adjust_seconds |
|
new_segment['end'] = segment_end + adjust_seconds |
|
result.append(new_segment) |
|
return result |
|
|
|
def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float): |
|
if (padding_left == 0 and padding_right == 0): |
|
return timestamps |
|
|
|
result = [] |
|
prev_entry = None |
|
|
|
for i in range(len(timestamps)): |
|
curr_entry = timestamps[i] |
|
next_entry = timestamps[i + 1] if i < len(timestamps) - 1 else None |
|
|
|
segment_start = curr_entry['start'] |
|
segment_end = curr_entry['end'] |
|
|
|
if padding_left is not None: |
|
segment_start = max(prev_entry['end'] if prev_entry else 0, segment_start - padding_left) |
|
if padding_right is not None: |
|
segment_end = segment_end + padding_right |
|
|
|
|
|
if (next_entry is not None): |
|
segment_end = min(next_entry['start'], segment_end) |
|
|
|
new_entry = { 'start': segment_start, 'end': segment_end } |
|
prev_entry = new_entry |
|
result.append(new_entry) |
|
|
|
return result |
|
|
|
def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_merge_gap: float, max_merge_size: float): |
|
if max_merge_gap is None: |
|
return timestamps |
|
|
|
result = [] |
|
current_entry = None |
|
|
|
for entry in timestamps: |
|
if current_entry is None: |
|
current_entry = entry |
|
continue |
|
|
|
|
|
distance = entry['start'] - current_entry['end'] |
|
current_entry_size = current_entry['end'] - current_entry['start'] |
|
|
|
if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size): |
|
|
|
current_entry['end'] = entry['end'] |
|
else: |
|
|
|
result.append(current_entry) |
|
current_entry = entry |
|
|
|
|
|
if current_entry is not None: |
|
result.append(current_entry) |
|
|
|
return result |
|
|
|
def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float): |
|
result = [] |
|
|
|
for entry in timestamps: |
|
start = entry['start'] |
|
end = entry['end'] |
|
|
|
result.append({ |
|
'start': start * factor, |
|
'end': end * factor |
|
}) |
|
return result |
|
|
|
class VadSileroTranscription(AbstractTranscription): |
|
def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT, |
|
max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, transcribe_non_speech: bool = False, |
|
copy = None): |
|
super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right, |
|
max_silent_period=max_silent_period, max_merge_size=max_merge_size, transcribe_non_speech=transcribe_non_speech) |
|
|
|
if copy: |
|
self.model = copy.model |
|
self.get_speech_timestamps = copy.get_speech_timestamps |
|
else: |
|
self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad') |
|
(self.get_speech_timestamps, _, _, _, _) = utils |
|
|
|
def get_transcribe_timestamps(self, audio: str): |
|
audio_duration = get_audio_duration(audio) |
|
result = [] |
|
|
|
|
|
chunk_start = 0.0 |
|
|
|
while (chunk_start < audio_duration): |
|
chunk_duration = min(audio_duration - chunk_start, VAD_MAX_PROCESSING_CHUNK) |
|
|
|
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration))) |
|
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration)) |
|
|
|
sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD) |
|
seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate) |
|
adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration) |
|
|
|
|
|
|
|
result.extend(adjusted) |
|
chunk_start += chunk_duration |
|
|
|
return result |
|
|
|
|
|
class VadPeriodicTranscription(AbstractTranscription): |
|
def __init__(self, periodic_duration: float): |
|
super().__init__() |
|
self.periodic_duration = periodic_duration |
|
|
|
def get_transcribe_timestamps(self, audio: str): |
|
|
|
audio_duration = get_audio_duration(audio) |
|
result = [] |
|
|
|
|
|
start_timestamp = 0 |
|
|
|
while (start_timestamp < audio_duration): |
|
end_timestamp = min(start_timestamp + self.periodic_duration, audio_duration) |
|
segment_duration = end_timestamp - start_timestamp |
|
|
|
|
|
if (segment_duration >= 1): |
|
result.append( { 'start': start_timestamp, 'end': end_timestamp } ) |
|
|
|
start_timestamp = end_timestamp |
|
|
|
return result |
|
|
|
def get_audio_duration(file: str): |
|
return float(ffmpeg.probe(file)["format"]["duration"]) |
|
|
|
def load_audio(file: str, sample_rate: int = 16000, |
|
start_time: str = None, duration: str = None): |
|
""" |
|
Open an audio file and read as mono waveform, resampling as necessary |
|
|
|
Parameters |
|
---------- |
|
file: str |
|
The audio file to open |
|
|
|
sr: int |
|
The sample rate to resample the audio if necessary |
|
|
|
start_time: str |
|
The start time, using the standard FFMPEG time duration syntax, or None to disable. |
|
|
|
duration: str |
|
The duration, using the standard FFMPEG time duration syntax, or None to disable. |
|
|
|
Returns |
|
------- |
|
A NumPy array containing the audio waveform, in float32 dtype. |
|
""" |
|
try: |
|
inputArgs = {'threads': 0} |
|
|
|
if (start_time is not None): |
|
inputArgs['ss'] = start_time |
|
if (duration is not None): |
|
inputArgs['t'] = duration |
|
|
|
|
|
|
|
out, _ = ( |
|
ffmpeg.input(file, **inputArgs) |
|
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate) |
|
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True) |
|
) |
|
except ffmpeg.Error as e: |
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") |
|
|
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |