# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI import sys import torch as t from transformers import JukeboxVQVAE import gradio as gr model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics'] if 'google.colab' in sys.modules: cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"} # Connect to your Google Drive from google.colab import drive drive.mount('/content/drive') else: cache_path = '~/.cache/' class Convert: class TokenList: def to_tokens_file(tokens_list): # temporary random file name filename = f"tmp/{t.randint(0, 1000000)}.jt" t.save(validate_tokens_list(tokens_list), filename) return filename def to_audio(tokens_list): return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1) # TODO: Implement converting other levels besides 2 class TokensFile: def to_tokens_list(file): return validate_tokens_list(t.load(file)) def to_audio(file): return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file)) class Audio: def to_tokens_list(audio): return model.encode(audio.unsqueeze(0), start_level=2) # (TODO: Generated by copilot, check if it works) def to_tokens_file(audio): return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio)) def init(): global model try: model print("Model already initialized") except NameError: model = JukeboxVQVAE.from_pretrained( model_id, torch_dtype = t.float16, cache_dir = f"{cache_path}/jukebox/models" ) def validate_tokens_list(tokens_list): # Make sure that: # - tokens_list is a list of exactly 3 torch tensors assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors" # - each has the same number of dimensions assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions" # - the shape along dimension 0 is the same assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list" # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2 assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2" return tokens_list with gr.Blocks() as ui: # File input to upload or download the music tokens file tokens = gr.File(label='music_tokens_file') # Audio output to play or upload the generated audio audio = gr.Audio(label='audio') # Buttons to convert from music tokens to audio (primary) and vice versa (secondary) gr.Button("Convert tokens to audio", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio) gr.Button("Convert audio to tokens", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens) if __name__ == '__main__': init() ui.launch()