asigalov61 commited on
Commit
356c43a
·
verified ·
1 Parent(s): aba2db8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -23,12 +23,9 @@ import matplotlib.pyplot as plt
23
  in_space = os.getenv("SYSTEM") == "spaces"
24
 
25
  # =================================================================================================
26
-
27
  @spaces.GPU
28
- def ClassifyMIDI(input_midi):
29
- print('=' * 70)
30
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
31
- start_time = reqtime.time()
32
 
33
  print('Loading model...')
34
 
@@ -70,6 +67,48 @@ def ClassifyMIDI(input_midi):
70
 
71
  #==================================================================
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  print('=' * 70)
74
 
75
  fn = os.path.basename(input_midi.name)
@@ -158,39 +197,8 @@ def ClassifyMIDI(input_midi):
158
  classification_summary_string += 'Number of notes in all composition chunks: ' + str(len(input_data) * 340) + '\n'
159
  classification_summary_string += '=' * 70
160
  classification_summary_string += '\n'
161
-
162
- number_of_batches = 100 # @param {type:"slider", min:1, max:100, step:1}
163
-
164
- # @markdown NOTE: You can increase the number of batches on high-ram GPUs for better classification
165
-
166
- print('=' * 70)
167
- print('Annotated MIDI Dataset Classifier')
168
- print('=' * 70)
169
- print('Classifying...')
170
-
171
- torch.cuda.empty_cache()
172
-
173
- model.eval()
174
-
175
- results = []
176
-
177
- for input in input_data:
178
-
179
- x = torch.tensor([input[:1022]] * number_of_batches, dtype=torch.long, device='cuda')
180
-
181
- with ctx:
182
- out = model.generate(x,
183
- 1,
184
- temperature=0.3,
185
- return_prime=False,
186
- verbose=False)
187
-
188
- y = out.tolist()
189
-
190
- output = [l[0] for l in y]
191
- result = mode(output)
192
-
193
- results.append(result)
194
 
195
  all_results_labels = [classifier_labels[0][r-384] for r in results]
196
  final_result = mode(results)
 
23
  in_space = os.getenv("SYSTEM") == "spaces"
24
 
25
  # =================================================================================================
26
+
27
  @spaces.GPU
28
+ def classify_GPU(input_data):
 
 
 
29
 
30
  print('Loading model...')
31
 
 
67
 
68
  #==================================================================
69
 
70
+ number_of_batches = 100 # @param {type:"slider", min:1, max:100, step:1}
71
+
72
+ # @markdown NOTE: You can increase the number of batches on high-ram GPUs for better classification
73
+
74
+ print('=' * 70)
75
+ print('Annotated MIDI Dataset Classifier')
76
+ print('=' * 70)
77
+ print('Classifying...')
78
+
79
+ torch.cuda.empty_cache()
80
+
81
+ model.eval()
82
+
83
+ results = []
84
+
85
+ for input in input_data:
86
+
87
+ x = torch.tensor([input[:1022]] * number_of_batches, dtype=torch.long, device='cuda')
88
+
89
+ with ctx:
90
+ out = model.generate(x,
91
+ 1,
92
+ temperature=0.3,
93
+ return_prime=False,
94
+ verbose=False)
95
+
96
+ y = out.tolist()
97
+
98
+ output = [l[0] for l in y]
99
+ result = mode(output)
100
+
101
+ results.append(result)
102
+
103
+ return results
104
+
105
+ # =================================================================================================
106
+
107
+ def ClassifyMIDI(input_midi):
108
+ print('=' * 70)
109
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
110
+ start_time = reqtime.time()
111
+
112
  print('=' * 70)
113
 
114
  fn = os.path.basename(input_midi.name)
 
197
  classification_summary_string += 'Number of notes in all composition chunks: ' + str(len(input_data) * 340) + '\n'
198
  classification_summary_string += '=' * 70
199
  classification_summary_string += '\n'
200
+
201
+ results = classify_GPU(input_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  all_results_labels = [classifier_labels[0][r-384] for r in results]
204
  final_result = mode(results)