from transformers import JukeboxModel , JukeboxTokenizer from transformers.models.jukebox import convert_jukebox import gradio as gr import torch as t model_id = 'openai/jukebox-1b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics'] sample_rate = 44100 total_duration_in_seconds = 200 raw_to_tokens = 128 chunk_size = 32 max_batch_size = 16 cache_path = '~/.cache/' def tokens_to_seconds(tokens, level = 2): global sample_rate, raw_to_tokens return tokens * raw_to_tokens / sample_rate / 4 ** (2 - level) def seconds_to_tokens(sec, level = 2): global sample_rate, raw_to_tokens, chunk_size tokens = sec * sample_rate // raw_to_tokens tokens = ( (tokens // chunk_size) + 1 ) * chunk_size # For levels 1 and 0, multiply by 4 and 16 respectively tokens *= 4 ** (2 - level) return int(tokens) # Init is ran on server startup # Load your model to GPU as a global variable here using the variable name "model" def init(): global model print(f"Loading model from/to {cache_path}...") model = JukeboxModel.from_pretrained( model_id, device_map = "auto", torch_dtype = t.float16, cache_dir = f"{cache_path}/jukebox/models", resume_download = True, min_duration = 0 ).eval() print("Model loaded: ", model) # Inference is ran for every server call # Reference your preloaded global model variable here. def inference(artist, genres, lyrics): global model, zs n_samples = 4 generation_length = seconds_to_tokens(1) offset = 0 level = 0 model.total_length = seconds_to_tokens(total_duration_in_seconds) sampling_kwargs = dict( temp = 0.98, chunk_size = chunk_size, ) metas = dict( artist = artist, genres = genres, lyrics = lyrics, ) labels = JukeboxTokenizer.from_pretrained(model_id)(**metas)['input_ids'][level].repeat(n_samples, 1).cuda() print(f"Labels: {labels.shape}") zs = [ t.zeros(n_samples, 0, dtype=t.long, device='cuda') for _ in range(3) ] print(f"Zs: {[z.shape for z in zs]}") zs = model.sample_partial_window( zs, labels, offset, sampling_kwargs, level = level, tokens_to_sample = generation_length, max_batch_size = max_batch_size ) print(f"Zs after sampling: {[z.shape for z in zs]}") # Convert to numpy array return zs.cpu().numpy() with gr.Blocks() as ui: # Define UI components title = gr.Textbox(lines=1, label="Title") artist = gr.Textbox(lines=1, label="Artist") genres = gr.Textbox(lines=1, label="Genre(s)", placeholder="Separate with spaces") lyrics = gr.Textbox(lines=5, label="Lyrics", placeholder="Shift+Enter for new line") submit = gr.Button(label="Generate") output_zs = gr.Dataframe(label="zs") inference, inputs = [ artist, genres, lyrics ], outputs = output_zs, ) if __name__ == "__main__": init() gr.launch()