sheet-demo / app.py
unilight's picture
Update app.py
cad836c verified
import os
import torch.nn.functional as F
import torchaudio
from loguru import logger
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import yaml
# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'
SERVER_PORT = 42208
SERVER_NAME = "0.0.0.0"
SSL_DIR = './keyble_ssl'
FS = 16000
resamplers = {}
MIN_REQUIRED_WAV_LENGTH = 1040
# EXAMPLE_DIR = './examples'
# en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav')))
# jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav')))
# zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav')))
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Download models ----------
logger.info('============================= Download models ===========================')
model_paths = {
"SSL-MOS, all training sets": {
"ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"),
"config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"),
}
}
# ---------- Model ----------
models = {}
for name, path_dict in model_paths.items():
logger.info(f'============================= Setting up model for {name} =============')
checkpoint_path = path_dict["ckpt"]
config_path = path_dict["config"]
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.Loader)
if config["model_type"] == "SSLMOS":
from models.sslmos import SSLMOS
model = SSLMOS(
config["model_input"],
num_listeners=config.get("num_listeners", None),
num_domains=config.get("num_domains", None),
**config["model_params"],
).to(DEVICE)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"])
model = model.eval().to(DEVICE)
logger.info(f"Loaded model parameters from {checkpoint_path}.")
models[name] = model
def read_wav(wav_path):
# read waveform
waveform, sample_rate = torchaudio.load(
wav_path, channels_first=False
) # waveform: [T, 1]
# resample if needed
if sample_rate != FS:
resampler_key = f"{sample_rate}-{FS}"
if resampler_key not in resamplers:
resamplers[resampler_key] = torchaudio.transforms.Resample(
sample_rate, FS, dtype=waveform.dtype
)
waveform = resamplers[resampler_key](waveform)
waveform = waveform.squeeze(-1)
# always pad to a minumum length
if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH:
to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2
waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0)
return waveform, sample_rate
def predict(model_name, wav_file):
x, fs = read_wav(wav_file)
logger.info('wav file loaded')
# set up model input
model_input = x.unsqueeze(0).to(DEVICE)
model_lengths = model_input.new_tensor([model_input.size(1)]).long()
inputs = {
config["model_input"]: model_input,
config["model_input"] + "_lengths": model_lengths,
}
with torch.no_grad():
# model forward
if config["inference_mode"] == "mean_listener":
outputs = models[model_name].mean_listener_inference(inputs)
elif config["inference_mode"] == "mean_net":
outputs = models[model_name].mean_net_inference(inputs)
pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0]
return pred_mean_scores
with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo:
gr.Markdown(
"""
# Demo for SHEET: Speech Human Evaluation Estimation Toolkit
### [[Paper (arXiv)]](https://arxiv.org/abs/2411.03715) [[Code]](https://github.com/unilight/sheet)
**SHEET** is a subjective speech quality assessment (SSQA) toolkit designed to conduct SSQA research. It was specifically designed to interactive with MOS-Bench, a collective of datasets to benchmark SSQA models.
In this demo, you can record your own voice or upload speech files to assess the quality.
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("## Record your speech here!")
input_wav = gr.Audio(label="Input speech", type='filepath')
gr.Markdown("## Select a model!")
model_name = gr.Radio(label="Model", choices=list(model_paths.keys()))
evaluate_btn = gr.Button(value="Evaluate!")
# gr.Markdown("### You can use these examples if using a microphone is too troublesome!")
# gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.")
# gr.Examples(
# examples=en_examples,
# inputs=input_wav,
# label="English examples"
# )
# gr.Examples(
# examples=jp_examples,
# inputs=input_wav,
# label="Japanese examples"
# )
# gr.Examples(
# examples=zh_examples,
# inputs=input_wav,
# label="Mandarin examples"
# )
with gr.Column():
gr.Markdown("## The predicted scores is here:")
output_score = gr.Textbox(label="Prediction", interactive=False)
evaluate_btn.click(predict, [model_name, input_wav], output_score)
if __name__ == '__main__':
try:
demo.launch(debug=True)
except KeyboardInterrupt as e:
print(e)
finally:
demo.close()