ConsistencyTTA / run_gradio.py
Bai-YT's picture
Update run_gradio.py
83cac0f verified
raw
history blame
2.93 kB
import torch
import gradio as gr
import soundfile as sf
import numpy as np
import random, os
import spaces
from consistencytta import ConsistencyTTA
def seed_all(seed):
""" Seed all random number generators. """
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.random.manual_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
device = torch.device(
"cuda:0" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu"
)
sr = 16000
# Build ConsistencyTTA model
consistencytta = ConsistencyTTA().to(device)
consistencytta.eval()
consistencytta.requires_grad_(False)
@spaces.GPU()
def generate(prompt: str, seed: str = '', cfg_weight: float = 4.):
""" Generate audio from a given prompt.
Args:
prompt (str): Text prompt to generate audio from.
seed (str, optional): Random seed. Defaults to '', which means no seed.
"""
if seed != '':
try:
seed_all(int(seed))
except:
pass
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
wav = consistencytta(
[prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr
)
sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16')
return "output.wav"
# Generate test audio
print("Generating test audio...")
generate("A dog barks as a train passes by.", seed=1)
print("Test audio generated successfully! Starting Gradio interface...")
# Launch Gradio interface
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(
label="Text", value="Several people cheer and scream and speak as water flows hard."
),
gr.Textbox(label="Random Seed (Optional)", value=''),
gr.Slider(
minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength"
)],
outputs="audio",
title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \
"Generation with Consistency Distillation",
description="This is the official demo page for <a href='https://consistency-tta.github." \
"io' target=&ldquo;blank&rdquo;>ConsistencyTTA</a>, a model that accelerates " \
"diffusion-based text-to-audio generation hundreds of times with consistency " \
"models. <br> Here, the audio is generated within a single non-autoregressive " \
"forward pass from the CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \
"the training dataset does not include speech, the model is not expected to " \
"generate coherent speech. <br> Have fun!"
)
iface.launch()