Spaces:
Runtime error
Runtime error
from transformers import JukeboxModel , JukeboxTokenizer | |
from transformers.models.jukebox import convert_jukebox | |
import gradio as gr | |
import torch as t | |
model_id = 'openai/jukebox-1b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics'] | |
sample_rate = 44100 | |
total_duration_in_seconds = 200 | |
raw_to_tokens = 128 | |
chunk_size = 32 | |
max_batch_size = 16 | |
cache_path = '~/.cache/' | |
def tokens_to_seconds(tokens, level = 2): | |
global sample_rate, raw_to_tokens | |
return tokens * raw_to_tokens / sample_rate / 4 ** (2 - level) | |
def seconds_to_tokens(sec, level = 2): | |
global sample_rate, raw_to_tokens, chunk_size | |
tokens = sec * sample_rate // raw_to_tokens | |
tokens = ( (tokens // chunk_size) + 1 ) * chunk_size | |
# For levels 1 and 0, multiply by 4 and 16 respectively | |
tokens *= 4 ** (2 - level) | |
return int(tokens) | |
# Init is ran on server startup | |
# Load your model to GPU as a global variable here using the variable name "model" | |
def init(): | |
global model | |
print(f"Loading model from/to {cache_path}...") | |
model = JukeboxModel.from_pretrained( | |
model_id, | |
device_map = "auto", | |
torch_dtype = t.float16, | |
cache_dir = f"{cache_path}/jukebox/models", | |
resume_download = True, | |
min_duration = 0 | |
).eval() | |
print("Model loaded: ", model) | |
# Inference is ran for every server call | |
# Reference your preloaded global model variable here. | |
def inference(artist, genres, lyrics): | |
global model, zs | |
n_samples = 4 | |
generation_length = seconds_to_tokens(1) | |
offset = 0 | |
level = 0 | |
model.total_length = seconds_to_tokens(total_duration_in_seconds) | |
sampling_kwargs = dict( | |
temp = 0.98, | |
chunk_size = chunk_size, | |
) | |
metas = dict( | |
artist = artist, | |
genres = genres, | |
lyrics = lyrics, | |
) | |
labels = JukeboxTokenizer.from_pretrained(model_id)(**metas)['input_ids'][level].repeat(n_samples, 1).cuda() | |
print(f"Labels: {labels.shape}") | |
zs = [ t.zeros(n_samples, 0, dtype=t.long, device='cuda') for _ in range(3) ] | |
print(f"Zs: {[z.shape for z in zs]}") | |
zs = model.sample_partial_window( | |
zs, labels, offset, sampling_kwargs, level = level, tokens_to_sample = generation_length, max_batch_size = max_batch_size | |
) | |
print(f"Zs after sampling: {[z.shape for z in zs]}") | |
# Convert to numpy array | |
return zs.cpu().numpy() | |
with gr.Blocks() as ui: | |
# Define UI components | |
title = gr.Textbox(lines=1, label="Title") | |
artist = gr.Textbox(lines=1, label="Artist") | |
genres = gr.Textbox(lines=1, label="Genre(s)", placeholder="Separate with spaces") | |
lyrics = gr.Textbox(lines=5, label="Lyrics", placeholder="Shift+Enter for new line") | |
submit = gr.Button(label="Generate") | |
output_zs = gr.Dataframe(label="zs") | |
submit.click( | |
inference, | |
inputs = [ artist, genres, lyrics ], | |
outputs = output_zs, | |
) | |
if __name__ == "__main__": | |
init() | |
gr.launch() |