vovahimself commited on
Commit
631e673
1 Parent(s): bf8bb24
Files changed (4) hide show
  1. app.py +93 -0
  2. app_to_colab.py +83 -0
  3. jukwi-vqvae.ipynb +140 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI
2
+
3
+ from transformers import JukeboxVQVAE
4
+
5
+ import gradio as gr
6
+ import torch as t
7
+
8
+ model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']
9
+
10
+ if 'google.colab' in sys.modules:
11
+
12
+ cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"}
13
+ # Connect to your Google Drive
14
+ from google.colab import drive
15
+ drive.mount('/content/drive')
16
+
17
+ else:
18
+
19
+ cache_path = '~/.cache/'
20
+
21
+ class Convert:
22
+
23
+ class TokenList:
24
+
25
+ def to_tokens_file(tokens_list):
26
+ # temporary random file name
27
+ filename = f"tmp/{t.randint(0, 1000000)}.jt"
28
+ t.save(validate_tokens_list(tokens_list), filename)
29
+ return filename
30
+
31
+ def to_audio(tokens_list):
32
+ return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)
33
+ # TODO: Implement converting other levels besides 2
34
+
35
+ class TokensFile:
36
+
37
+ def to_tokens_list(file):
38
+ return validate_tokens_list(t.load(file))
39
+
40
+ def to_audio(file):
41
+ return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))
42
+
43
+ class Audio:
44
+
45
+ def to_tokens_list(audio):
46
+ return model.encode(audio.unsqueeze(0), start_level=2)
47
+ # (TODO: Generated by copilot, check if it works)
48
+
49
+ def to_tokens_file(audio):
50
+ return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))
51
+
52
+ def init():
53
+ global model
54
+
55
+ model = JukeboxVQVAE.from_pretrained(
56
+ model_id,
57
+ device_map = "auto",
58
+ torch_dtype = t.float16,
59
+ cache_dir = f"{cache_path}/jukebox/models"
60
+ )
61
+
62
+ def validate_tokens_list(tokens_list):
63
+ # Make sure that:
64
+ # - tokens_list is a list of exactly 3 torch tensors
65
+ assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors"
66
+
67
+ # - each has the same number of dimensions
68
+ assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions"
69
+
70
+ # - the shape along dimension 0 is the same
71
+ assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list"
72
+
73
+ # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2
74
+ assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2"
75
+
76
+ return tokens_list
77
+
78
+
79
+ with gr.Blocks() as ui:
80
+
81
+ # File input to upload or download the music tokens file
82
+ tokens = gr.File(label='music_tokens_file')
83
+
84
+ # Audio output to play or upload the generated audio
85
+ audio = gr.Audio(label='audio')
86
+
87
+ # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
88
+ gr.Button(label="Convert tokens to audio", primary=True).click(Convert.TokensFile.to_audio, tokens, audio)
89
+ gr.Button(label="Convert audio to tokens", primary=False).click(Convert.Audio.to_tokens_file, audio, tokens)
90
+
91
+ if __name__ == '__main__':
92
+ init()
93
+ ui.launch()
app_to_colab.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script to create a notebook out of `requirements.txt` (installing the dependencies) and `app.py`
2
+
3
+ import json
4
+ from textwrap import dedent
5
+
6
+ def create_colab(requirements_file='requirements.txt', app_file='app.py'):
7
+
8
+ # cells = []
9
+
10
+ requirements_txt = open('requirements.txt', 'r').read().replace('\n', '\\n')
11
+
12
+ def text_to_cell(text, cell_type='code'):
13
+
14
+ lines = dedent(text).splitlines()
15
+
16
+ # add a \n to the end of each line except the last one
17
+ lines = [ f'{line}\n' for line in lines[:-1] ] + [ lines[-1] ]
18
+
19
+ return dict(
20
+ metadata={},
21
+ execution_count=None,
22
+ outputs=[],
23
+ cell_type=cell_type,
24
+ source=lines
25
+ )
26
+
27
+ cells = [
28
+
29
+ # Cell to mount drive, install the dependencies etc.
30
+ text_to_cell(f"""\
31
+ from google.colab import drive
32
+ mount_drive = True #@param {{type:"boolean"}}
33
+ if mount_drive:
34
+ drive.mount('/content/drive')
35
+
36
+ requirements_txt = "{requirements_txt}"
37
+
38
+ # Save the requirements.txt file
39
+ with open('requirements.txt', 'w') as f:
40
+ f.write(requirements_txt)
41
+
42
+ # Install the dependencies
43
+ %pip install -r requirements.txt
44
+ """),
45
+
46
+ # Cell to run the app
47
+ text_to_cell(open(app_file, 'r').read())
48
+
49
+ ]
50
+
51
+ # Add notebook metadata
52
+ metadata = dict(
53
+ kernelspec = dict(
54
+ display_name = 'Python 3',
55
+ language = 'python',
56
+ name = 'python3'
57
+ ),
58
+ language_info = dict(
59
+ name = 'python',
60
+ version = '3.7.5',
61
+ ),
62
+ orig_nbformat = 4,
63
+ )
64
+
65
+ # Finalize the notebook
66
+ notebook = dict(
67
+ cells=cells,
68
+ metadata=metadata,
69
+ nbformat=4,
70
+ nbformat_minor=2,
71
+ )
72
+
73
+ # Name the notebook the same as the parent directory
74
+ from pathlib import Path
75
+ notebook_name = f'{Path().absolute().name}.ipynb'
76
+
77
+ # Save the notebook in JSON format
78
+ open(notebook_name, 'w').write(json.dumps(notebook, indent=2))
79
+
80
+ return notebook
81
+
82
+ if __name__ == '__main__':
83
+ create_colab()
jukwi-vqvae.ipynb ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "execution_count": null,
6
+ "outputs": [],
7
+ "cell_type": "code",
8
+ "source": [
9
+ "from google.colab import drive\n",
10
+ "mount_drive = True #@param {type:\"boolean\"}\n",
11
+ "if mount_drive:\n",
12
+ " drive.mount('/content/drive')\n",
13
+ "\n",
14
+ "requirements_txt = \"git+https://github.com/ArthurZucker/transformers.git@jukebox\\naccelerate\\nbitsandbytes==0.31.8\\ngradio\"\n",
15
+ "\n",
16
+ "# Save the requirements.txt file\n",
17
+ "with open('requirements.txt', 'w') as f:\n",
18
+ " f.write(requirements_txt)\n",
19
+ "\n",
20
+ "# Install the dependencies\n",
21
+ "%pip install -r requirements.txt"
22
+ ]
23
+ },
24
+ {
25
+ "metadata": {},
26
+ "execution_count": null,
27
+ "outputs": [],
28
+ "cell_type": "code",
29
+ "source": [
30
+ "# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI\n",
31
+ "\n",
32
+ "from transformers import JukeboxVQVAE\n",
33
+ "\n",
34
+ "import gradio as gr\n",
35
+ "import torch as t\n",
36
+ "\n",
37
+ "model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']\n",
38
+ "\n",
39
+ "if 'google.colab' in sys.modules:\n",
40
+ "\n",
41
+ " cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:\"string\"}\n",
42
+ " # Connect to your Google Drive\n",
43
+ " from google.colab import drive\n",
44
+ " drive.mount('/content/drive')\n",
45
+ "\n",
46
+ "else:\n",
47
+ "\n",
48
+ " cache_path = '~/.cache/'\n",
49
+ "\n",
50
+ "class Convert:\n",
51
+ "\n",
52
+ " class TokenList:\n",
53
+ "\n",
54
+ " def to_tokens_file(tokens_list):\n",
55
+ " # temporary random file name\n",
56
+ " filename = f\"tmp/{t.randint(0, 1000000)}.jt\"\n",
57
+ " t.save(validate_tokens_list(tokens_list), filename)\n",
58
+ " return filename\n",
59
+ "\n",
60
+ " def to_audio(tokens_list):\n",
61
+ " return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)\n",
62
+ " # TODO: Implement converting other levels besides 2\n",
63
+ "\n",
64
+ " class TokensFile:\n",
65
+ "\n",
66
+ " def to_tokens_list(file):\n",
67
+ " return validate_tokens_list(t.load(file))\n",
68
+ "\n",
69
+ " def to_audio(file):\n",
70
+ " return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))\n",
71
+ "\n",
72
+ " class Audio:\n",
73
+ "\n",
74
+ " def to_tokens_list(audio):\n",
75
+ " return model.encode(audio.unsqueeze(0), start_level=2)\n",
76
+ " # (TODO: Generated by copilot, check if it works)\n",
77
+ "\n",
78
+ " def to_tokens_file(audio):\n",
79
+ " return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))\n",
80
+ "\n",
81
+ "def init():\n",
82
+ " global model\n",
83
+ "\n",
84
+ " model = JukeboxVQVAE.from_pretrained(\n",
85
+ " model_id,\n",
86
+ " device_map = \"auto\",\n",
87
+ " torch_dtype = t.float16,\n",
88
+ " cache_dir = f\"{cache_path}/jukebox/models\"\n",
89
+ " )\n",
90
+ "\n",
91
+ "def validate_tokens_list(tokens_list):\n",
92
+ " # Make sure that:\n",
93
+ " # - tokens_list is a list of exactly 3 torch tensors\n",
94
+ " assert len(tokens_list) == 3, \"Invalid file format: expecting a list of 3 tensors\"\n",
95
+ "\n",
96
+ " # - each has the same number of dimensions\n",
97
+ " assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), \"Invalid file format: each tensor in the list should have the same number of dimensions\"\n",
98
+ "\n",
99
+ " # - the shape along dimension 0 is the same\n",
100
+ " assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], \"Invalid file format: the shape along dimension 0 should be the same for all tensors in the list\"\n",
101
+ "\n",
102
+ " # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2\n",
103
+ " assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], \"Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2\"\n",
104
+ "\n",
105
+ " return tokens_list\n",
106
+ "\n",
107
+ "\n",
108
+ "with gr.Blocks() as ui:\n",
109
+ "\n",
110
+ " # File input to upload or download the music tokens file\n",
111
+ " tokens = gr.File(label='music_tokens_file')\n",
112
+ "\n",
113
+ " # Audio output to play or upload the generated audio\n",
114
+ " audio = gr.Audio(label='audio')\n",
115
+ "\n",
116
+ " # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)\n",
117
+ " gr.Button(label=\"Convert tokens to audio\", primary=True).click(Convert.TokensFile.to_audio, tokens, audio)\n",
118
+ " gr.Button(label=\"Convert audio to tokens\", primary=False).click(Convert.Audio.to_tokens_file, audio, tokens)\n",
119
+ "\n",
120
+ "if __name__ == '__main__':\n",
121
+ " init()\n",
122
+ " ui.launch()"
123
+ ]
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "Python 3",
129
+ "language": "python",
130
+ "name": "python3"
131
+ },
132
+ "language_info": {
133
+ "name": "python",
134
+ "version": "3.7.5"
135
+ },
136
+ "orig_nbformat": 4
137
+ },
138
+ "nbformat": 4,
139
+ "nbformat_minor": 2
140
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/ArthurZucker/transformers.git@jukebox
2
+ accelerate
3
+ bitsandbytes==0.31.8
4
+ gradio