asigalov61 commited on
Commit
7d723ab
1 Parent(s): 6ceed88

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +337 -0
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import time as reqtime
4
+ import datetime
5
+ from pytz import timezone
6
+
7
+ import torch
8
+
9
+ import spaces
10
+ import gradio as gr
11
+
12
+ from x_transformer_1_23_2 import *
13
+ import random
14
+ import tqdm
15
+
16
+ from midi_to_colab_audio import midi_to_colab_audio
17
+ import TMIDIX
18
+
19
+ import matplotlib.pyplot as plt
20
+
21
+ in_space = os.getenv("SYSTEM") == "spaces"
22
+
23
+ # =================================================================================================
24
+
25
+ @spaces.GPU
26
+ def Text_to_Music(input_title, input_num_tokens, input_prompt_type):
27
+ print('=' * 70)
28
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
+ start_time = reqtime.time()
30
+
31
+ print('Loading model...')
32
+
33
+ SEQ_LEN = 4096 # Models seq len
34
+ PAD_IDX = 2571 # Models pad index
35
+ DEVICE = 'cuda' # 'cuda'
36
+
37
+ # instantiate the model
38
+
39
+ model = TransformerWrapper(
40
+ num_tokens = PAD_IDX+1,
41
+ max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(dim = 2048, depth = 8, heads = 16, attn_flash = True)
43
+ )
44
+
45
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
46
+
47
+ model.to(DEVICE)
48
+ print('=' * 70)
49
+
50
+ print('Loading model checkpoint...')
51
+
52
+ model.load_state_dict(
53
+ torch.load('Text_to_Music_Transformer_Medium_Trained_Model_33934_steps_0.6093_loss_0.813_acc.pth',
54
+ map_location=DEVICE))
55
+ print('=' * 70)
56
+
57
+ model.eval()
58
+
59
+ if DEVICE == 'cpu':
60
+ dtype = torch.bfloat16
61
+ else:
62
+ dtype = torch.float16
63
+
64
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
+
66
+ print('Done!')
67
+ print('=' * 70)
68
+
69
+ input_num_tokens = max(8, min(2048, input_num_tokens))
70
+
71
+ print('-' * 70)
72
+ print('Input title:', input_title)
73
+ print('Req num toks:', input_num_tokens)
74
+ print('Open-ended prompt:', input_prompt_type)
75
+ print('-' * 70)
76
+
77
+ #===============================================================================
78
+
79
+ print('Setting up model patches and loading helper functions...')
80
+
81
+ # @title Setup and load model channels MIDI patches
82
+
83
+ model_channel_0_piano_family = "Acoustic Grand" # @param ["Acoustic Grand", "Bright Acoustic", "Electric Grand", "Honky-Tonk", "Electric Piano 1", "Electric Piano 2", "Harpsichord", "Clav"]
84
+ model_channel_1_chromatic_percussion_family = "Music Box" # @param ["Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba", "Xylophone", "Tubular Bells", "Dulcimer"]
85
+ model_channel_2_organ_family = "Church Organ" # @param ["Drawbar Organ", "Percussive Organ", "Rock Organ", "Church Organ", "Reed Organ", "Accordion", "Harmonica", "Tango Accordion"]
86
+ model_channel_3_guitar_family = "Acoustic Guitar(nylon)" # @param ["Acoustic Guitar(nylon)", "Acoustic Guitar(steel)", "Electric Guitar(jazz)", "Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar", "Guitar Harmonics"]
87
+ model_channel_4_bass_family = "Fretless Bass" # @param ["Acoustic Bass", "Electric Bass(finger)", "Electric Bass(pick)", "Fretless Bass", "Slap Bass 1", "Slap Bass 2", "Synth Bass 1", "Synth Bass 2"]
88
+ model_channel_5_strings_family = "Violin" # @param ["Violin", "Viola", "Cello", "Contrabass", "Tremolo Strings", "Pizzicato Strings", "Orchestral Harp", "Timpani"]
89
+ model_channel_6_ensemble_family = "Choir Aahs" # @param ["String Ensemble 1", "String Ensemble 2", "SynthStrings 1", "SynthStrings 2", "Choir Aahs", "Voice Oohs", "Synth Voice", "Orchestra Hit"]
90
+ model_channel_7_brass_family = "Trumpet" # @param ["Trumpet", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Brass Section", "SynthBrass 1", "SynthBrass 2"]
91
+ model_channel_8_reed_family = "Alto Sax" # @param ["Soprano Sax", "Alto Sax", "Tenor Sax", "Baritone Sax", "Oboe", "English Horn", "Bassoon", "Clarinet"]
92
+ model_channel_9_pipe_family = "Flute" # @param ["Piccolo", "Flute", "Recorder", "Pan Flute", "Blown Bottle", "Skakuhachi", "Whistle", "Ocarina"]
93
+ model_channel_10_synth_lead_family = "Lead 8 (bass+lead)" # @param ["Lead 1 (square)", "Lead 2 (sawtooth)", "Lead 3 (calliope)", "Lead 4 (chiff)", "Lead 5 (charang)", "Lead 6 (voice)", "Lead 7 (fifths)", "Lead 8 (bass+lead)"]
94
+ model_channel_11_synth_pad_family = "Pad 2 (warm)" # @param ["Pad 1 (new age)", "Pad 2 (warm)", "Pad 3 (polysynth)", "Pad 4 (choir)", "Pad 5 (bowed)", "Pad 6 (metallic)", "Pad 7 (halo)", "Pad 8 (sweep)"]
95
+ model_channel_12_synth_effects_family = "FX 3 (crystal)" # @param ["FX 1 (rain)", "FX 2 (soundtrack)", "FX 3 (crystal)", "FX 4 (atmosphere)", "FX 5 (brightness)", "FX 6 (goblins)", "FX 7 (echoes)", "FX 8 (sci-fi)"]
96
+ model_channel_13_ethnic_family = "Banjo" # @param ["Sitar", "Banjo", "Shamisen", "Koto", "Kalimba", "Bagpipe", "Fiddle", "Shanai"]
97
+ model_channel_14_percussive_family = "Melodic Tom" # @param ["Tinkle Bell", "Agogo", "Steel Drums", "Woodblock", "Taiko Drum", "Melodic Tom", "Synth Drum", "Reverse Cymbal"]
98
+ model_channel_15_sound_effects_family = "Bird Tweet" # @param ["Guitar Fret Noise", "Breath Noise", "Seashore", "Bird Tweet", "Telephone Ring", "Helicopter", "Applause", "Gunshot"]
99
+ model_channel_16_drums_family = "Drums" # @param ["Drums"]
100
+
101
+ print('=' * 70)
102
+ print('Loading helper functions...')
103
+
104
+ def txt2tokens(txt):
105
+ return [ord(char)+2440 if 0 < ord(char) < 128 else 0+2440 for char in txt.lower()]
106
+
107
+ def tokens2txt(tokens):
108
+ return [chr(tok-2440) for tok in tokens if 0+2440 < tok < 128+2440 ]
109
+
110
+ print('=' * 70)
111
+ print('Setting up patches...')
112
+ print('=' * 70)
113
+
114
+ instruments = [v[1] for v in TMIDIX.Number2patch.items()]
115
+
116
+ patches = [instruments.index(model_channel_0_piano_family),
117
+ instruments.index(model_channel_1_chromatic_percussion_family),
118
+ instruments.index(model_channel_2_organ_family),
119
+ instruments.index(model_channel_3_guitar_family),
120
+ instruments.index(model_channel_4_bass_family),
121
+ instruments.index(model_channel_5_strings_family),
122
+ instruments.index(model_channel_6_ensemble_family),
123
+ instruments.index(model_channel_7_brass_family),
124
+ instruments.index(model_channel_8_reed_family),
125
+ 9, # Drums patch
126
+ instruments.index(model_channel_9_pipe_family),
127
+ instruments.index(model_channel_10_synth_lead_family),
128
+ instruments.index(model_channel_11_synth_pad_family),
129
+ instruments.index(model_channel_12_synth_effects_family),
130
+ instruments.index(model_channel_13_ethnic_family),
131
+ instruments.index(model_channel_15_sound_effects_family)
132
+ ]
133
+
134
+ print('Done!')
135
+ print('=' * 70)
136
+
137
+ print('Generating...')
138
+
139
+ #@title Standard Text-to-Music Generator
140
+
141
+ #@markdown Prompt settings
142
+
143
+ song_title_prompt = input_title
144
+ open_ended_prompt = input_prompt_type
145
+
146
+ #@markdown Generation settings
147
+
148
+ number_of_tokens_to_generate = input_num_tokens
149
+ number_of_batches_to_generate = 1 #@param {type:"slider", min:1, max:16, step:1}
150
+ temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05}
151
+
152
+ print('=' * 70)
153
+ print('Text-to-Music Model Generator')
154
+ print('=' * 70)
155
+
156
+ if song_title_prompt == '':
157
+ outy = [2569]
158
+
159
+ else:
160
+ if open_ended_prompt:
161
+ outy = [2569] + txt2tokens(song_title_prompt)
162
+
163
+ else:
164
+ outy = [2569] + txt2tokens(song_title_prompt) + [2570]
165
+
166
+ print('Selected prompt sequence:')
167
+ print(outy[:12])
168
+ print('=' * 70)
169
+
170
+ torch.cuda.empty_cache()
171
+
172
+ inp = [outy] * number_of_batches_to_generate
173
+
174
+ inp = torch.LongTensor(inp).cuda()
175
+
176
+ with ctx:
177
+ out = model.generate(inp,
178
+ number_of_tokens_to_generate,
179
+ temperature=temperature,
180
+ return_prime=True,
181
+ verbose=False)
182
+
183
+ out0 = out.tolist()
184
+
185
+ print('=' * 70)
186
+ print('Done!')
187
+ print('=' * 70)
188
+
189
+ #===============================================================================
190
+ print('Rendering results...')
191
+ print('=' * 70)
192
+
193
+ out1 = out0[0]
194
+
195
+ print('Sample INTs', out1[:12])
196
+ print('=' * 70)
197
+
198
+ generated_song_title = ''.join(tokens2txt(out1)).title()
199
+
200
+ print('Generated song title:', generated_song_title)
201
+ print('=' * 70)
202
+
203
+ if len(out1) != 0:
204
+
205
+ song = out1
206
+ song_f = []
207
+
208
+ time = 0
209
+ dur = 0
210
+ vel = 90
211
+ pitch = 0
212
+ channel = 0
213
+ chan = 0
214
+
215
+ for ss in song:
216
+
217
+ if 0 <= ss < 128:
218
+
219
+ time += ss * 32
220
+
221
+ if 128 <= ss < 256:
222
+
223
+ dur = (ss-128) * 32
224
+
225
+ if 256 <= ss < 2432:
226
+
227
+ chan = (ss-256) // 128
228
+
229
+ if chan < 9:
230
+ channel = chan
231
+ elif 9 < chan < 15:
232
+ channel = chan+1
233
+ elif chan == 15:
234
+ channel = 15
235
+ elif chan == 16:
236
+ channel = 9
237
+
238
+ pitch = (ss-256) % 128
239
+
240
+ if 2432 <= ss < 2440:
241
+
242
+ vel = (((ss-2432)+1) * 15)-1
243
+
244
+ song_f.append(['note', time, dur, channel, pitch, vel, chan*8 ])
245
+
246
+ fn1 = "Text-to-Music-Transformer-Composition"
247
+
248
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
249
+ output_signature = 'Text-to-Music Transformer',
250
+ output_file_name = fn1,
251
+ track_name='Project Los Angeles',
252
+ list_of_MIDI_patches=patches
253
+ )
254
+
255
+ new_fn = fn1+'.mid'
256
+
257
+
258
+ audio = midi_to_colab_audio(new_fn,
259
+ soundfont_path=soundfont,
260
+ sample_rate=16000,
261
+ volume_scale=10,
262
+ output_for_gradio=True
263
+ )
264
+
265
+ print('Done!')
266
+ print('=' * 70)
267
+
268
+ #========================================================
269
+
270
+ output_midi_title = generated_song_title
271
+ output_midi_summary = str(song_f[:3])
272
+ output_midi = str(new_fn)
273
+ output_audio = (16000, audio)
274
+
275
+ output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
276
+
277
+ print('Output MIDI file name:', output_midi)
278
+ print('Output MIDI title:', output_midi_title)
279
+ print('Output MIDI summary:', output_midi_summary)
280
+ print('=' * 70)
281
+
282
+
283
+ #========================================================
284
+
285
+ print('-' * 70)
286
+ print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
287
+ print('-' * 70)
288
+ print('Req execution time:', (reqtime.time() - start_time), 'sec')
289
+
290
+ return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
291
+
292
+ # =================================================================================================
293
+
294
+ if __name__ == "__main__":
295
+
296
+ PDT = timezone('US/Pacific')
297
+
298
+ print('=' * 70)
299
+ print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
300
+ print('=' * 70)
301
+
302
+ soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
303
+
304
+ app = gr.Blocks()
305
+ with app:
306
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Text-to-Music Transformer</h1>")
307
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate music based on a title of your imagination :)</h1>")
308
+ gr.Markdown(
309
+ "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Text-to-Music-Transformer&style=flat)\n\n"
310
+ "Generate music based on a title of your imagination :)\n\n"
311
+ "Check out [Text-to-Music Transformer](https://github.com/asigalov61/Text-to-Music-Transformer) on GitHub!\n\n"
312
+ "[Open In Colab]"
313
+ "(https://colab.research.google.com/github/asigalov61/Text-to-Music-Transformer/blob/main/Text_to_Music_Transformer.ipynb)"
314
+ " for faster execution and endless generation"
315
+ )
316
+ gr.Markdown("## Enter any desired song title")
317
+
318
+ input_title = gr.Textbox(value="Nothing Else Matters", label="Song title")
319
+ input_prompt_type = gr.Checkbox(label="Open-ended prompt")
320
+
321
+ input_num_tokens = gr.Slider(8, 2048, value=512, step=8, label="Number of tokens to generate")
322
+
323
+ run_btn = gr.Button("generate", variant="primary")
324
+
325
+ gr.Markdown("## Generation results")
326
+
327
+ output_midi_title = gr.Textbox(label="Generated MIDI title")
328
+ output_midi_summary = gr.Textbox(label="Output MIDI summary")
329
+ output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
330
+ output_plot = gr.Plot(label="Output MIDI score plot")
331
+ output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
332
+
333
+
334
+ run_event = run_btn.click(Text_to_Music, [input_title, input_num_tokens, input_prompt_type],
335
+ [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
336
+
337
+ app.queue().launch()