kotoba-speech / fam /llm /enhancers.py
yuta0306
first commit
565faca
raw
history blame
2.95 kB
import os
from abc import ABC
from typing import Literal, Optional
from df.enhance import enhance, init_df, load_audio, save_audio
from pydub import AudioSegment
def convert_to_wav(input_file: str, output_file: str):
"""Convert an audio file to WAV format
Args:
input_file (str): path to input audio file
output_file (str): path to output WAV file
"""
# Detect the format of the input file
format = input_file.split(".")[-1].lower()
# Read the audio file
audio = AudioSegment.from_file(input_file, format=format)
# Export as WAV
audio.export(output_file, format="wav")
def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
"""Generate the output file path
Args:
audio_file (str): path to input audio file
tag (str): tag to append to the output file name
ext (str, optional): extension of the output file. Defaults to None.
Returns:
str: path to output file
"""
directory = "./enhanced"
# Get the name of the input file
filename = os.path.basename(audio_file)
# Get the name of the input file without the extension
filename_without_extension = os.path.splitext(filename)[0]
# Get the extension of the input file
extension = ext or os.path.splitext(filename)[1]
# Generate the output file path
output_file = os.path.join(directory, filename_without_extension + tag + extension)
return output_file
class BaseEnhancer(ABC):
"""Base class for audio enhancers"""
def __init__(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
raise NotImplementedError
def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
output_file = make_output_file_path(audio_file, tag, ext=ext)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
return output_file
class DFEnhancer(BaseEnhancer):
def __init__(self, *args, **kwargs):
self.model, self.df_state, _ = init_df()
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
output_file = output_file or self.get_output_file(audio_file, "_df")
audio, _ = load_audio(audio_file, sr=self.df_state.sr())
enhanced = enhance(self.model, self.df_state, audio)
save_audio(output_file, enhanced, self.df_state.sr())
return output_file
def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
"""Get an audio enhancer
Args:
enhancer_name (Literal["df"]): name of the audio enhancer
Raises:
ValueError: if the enhancer name is not recognised
Returns:
BaseEnhancer: audio enhancer
"""
if enhancer_name == "df":
return DFEnhancer()
else:
raise ValueError(f"Unknown enhancer name: {enhancer_name}")