asigalov61 commited on
Commit
5837401
1 Parent(s): 5a9b440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -92
app.py CHANGED
@@ -23,129 +23,104 @@ import TMIDIX
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateSong(input_melody_seed_number):
 
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
30
-
 
31
  print('Loading model...')
32
 
33
- SEQ_LEN = 2560
34
- PAD_IDX = 514
35
- DEVICE = 'cuda' # 'cuda'
36
-
37
- # instantiate the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- model = TransformerWrapper(
40
- num_tokens = PAD_IDX+1,
41
- max_seq_len = SEQ_LEN,
42
- attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True)
43
- )
44
-
45
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
46
-
47
- model.to(DEVICE)
48
  print('=' * 70)
49
 
50
  print('Loading model checkpoint...')
51
 
52
- model.load_state_dict(
53
- torch.load('Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth',
54
- map_location=DEVICE))
55
- print('=' * 70)
56
-
57
- model.eval()
58
-
59
- if DEVICE == 'cpu':
60
- dtype = torch.bfloat16
61
- else:
62
- dtype = torch.bfloat16
63
-
64
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
 
66
  print('Done!')
67
  print('=' * 70)
68
- seed_melody = seed_melodies_data[input_melody_seed_number]
69
- print('Input melody seed number:', input_melody_seed_number)
70
- print('-' * 70)
71
 
72
- #==================================================================
73
 
74
- print('=' * 70)
75
-
76
- print('Sample output events', seed_melody[:16])
77
  print('=' * 70)
78
  print('Generating...')
79
 
80
- x = (torch.tensor(seed_melody, dtype=torch.long, device='cuda')[None, ...])
81
 
82
- with ctx:
83
- with torch.inference_mode():
84
- out = model.generate(x,
85
- 1024,
86
- temperature=0.9,
87
- return_prime=False,
88
- verbose=False)
89
 
90
- output = out[0].tolist()
91
-
92
- print('=' * 70)
 
 
 
 
93
  print('Done!')
94
  print('=' * 70)
95
 
96
  #===============================================================================
97
- print('Rendering results...')
98
-
99
- print('=' * 70)
100
- print('Sample INTs', output[:15])
101
- print('=' * 70)
102
-
103
- out1 = output
104
 
105
- if len(out1) != 0:
106
-
107
- song = out1
108
- song_f = []
109
-
110
- time = 0
111
- dur = 0
112
- vel = 90
113
- pitch = 0
114
- channel = 0
115
-
116
- patches = [0] * 16
117
- patches[3] = 40
118
-
119
- for ss in song:
120
-
121
- if 0 < ss < 128:
122
 
123
- time += (ss * 32)
124
-
125
- if 128 < ss < 256:
126
-
127
- dur = (ss-128) * 32
128
-
129
- if 256 < ss < 512:
130
 
131
- pitch = (ss-256) % 128
 
 
 
 
 
 
 
 
 
 
132
 
133
- channel = (ss-256) // 128
 
 
134
 
135
- if channel == 1:
136
- channel = 3
137
- vel = 110 + (pitch % 12)
138
- song_f.append(['note', time, dur, channel, pitch, vel, 40])
139
-
140
- else:
141
- vel = 80 + (pitch % 12)
142
- channel = 0
143
- song_f.append(['note', time, dur, channel, pitch, vel, 0])
144
 
145
- fn1 = "Melody2Song-Seq2Seq-Music-Transformer-Composition"
146
 
147
  detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
148
- output_signature = 'Melody2Song Seq2Seq Music Transformer',
149
  output_file_name = fn1,
150
  track_name='Project Los Angeles',
151
  list_of_MIDI_patches=patches
@@ -223,7 +198,7 @@ if __name__ == "__main__":
223
  output_plot = gr.Plot(label="Output MIDI score plot")
224
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
225
 
226
- run_event = run_btn.click(GenerateSong, [input_melody_seed_number],
227
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
228
 
229
  app.queue().launch()
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def Generate_POP_Medley(input_num_medley_comps):
27
+
28
  print('=' * 70)
29
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
30
  start_time = reqtime.time()
31
+ print('=' * 70)
32
+
33
  print('Loading model...')
34
 
35
+ DIM = 64
36
+ CHANS = 1
37
+ TSTEPS = 1000
38
+ DEVICE = 'cuda' # 'cpu'
39
+
40
+ unet = Unet(
41
+ dim = DIM,
42
+ dim_mults = (1, 2, 4, 8),
43
+ num_resnet_blocks = 1,
44
+ channels=CHANS,
45
+ layer_attns = (False, False, False, True),
46
+ layer_cross_attns = False
47
+ )
48
+
49
+ imagen = Imagen(
50
+ condition_on_text = False, # this must be set to False for unconditional Imagen
51
+ unets = unet,
52
+ channels=CHANS,
53
+ image_sizes = 128,
54
+ timesteps = TSTEPS
55
+ )
56
+
57
+ trainer = ImagenTrainer(
58
+ imagen = imagen,
59
+ split_valid_from_train = True # whether to split the validation dataset from the training
60
+ ).to(DEVICE)
61
 
 
 
 
 
 
 
 
 
 
62
  print('=' * 70)
63
 
64
  print('Loading model checkpoint...')
65
 
66
+ trainer.load('Imagen_POP909_64_dim_12638_steps_0.00983_loss.ckptt')
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  print('Done!')
69
  print('=' * 70)
 
 
 
70
 
71
+ print('Req number of medley compositions:', input_num_medley_comps)
72
 
 
 
 
73
  print('=' * 70)
74
  print('Generating...')
75
 
76
+ images = trainer.sample(batch_size = input_num_medley_comps, return_pil_images = True)
77
 
78
+ threshold = 128
 
 
 
 
 
 
79
 
80
+ imgs_array = []
81
+
82
+ for i in images:
83
+ arr = np.array(i)
84
+ farr = np.where(arr < threshold, 0, 1)
85
+ imgs_array.append(farr)
86
+
87
  print('Done!')
88
  print('=' * 70)
89
 
90
  #===============================================================================
 
 
 
 
 
 
 
91
 
92
+ print('Converting images to scores...')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+
95
+ medley_compositions_escores = []
96
+
97
+ for i in imgs_array:
98
+
99
+ bmatrix = TPLOTS.images_to_binary_matrix([i])
 
100
 
101
+ score = TMIDIX.binary_matrix_to_original_escore_notes(bmatrix)
102
+
103
+ medley_compositions_escores.append(score)
104
+
105
+ print('Done!')
106
+ print('=' * 70)
107
+ print('Creating medley score...')
108
+
109
+ medley_labels = ['Composition #' + str(i+1) for i in range(len(medley_compositions_escores))]
110
+
111
+ medley_escore = TMIDIX.escore_notes_medley(medley_compositions_escores, medley_labels)
112
 
113
+ #===============================================================================
114
+ print('Rendering results...')
115
+ print('=' * 70)
116
 
117
+ print('Sample INTs', medley_escore[:15])
118
+ print('=' * 70)
 
 
 
 
 
 
 
119
 
120
+ fn1 = "Imagen-POP-Music-Medley-Diffusion-Transformer-Composition"
121
 
122
  detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
123
+ output_signature = 'Imagen POP Music Medley',
124
  output_file_name = fn1,
125
  track_name='Project Los Angeles',
126
  list_of_MIDI_patches=patches
 
198
  output_plot = gr.Plot(label="Output MIDI score plot")
199
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
200
 
201
+ run_event = run_btn.click(Generate_POP_Medley, [input_num_medley_comps],
202
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
203
 
204
  app.queue().launch()