asigalov61
commited on
Commit
•
5837401
1
Parent(s):
5a9b440
Update app.py
Browse files
app.py
CHANGED
@@ -23,129 +23,104 @@ import TMIDIX
|
|
23 |
# =================================================================================================
|
24 |
|
25 |
@spaces.GPU
|
26 |
-
def
|
|
|
27 |
print('=' * 70)
|
28 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
29 |
start_time = reqtime.time()
|
30 |
-
|
|
|
31 |
print('Loading model...')
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
model = TransformerWrapper(
|
40 |
-
num_tokens = PAD_IDX+1,
|
41 |
-
max_seq_len = SEQ_LEN,
|
42 |
-
attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True)
|
43 |
-
)
|
44 |
-
|
45 |
-
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
46 |
-
|
47 |
-
model.to(DEVICE)
|
48 |
print('=' * 70)
|
49 |
|
50 |
print('Loading model checkpoint...')
|
51 |
|
52 |
-
|
53 |
-
torch.load('Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth',
|
54 |
-
map_location=DEVICE))
|
55 |
-
print('=' * 70)
|
56 |
-
|
57 |
-
model.eval()
|
58 |
-
|
59 |
-
if DEVICE == 'cpu':
|
60 |
-
dtype = torch.bfloat16
|
61 |
-
else:
|
62 |
-
dtype = torch.bfloat16
|
63 |
-
|
64 |
-
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
65 |
|
66 |
print('Done!')
|
67 |
print('=' * 70)
|
68 |
-
seed_melody = seed_melodies_data[input_melody_seed_number]
|
69 |
-
print('Input melody seed number:', input_melody_seed_number)
|
70 |
-
print('-' * 70)
|
71 |
|
72 |
-
|
73 |
|
74 |
-
print('=' * 70)
|
75 |
-
|
76 |
-
print('Sample output events', seed_melody[:16])
|
77 |
print('=' * 70)
|
78 |
print('Generating...')
|
79 |
|
80 |
-
|
81 |
|
82 |
-
|
83 |
-
with torch.inference_mode():
|
84 |
-
out = model.generate(x,
|
85 |
-
1024,
|
86 |
-
temperature=0.9,
|
87 |
-
return_prime=False,
|
88 |
-
verbose=False)
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
print('Done!')
|
94 |
print('=' * 70)
|
95 |
|
96 |
#===============================================================================
|
97 |
-
print('Rendering results...')
|
98 |
-
|
99 |
-
print('=' * 70)
|
100 |
-
print('Sample INTs', output[:15])
|
101 |
-
print('=' * 70)
|
102 |
-
|
103 |
-
out1 = output
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
song = out1
|
108 |
-
song_f = []
|
109 |
-
|
110 |
-
time = 0
|
111 |
-
dur = 0
|
112 |
-
vel = 90
|
113 |
-
pitch = 0
|
114 |
-
channel = 0
|
115 |
-
|
116 |
-
patches = [0] * 16
|
117 |
-
patches[3] = 40
|
118 |
-
|
119 |
-
for ss in song:
|
120 |
-
|
121 |
-
if 0 < ss < 128:
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
if 256 < ss < 512:
|
130 |
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
|
|
|
|
|
134 |
|
135 |
-
|
136 |
-
|
137 |
-
vel = 110 + (pitch % 12)
|
138 |
-
song_f.append(['note', time, dur, channel, pitch, vel, 40])
|
139 |
-
|
140 |
-
else:
|
141 |
-
vel = 80 + (pitch % 12)
|
142 |
-
channel = 0
|
143 |
-
song_f.append(['note', time, dur, channel, pitch, vel, 0])
|
144 |
|
145 |
-
fn1 = "
|
146 |
|
147 |
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
|
148 |
-
output_signature = '
|
149 |
output_file_name = fn1,
|
150 |
track_name='Project Los Angeles',
|
151 |
list_of_MIDI_patches=patches
|
@@ -223,7 +198,7 @@ if __name__ == "__main__":
|
|
223 |
output_plot = gr.Plot(label="Output MIDI score plot")
|
224 |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
|
225 |
|
226 |
-
run_event = run_btn.click(
|
227 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
|
228 |
|
229 |
app.queue().launch()
|
|
|
23 |
# =================================================================================================
|
24 |
|
25 |
@spaces.GPU
|
26 |
+
def Generate_POP_Medley(input_num_medley_comps):
|
27 |
+
|
28 |
print('=' * 70)
|
29 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
30 |
start_time = reqtime.time()
|
31 |
+
print('=' * 70)
|
32 |
+
|
33 |
print('Loading model...')
|
34 |
|
35 |
+
DIM = 64
|
36 |
+
CHANS = 1
|
37 |
+
TSTEPS = 1000
|
38 |
+
DEVICE = 'cuda' # 'cpu'
|
39 |
+
|
40 |
+
unet = Unet(
|
41 |
+
dim = DIM,
|
42 |
+
dim_mults = (1, 2, 4, 8),
|
43 |
+
num_resnet_blocks = 1,
|
44 |
+
channels=CHANS,
|
45 |
+
layer_attns = (False, False, False, True),
|
46 |
+
layer_cross_attns = False
|
47 |
+
)
|
48 |
+
|
49 |
+
imagen = Imagen(
|
50 |
+
condition_on_text = False, # this must be set to False for unconditional Imagen
|
51 |
+
unets = unet,
|
52 |
+
channels=CHANS,
|
53 |
+
image_sizes = 128,
|
54 |
+
timesteps = TSTEPS
|
55 |
+
)
|
56 |
+
|
57 |
+
trainer = ImagenTrainer(
|
58 |
+
imagen = imagen,
|
59 |
+
split_valid_from_train = True # whether to split the validation dataset from the training
|
60 |
+
).to(DEVICE)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
print('=' * 70)
|
63 |
|
64 |
print('Loading model checkpoint...')
|
65 |
|
66 |
+
trainer.load('Imagen_POP909_64_dim_12638_steps_0.00983_loss.ckptt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
print('Done!')
|
69 |
print('=' * 70)
|
|
|
|
|
|
|
70 |
|
71 |
+
print('Req number of medley compositions:', input_num_medley_comps)
|
72 |
|
|
|
|
|
|
|
73 |
print('=' * 70)
|
74 |
print('Generating...')
|
75 |
|
76 |
+
images = trainer.sample(batch_size = input_num_medley_comps, return_pil_images = True)
|
77 |
|
78 |
+
threshold = 128
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
imgs_array = []
|
81 |
+
|
82 |
+
for i in images:
|
83 |
+
arr = np.array(i)
|
84 |
+
farr = np.where(arr < threshold, 0, 1)
|
85 |
+
imgs_array.append(farr)
|
86 |
+
|
87 |
print('Done!')
|
88 |
print('=' * 70)
|
89 |
|
90 |
#===============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
print('Converting images to scores...')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
|
95 |
+
medley_compositions_escores = []
|
96 |
+
|
97 |
+
for i in imgs_array:
|
98 |
+
|
99 |
+
bmatrix = TPLOTS.images_to_binary_matrix([i])
|
|
|
100 |
|
101 |
+
score = TMIDIX.binary_matrix_to_original_escore_notes(bmatrix)
|
102 |
+
|
103 |
+
medley_compositions_escores.append(score)
|
104 |
+
|
105 |
+
print('Done!')
|
106 |
+
print('=' * 70)
|
107 |
+
print('Creating medley score...')
|
108 |
+
|
109 |
+
medley_labels = ['Composition #' + str(i+1) for i in range(len(medley_compositions_escores))]
|
110 |
+
|
111 |
+
medley_escore = TMIDIX.escore_notes_medley(medley_compositions_escores, medley_labels)
|
112 |
|
113 |
+
#===============================================================================
|
114 |
+
print('Rendering results...')
|
115 |
+
print('=' * 70)
|
116 |
|
117 |
+
print('Sample INTs', medley_escore[:15])
|
118 |
+
print('=' * 70)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
fn1 = "Imagen-POP-Music-Medley-Diffusion-Transformer-Composition"
|
121 |
|
122 |
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
|
123 |
+
output_signature = 'Imagen POP Music Medley',
|
124 |
output_file_name = fn1,
|
125 |
track_name='Project Los Angeles',
|
126 |
list_of_MIDI_patches=patches
|
|
|
198 |
output_plot = gr.Plot(label="Output MIDI score plot")
|
199 |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
|
200 |
|
201 |
+
run_event = run_btn.click(Generate_POP_Medley, [input_num_medley_comps],
|
202 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
|
203 |
|
204 |
app.queue().launch()
|