|
|
|
|
|
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
import json |
|
import os |
|
import librosa |
|
import numpy as np |
|
import time |
|
import torch |
|
from pydub import AudioSegment |
|
import soundfile as sf |
|
import onnxruntime as ort |
|
import tqdm |
|
import subprocess |
|
import re |
|
|
|
from utils.logger import Logger, time_logger |
|
|
|
|
|
def load_cfg(cfg_path): |
|
""" |
|
Load configuration from a JSON file. |
|
|
|
Args: |
|
cfg_path (str): Path to the configuration file. |
|
|
|
Returns: |
|
dict: Configuration dictionary. |
|
""" |
|
if not os.path.exists(cfg_path): |
|
raise FileNotFoundError( |
|
f"{cfg_path} not found. Please copy, configure, and rename `config.json.example` to `{cfg_path}`." |
|
) |
|
with open(cfg_path, "r") as f: |
|
try: |
|
cfg = json.load(f) |
|
except json.decoder.JSONDecodeError as e: |
|
raise TypeError( |
|
"Please finish the `// TODO:` in the `config.json` file before running the script. Check README.md for details." |
|
) |
|
return cfg |
|
|
|
|
|
def write_wav(path, sr, x): |
|
"""Write numpy array to WAV file.""" |
|
sf.write(path, x, sr) |
|
|
|
|
|
def write_mp3(path, sr, x): |
|
"""Convert numpy array to MP3.""" |
|
try: |
|
|
|
if x.dtype != np.int16: |
|
|
|
x = np.int16(x / np.max(np.abs(x)) * 32767) |
|
|
|
|
|
audio = AudioSegment( |
|
x.tobytes(), frame_rate=sr, sample_width=x.dtype.itemsize, channels=1 |
|
) |
|
|
|
audio.export(path, format="mp3") |
|
except Exception as e: |
|
print(e) |
|
print("Error: Failed to write MP3 file.") |
|
|
|
|
|
def get_audio_files(folder_path): |
|
"""Get all audio files in a folder.""" |
|
audio_files = [] |
|
for root, _, files in os.walk(folder_path): |
|
if "_processed" in root: |
|
continue |
|
for file in files: |
|
if ".temp" in file: |
|
continue |
|
if file.endswith((".mp3", ".wav", ".flac", ".m4a", ".aac")): |
|
audio_files.append(os.path.join(root, file)) |
|
return audio_files |
|
|
|
|
|
def get_specific_files(folder_path, ext): |
|
"""Get specific files with a given extension in a folder.""" |
|
audio_files = [] |
|
for root, _, files in os.walk(folder_path): |
|
if "_processed" in root: |
|
continue |
|
for file in files: |
|
if ".temp" in file: |
|
continue |
|
if file.endswith(ext): |
|
audio_files.append(os.path.join(root, file)) |
|
return audio_files |
|
|
|
|
|
def export_to_srt(asr_result, file_path): |
|
"""Export ASR result to SRT file.""" |
|
with open(file_path, "w") as f: |
|
|
|
def format_time(seconds): |
|
return ( |
|
time.strftime("%H:%M:%S", time.gmtime(seconds)) |
|
+ f",{int(seconds * 1000 % 1000):03d}" |
|
) |
|
|
|
for idx, segment in enumerate(asr_result): |
|
f.write(f"{idx + 1}\n") |
|
f.write( |
|
f"{format_time(segment['start'])} --> {format_time(segment['end'])}\n" |
|
) |
|
f.write(f"{segment['speaker']}: {segment['text']}\n\n") |
|
|
|
|
|
def detect_gpu(): |
|
"""Detect if GPU is available and print related information.""" |
|
logger = Logger.get_logger() |
|
|
|
if "CUDA_VISIBLE_DEVICES" not in os.environ: |
|
logger.info("ENV: CUDA_VISIBLE_DEVICES not set, use default setting") |
|
else: |
|
gpu_id = os.environ["CUDA_VISIBLE_DEVICES"] |
|
logger.info(f"ENV: CUDA_VISIBLE_DEVICES = {gpu_id}") |
|
|
|
if not torch.cuda.is_available(): |
|
logger.error("Torch CUDA: No GPU detected. torch.cuda.is_available() = False.") |
|
return False |
|
|
|
num_gpus = torch.cuda.device_count() |
|
logger.debug(f"Torch CUDA: Detected {num_gpus} GPUs.") |
|
for i in range(num_gpus): |
|
gpu_name = torch.cuda.get_device_name(i) |
|
logger.debug(f" * GPU {i}: {gpu_name}") |
|
|
|
logger.debug("Torch: CUDNN version = " + str(torch.backends.cudnn.version())) |
|
if not torch.backends.cudnn.is_available(): |
|
logger.error("Torch: CUDNN is not available.") |
|
return False |
|
logger.debug("Torch: CUDNN is available.") |
|
|
|
ort_providers = ort.get_available_providers() |
|
logger.debug(f"ORT: Available providers: {ort_providers}") |
|
if "CUDAExecutionProvider" not in ort_providers: |
|
logger.warning( |
|
"ORT: CUDAExecutionProvider is not available. " |
|
"Please install a compatible version of ONNX Runtime. " |
|
"See https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html" |
|
) |
|
|
|
return True |
|
|
|
|
|
def get_gpu_nums(): |
|
"""Get GPU nums by nvidia-smi.""" |
|
logger = Logger.get_logger() |
|
try: |
|
result = subprocess.check_output("nvidia-smi -L | wc -l", shell=True) |
|
gpus_count = int(result.decode().strip()) |
|
except Exception as e: |
|
logger.error("Error occurred while getting GPU count: " + str(e)) |
|
gpus_count = 8 |
|
return gpus_count |
|
|
|
|
|
def check_env(logger): |
|
"""Check environment variables.""" |
|
if "http_proxy" in os.environ: |
|
logger.info(f"ENV: http_proxy = {os.environ['http_proxy']}") |
|
else: |
|
logger.info("ENV: http_proxy not set") |
|
|
|
if "https_proxy" in os.environ: |
|
logger.info(f"ENV: https_proxy = {os.environ['https_proxy']}") |
|
else: |
|
logger.info("ENV: https_proxy not set") |
|
|
|
if "HF_ENDPOINT" in os.environ: |
|
logger.info( |
|
f"ENV: HF_ENDPOINT = {os.environ['HF_ENDPOINT']}, if downloading slow, try `unset HF_ENDPOINT`" |
|
) |
|
else: |
|
logger.info("ENV: HF_ENDPOINT not set") |
|
|
|
hostname = os.popen("hostname").read().strip() |
|
logger.debug(f"HOSTNAME: {hostname}") |
|
|
|
environ_path = os.environ["PATH"] |
|
environ_ld_library = os.environ.get("LD_LIBRARY_PATH", "") |
|
logger.debug(f"ENV: PATH = {environ_path}, LD_LIBRARY_PATH = {environ_ld_library}") |
|
|
|
|
|
@time_logger |
|
def export_to_mp3(audio, asr_result, folder_path, file_name): |
|
"""Export segmented audio to MP3 files.""" |
|
sr = audio["sample_rate"] |
|
audio = audio["waveform"] |
|
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
|
|
|
def process_segment(idx, segment): |
|
start, end = int(segment["start"] * sr), int(segment["end"] * sr) |
|
split_audio = audio[start:end] |
|
split_audio = librosa.to_mono(split_audio) |
|
out_file = f"{file_name}_{idx}.mp3" |
|
out_path = os.path.join(folder_path, out_file) |
|
write_mp3(out_path, sr, split_audio) |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=72) as executor: |
|
|
|
futures = [ |
|
executor.submit(process_segment, idx, segment) |
|
for idx, segment in enumerate(asr_result) |
|
] |
|
|
|
|
|
for future in tqdm.tqdm( |
|
futures, total=len(asr_result), desc="Exporting to MP3" |
|
): |
|
future.result() |
|
|
|
|
|
@time_logger |
|
def export_to_wav(audio, asr_result, folder_path, file_name): |
|
"""Export segmented audio to WAV files.""" |
|
sr = audio["sample_rate"] |
|
audio = audio["waveform"] |
|
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
|
for idx, segment in enumerate(tqdm.tqdm(asr_result, desc="Exporting to WAV")): |
|
start, end = int(segment["start"] * sr), int(segment["end"] * sr) |
|
split_audio = audio[start:end] |
|
split_audio = librosa.to_mono(split_audio) |
|
out_file = f"{file_name}_{idx}.wav" |
|
out_path = os.path.join(folder_path, out_file) |
|
write_wav(out_path, sr, split_audio) |
|
|
|
|
|
def get_char_count(text): |
|
""" |
|
Get the number of characters in the text. |
|
|
|
Args: |
|
text (str): Input text. |
|
|
|
Returns: |
|
int: Number of characters in the text. |
|
""" |
|
|
|
cleaned_text = re.sub(r"[,.!?\"',。!?“”‘’ ]", "", text) |
|
char_count = len(cleaned_text) |
|
return char_count |
|
|
|
|
|
def calculate_audio_stats( |
|
data, min_duration=3, max_duration=30, min_dnsmos=3, min_char_count=2 |
|
): |
|
""" |
|
Reading the proviced json, calculate and return the audio ID and their duration that meet the given filtering criteria. |
|
|
|
Args: |
|
data: JSON. |
|
min_duration: Minimum duration of the audio in seconds. |
|
max_duration: Maximum duration of the audio in seconds. |
|
min_dnsmos: Minimum DNSMOS value. |
|
min_char_count: Minimum number of characters. |
|
|
|
Returns: |
|
valid_audio_stats: A list containing tuples of audio ID and their duration. |
|
""" |
|
all_audio_stats = [] |
|
valid_audio_stats = [] |
|
avg_durations = [] |
|
|
|
|
|
for entry in data: |
|
|
|
char_count = get_char_count(entry["text"]) |
|
duration = entry["end"] - entry["start"] |
|
if char_count > 0: |
|
avg_durations.append(duration / char_count) |
|
|
|
|
|
if len(avg_durations) > 0: |
|
q1 = np.percentile(avg_durations, 25) |
|
q3 = np.percentile(avg_durations, 75) |
|
iqr = q3 - q1 |
|
lower_bound = q1 - 1.5 * iqr |
|
upper_bound = q3 + 1.5 * iqr |
|
else: |
|
|
|
lower_bound, upper_bound = 0, np.inf |
|
|
|
|
|
for idx, entry in enumerate(data): |
|
duration = entry["end"] - entry["start"] |
|
dnsmos = entry["dnsmos"] |
|
|
|
char_count = get_char_count(entry["text"]) |
|
if char_count > 0: |
|
avg_char_duration = duration / char_count |
|
else: |
|
avg_char_duration = 0 |
|
|
|
|
|
all_audio_stats.append((idx, duration)) |
|
|
|
|
|
if ( |
|
(min_duration <= duration <= max_duration) |
|
and (dnsmos >= min_dnsmos) |
|
and (char_count >= min_char_count) |
|
and ( |
|
lower_bound <= avg_char_duration <= upper_bound |
|
) |
|
): |
|
valid_audio_stats.append((idx, duration)) |
|
|
|
return valid_audio_stats, all_audio_stats |
|
|