reach-vb's picture
reach-vb HF staff
Update app.py
e360f5c verified
raw
history blame
2.74 kB
import torch
import numpy as np
import gradio as gr
import soundfile as sf
import tempfile
from transformers import pipeline
from huggingface_hub import InferenceClient
def _grab_best_device(use_gpu=True):
if torch.cuda.device_count() > 0 and use_gpu:
device = "cuda"
else:
device = "cpu"
return device
device = _grab_best_device()
title = """# MusicGen Prompt Upsampling 🎢
MusicGen, a simple and controllable model for music generation.
**Model**: https://huggingface.co/facebook/musicgen-stereo-medium
"""
vibes = pipeline("text-to-audio",
"facebook/musicgen-stereo-medium",
torch_dtype=torch.float16,
device="cuda")
client = InferenceClient(model="mistralai/Mixtral-8x7B-Instruct-v0.1",)
# Inference
def generate_audio(text,):
prompt = f"Take the next sentence and enrich it with details. Keep it compact. {text}"
output = client.text_generation(prompt, max_new_tokens=250)
out = vibes(output)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, out["audio"][0].T, out["sampling_rate"])
return f.name, output
css = """
#container{
margin: 0 auto;
max-width: 80rem;
}
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
# Gradio blocks demo
with gr.Blocks(css=css) as demo_blocks:
gr.Markdown(title, elem_id="intro")
with gr.Row(elem_id="container"):
with gr.Column():
inp_text = gr.Textbox(label="Input Prompt", info="What would you like MusicGen to synthesise?")
btn = gr.Button("Generate Music! 🎢")
with gr.Column():
out = gr.Audio(autoplay=False, label=f"Generated Music", show_label=True,)
prompt_text = gr.Textbox(label="Upsampled Prompt")
with gr.Accordion("Use MusicGen with Transformers πŸ€—", open=False):
gr.Markdown(
"""
```python
import torch
import soundfile as sf
from transformers import pipeline
synthesiser = pipeline("text-to-audio",
"facebook/musicgen-stereo-medium",
device="cuda:0",
torch_dtype=torch.float16)
music = synthesiser("lo-fi music with a soothing melody",
forward_params={"max_new_tokens": 256})
sf.write("musicgen_out.wav", music["audio"][0].T, music["sampling_rate"])
```
"""
)
btn.click(generate_audio, inp_text, [out, prompt_text])
demo_blocks.queue().launch()