File size: 3,274 Bytes
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659a5e1
 
 
 
 
 
 
 
 
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659a5e1
 
631e673
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI

from transformers import JukeboxVQVAE

import gradio as gr
import torch as t

model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']

if 'google.colab' in sys.modules:

  cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"}
  # Connect to your Google Drive
  from google.colab import drive
  drive.mount('/content/drive')

else:

  cache_path = '~/.cache/'

class Convert:

  class TokenList:

    def to_tokens_file(tokens_list):
      # temporary random file name
      filename = f"tmp/{t.randint(0, 1000000)}.jt"
      t.save(validate_tokens_list(tokens_list), filename)
      return filename
    
    def to_audio(tokens_list):
      return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)
      # TODO: Implement converting other levels besides 2

  class TokensFile:

    def to_tokens_list(file):
      return validate_tokens_list(t.load(file))
    
    def to_audio(file):
      return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))
  
  class Audio:

    def to_tokens_list(audio):
      return model.encode(audio.unsqueeze(0), start_level=2)
      # (TODO: Generated by copilot, check if it works)

    def to_tokens_file(audio):
      return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))

def init():
  global model

  try:
    model
    print("Model already initialized")
  except NameError:
    model = JukeboxVQVAE.from_pretrained(
      model_id,
      torch_dtype = t.float16,
      cache_dir = f"{cache_path}/jukebox/models"
    )

def validate_tokens_list(tokens_list):
  # Make sure that:
  # - tokens_list is a list of exactly 3 torch tensors
  assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors"

  # - each has the same number of dimensions
  assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions"

  # - the shape along dimension 0 is the same
  assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list"

  # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2
  assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2"
  
  return tokens_list


with gr.Blocks() as ui:

  # File input to upload or download the music tokens file
  tokens = gr.File(label='music_tokens_file')

  # Audio output to play or upload the generated audio
  audio = gr.Audio(label='audio')
  
  # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
  gr.Button("Convert tokens to audio", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)
  gr.Button("Convert audio to tokens", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)
  
if __name__ == '__main__':
  init()
  ui.launch()