asigalov61 commited on
Commit
b2efcdc
1 Parent(s): be28b0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -75
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # https://huggingface.co/spaces/asigalov61/Inpaint-Music-Transformer
2
 
3
  import os.path
4
 
@@ -21,64 +21,11 @@ import TMIDIX
21
  import matplotlib.pyplot as plt
22
 
23
  # =================================================================================================
24
-
25
- @spaces.GPU
26
- def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
27
- print('=' * 70)
28
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
- start_time = reqtime.time()
30
-
31
- print('Loading model...')
32
-
33
- SEQ_LEN = 8192 # Models seq len
34
- PAD_IDX = 19463 # Models pad index
35
- DEVICE = 'cuda' # 'cuda'
36
-
37
- # instantiate the model
38
-
39
- model = TransformerWrapper(
40
- num_tokens = PAD_IDX+1,
41
- max_seq_len = SEQ_LEN,
42
- attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True)
43
- )
44
-
45
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
46
-
47
- model.to(DEVICE)
48
- print('=' * 70)
49
-
50
- print('Loading model checkpoint...')
51
-
52
- model.load_state_dict(
53
- torch.load('Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth',
54
- map_location=DEVICE))
55
- print('=' * 70)
56
-
57
- model.eval()
58
-
59
- if DEVICE == 'cpu':
60
- dtype = torch.bfloat16
61
- else:
62
- dtype = torch.bfloat16
63
 
64
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
-
66
- print('Done!')
67
- print('=' * 70)
68
-
69
- fn = os.path.basename(input_midi.name)
70
- fn1 = fn.split('.')[0]
71
-
72
- input_num_of_notes = max(8, min(2048, input_num_of_notes))
73
-
74
- print('-' * 70)
75
- print('Input file name:', fn)
76
- print('Req num of notes:', input_num_of_notes)
77
- print('Req patch number:', input_patch_number)
78
- print('-' * 70)
79
 
80
  #===============================================================================
81
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
82
 
83
  #===============================================================================
84
  # Enhanced score notes
@@ -191,35 +138,85 @@ def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
191
  melody_chords2.append([delta_time, dur_vel+256, pat_ptc+2304])
192
 
193
  pe = e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  #==================================================================
197
 
198
  print('=' * 70)
199
- print('Number of tokens:', len(melody_chords))
200
- print('Number of notes:', len(melody_chords2))
201
- print('Sample output events', melody_chords[:5])
202
  print('=' * 70)
203
  print('Generating...')
204
 
205
- #@title Pitches/Instruments Inpainting
206
-
207
- #@markdown You can stop the inpainting at any time to render partial results
208
-
209
- #@markdown Inpainting settings
210
-
211
- #@markdown Select MIDI patch present in the composition to inpaint
212
-
213
- inpaint_MIDI_patch = input_patch_number
214
-
215
- #@markdown Generation settings
216
- number_of_prime_notes = 24
217
- number_of_memory_tokens = 1024 # @param {type:"slider", min:3, max:8190, step:3}
218
- number_of_samples_per_inpainted_note = 1 #@param {type:"slider", min:1, max:16, step:1}
219
  temperature = 0.85
220
 
221
  print('=' * 70)
222
- print('Giant Music Transformer Inpainting Model Generator')
223
  print('=' * 70)
224
 
225
  #==========================================================================
 
1
+ # https://huggingface.co/spaces/asigalov61/Intelligent-MIDI-Comparator
2
 
3
  import os.path
4
 
 
21
  import matplotlib.pyplot as plt
22
 
23
  # =================================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def read_MIDI(input_midi)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  #===============================================================================
28
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
29
 
30
  #===============================================================================
31
  # Enhanced score notes
 
138
  melody_chords2.append([delta_time, dur_vel+256, pat_ptc+2304])
139
 
140
  pe = e
141
+
142
+ return melody_chords, melody_chords2
143
+
144
+ # =================================================================================================
145
+
146
+
147
+ @spaces.GPU
148
+ def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
149
+ print('=' * 70)
150
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
151
+ start_time = reqtime.time()
152
+
153
+ print('Loading model...')
154
+
155
+ SEQ_LEN = 8192 # Models seq len
156
+ PAD_IDX = 19463 # Models pad index
157
+ DEVICE = 'cuda' # 'cuda'
158
+
159
+ # instantiate the model
160
+
161
+ model = TransformerWrapper(
162
+ num_tokens = PAD_IDX+1,
163
+ max_seq_len = SEQ_LEN,
164
+ attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True)
165
+ )
166
 
167
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
168
+
169
+ model.to(DEVICE)
170
+ print('=' * 70)
171
+
172
+ print('Loading model checkpoint...')
173
+
174
+ model.load_state_dict(
175
+ torch.load('Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth',
176
+ map_location=DEVICE))
177
+ print('=' * 70)
178
+
179
+ model.eval()
180
+
181
+ if DEVICE == 'cpu':
182
+ dtype = torch.bfloat16
183
+ else:
184
+ dtype = torch.bfloat16
185
+
186
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
187
+
188
+ print('Done!')
189
+ print('=' * 70)
190
+
191
+ fn = os.path.basename(input_midi.name)
192
+ fn1 = fn.split('.')[0]
193
+
194
+ input_num_of_notes = max(8, min(2048, input_num_of_notes))
195
+
196
+ print('-' * 70)
197
+ print('Input file name:', fn)
198
+ print('Req num of notes:', input_num_of_notes)
199
+ print('Req patch number:', input_patch_number)
200
+ print('-' * 70)
201
+
202
+ #===============================================================================
203
+
204
+ toekns, notes = read_MIDI(input_midi.name)
205
+
206
+
207
  #==================================================================
208
 
209
  print('=' * 70)
210
+ print('Number of tokens:', len(toekns))
211
+ print('Number of notes:', len(notes))
212
+ print('Sample output events', toekns[:5])
213
  print('=' * 70)
214
  print('Generating...')
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  temperature = 0.85
217
 
218
  print('=' * 70)
219
+ print('Giant Music Transformer MIDI Comparator')
220
  print('=' * 70)
221
 
222
  #==========================================================================