AudioLCM / app.py
liuhuadai's picture
Upload 340 files
6efc863 verified
raw
history blame
1.48 kB
import gradio
def infer(prompt):
config = OmegaConf.load("configs/audiolcm.yaml")
# print("-------quick debug no load ckpt---------")
# model = instantiate_from_config(config['model'])# for quick debug
model = load_model_from_config(config,
"../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = LCMSampler(model)
os.makedirs("results/test", exist_ok=True)
vocoder = VocoderBigVGAN("../vocoder/logs/bigvnat16k93.5w", device)
generator = GenSamples(sampler, model, "results/test", vocoder, save_mel=False, save_wav=True,
original_inference_steps=config.model.params.num_ddim_timesteps)
csv_dicts = []
with torch.no_grad():
with model.ema_scope():
wav_name = f'{prompt.strip().replace(" ", "-")}'
generator.gen_test_sample(prompt, wav_name=wav_name)
print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.")
def my_inference_function(prompt_oir):
prompt = {'ori_caption':prompt_oir,'struct_caption':prompt_oir}
file_path = infer(prompt)
return "test.wav"
gradio_interface = gradio.Interface(
fn = my_inference_function,
inputs = "text",
outputs = "audio"
)
gradio_interface.launch()