Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import soundfile as sf | |
import numpy as np | |
import random, os | |
from consistencytta import ConsistencyTTA | |
def seed_all(seed): | |
""" Seed all random number generators. """ | |
seed = int(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.cuda.random.manual_seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
device = torch.device( | |
"cuda:0" if torch.cuda.is_available() else | |
"mps" if torch.backends.mps.is_available() else "cpu" | |
) | |
sr = 16000 | |
# Build ConsistencyTTA model | |
consistencytta = ConsistencyTTA().to(device) | |
consistencytta.eval() | |
consistencytta.requires_grad_(False) | |
def generate(prompt: str, seed: str = '', cfg_weight: float = 4.): | |
""" Generate audio from a given prompt. | |
Args: | |
prompt (str): Text prompt to generate audio from. | |
seed (str, optional): Random seed. Defaults to '', which means no seed. | |
""" | |
if seed != '': | |
try: | |
seed_all(int(seed)) | |
except: | |
pass | |
with torch.no_grad(): | |
with torch.autocast( | |
device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available() | |
): | |
wav = consistencytta( | |
[prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr | |
) | |
sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16') | |
return "output.wav" | |
# Generate test audio | |
print("Generating test audio...") | |
generate("A dog barks as a train passes by.", seed=1) | |
print("Test audio generated successfully! Starting Gradio interface...") | |
# Launch Gradio interface | |
iface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Textbox( | |
label="Text", value="Several people cheer and scream and speak as water flows hard." | |
), | |
gr.Textbox(label="Random Seed (Optional)", value=''), | |
gr.Slider( | |
minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength" | |
)], | |
outputs="audio", | |
title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \ | |
"Generation with Consistency Distillation", | |
description="This is the official demo page for <a href='https://consistency-tta.github." \ | |
"io' target=“blank”>ConsistencyTTA</a>, a model that accelerates " \ | |
"diffusion-based text-to-audio generation hundreds of times with consistency " \ | |
"models. <br> Here, the audio is generated within a single non-autoregressive " \ | |
"forward pass from the CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \ | |
"the training dataset does not include speech, the model is not expected to " \ | |
"generate coherent speech. <br> Have fun!" | |
) | |
iface.launch(share=True) | |