File size: 2,987 Bytes
66982e9
 
 
 
 
b095b9d
66982e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8310070
66982e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305d037
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import gradio as gr
import soundfile as sf
import numpy as np
import random, os
import spaces

from consistencytta import ConsistencyTTA


def seed_all(seed):
    """ Seed all random number generators. """
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.random.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


device = torch.device(
    "cuda:0" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else "cpu"
)
sr = 16000

# Build ConsistencyTTA model
consistencytta = ConsistencyTTA().to(device)
consistencytta.eval()
consistencytta.requires_grad_(False)


@spaces.GPU()
def generate(prompt: str, seed: str = '', cfg_weight: float = 4.):
    """ Generate audio from a given prompt.
    Args:
        prompt (str): Text prompt to generate audio from.
        seed (str, optional): Random seed. Defaults to '', which means no seed.
    """
    if seed != '':
        try:
            seed_all(int(seed))
        except:
            pass

    with torch.no_grad():
        with torch.autocast(
            device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()
        ):
            wav = consistencytta(
                [prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr
            )
        sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16')

    return "output.wav"


# Generate test audio
print("Generating test audio...")
generate("A dog barks as a train passes by.", seed=1)
print("Test audio generated successfully! Starting Gradio interface...")

# Launch Gradio interface
iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(
            label="Text", value="Several people cheer and scream and speak as water flows hard."
        ),
        gr.Textbox(label="Random Seed (Optional)", value=''),
        gr.Slider(
            minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength"
        )],
    outputs="audio",
    title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \
          "Generation with Consistency Distillation",
    description="This is the official demo page for <a href='https://consistency-tta.github." \
                "io' target=&ldquo;blank&rdquo;>ConsistencyTTA</a>, a model that accelerates " \
                "diffusion-based text-to-audio generation hundreds of times with consistency " \
                "models. <br> Here, the audio is generated within a single non-autoregressive " \
                "forward pass from the  CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \
                "the training dataset does not include speech, the model is not expected to " \
                "generate coherent speech. <br> Have fun!"
)
iface.launch()