Spaces:
Starting
on
A10G
Starting
on
A10G
File size: 3,065 Bytes
ea270ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
#!/usr/bin/python3
import os
import torch
import logging
from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list
import argparse
os.environ["TOKENIZERS_PARALLELISM"] = "true"
matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)
CACHE_DIR = os.getenv(
"AUDIOLDM_CACHE_DIR",
os.path.join(os.path.expanduser("~"), ".cache/audioldm2"))
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--text",
type=str,
required=False,
default="",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"-tl",
"--text_list",
type=str,
required=False,
default="",
help="A file that contains text prompt to the model for audio generation",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="audioldm2-full",
choices=["audioldm2-full"]
)
parser.add_argument(
"-b",
"--batchsize",
type=int,
required=False,
default=1,
help="Generate how many samples at the same time",
)
parser.add_argument(
"--ddim_steps",
type=int,
required=False,
default=200,
help="The sampling step for DDIM",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=3.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"-n",
"--n_candidate_gen_per_text",
type=int,
required=False,
default=3,
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=0,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
torch.set_float32_matmul_precision("high")
save_path = os.path.join(args.save_path, get_time())
text = args.text
random_seed = args.seed
duration = 10
guidance_scale = args.guidance_scale
n_candidate_gen_per_text = args.n_candidate_gen_per_text
os.makedirs(save_path, exist_ok=True)
audioldm2 = build_model(model_name=args.model_name)
if(args.text_list):
print("Generate audio based on the text prompts in %s" % args.text_list)
prompt_todo = read_list(args.text_list)
else:
prompt_todo = [text]
for text in prompt_todo:
waveform = text_to_audio(
audioldm2,
text,
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=args.batchsize,
)
save_wave(waveform, save_path, name=text)
|