Allow Gradio to be accessible when running inside container

#3
by janwari - opened
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -11,13 +11,14 @@ import logging
11
  import json
12
  import os
13
  import re
 
14
 
15
  import pandas as pd
16
 
17
- import importlib
18
  modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")
19
 
20
- from Prediction_Head.MTGGenre_head import MLPProberBase
21
  # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
22
 
23
 
@@ -33,12 +34,12 @@ logger.addHandler(ch)
33
 
34
 
35
  inputs = [
36
- gr.components.Audio(type="filepath", label="Add music audio file"),
37
  ]
38
 
39
  title = "Isai - toward better music understanding"
40
  description = "This space uses MERT-95M model to peform various music information retrieval tasks."
41
-
42
  audio_examples = [
43
  ["samples/143.mp3"],
44
  ["samples/205.mp3"],
@@ -78,7 +79,7 @@ MERT_BEST_LAYER_IDX = {
78
  'NSynthP': 1,
79
  'VocalSetS': 2,
80
  'VocalSetT': 9,
81
- }
82
 
83
  MERT_BEST_LAYER_IDX = {
84
  'EMO': 5,
@@ -93,7 +94,7 @@ MERT_BEST_LAYER_IDX = {
93
  'NSynthP': 1,
94
  'VocalSetS': 2,
95
  'VocalSetT': 9,
96
- }
97
  CLASSIFIERS = {
98
 
99
  }
@@ -135,7 +136,7 @@ def model_infernce(inputs):
135
  # print(f'setting rate from {sample_rate} to {resample_rate}')
136
  resampler = T.Resample(sample_rate, resample_rate)
137
  waveform = resampler(waveform)
138
-
139
  waveform = waveform.view(-1,) # make it (n_sample, )
140
  model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
141
  model_inputs.to(device)
@@ -159,12 +160,12 @@ def model_infernce(inputs):
159
  else:
160
  logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
161
  # print(f'task {task} logits:', logits.shape, 'num class:', num_class)
162
-
163
- sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1
164
  sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
165
  # print(sorted_prob)
166
  # print(sorted_prob.shape)
167
-
168
  top_n_show = 5 if num_class >= 5 else num_class
169
  # task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
170
  # task_output_texts = task_output_texts + '----------------------\n'
@@ -185,17 +186,17 @@ def model_infernce(inputs):
185
  df_objects.append(row_elements)
186
  df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
187
  return df
188
-
189
  def convert_audio(inputs, microphone):
190
  if (microphone is not None):
191
  inputs = microphone
192
- df = model_infernce(inputs)
193
  return df
194
 
195
  def live_convert_audio(microphone):
196
  if (microphone is not None):
197
  inputs = microphone
198
- df = model_infernce(inputs)
199
  return df
200
 
201
  audio_chunked = gr.Interface(
@@ -228,11 +229,13 @@ audio_chunked = gr.Interface(
228
  # [
229
  # audio_chunked,
230
  # live_audio_chunked,
231
- # ],
232
  # [
233
  # "Audio File or Recording",
234
  # "Live Streaming Music"
235
  # ]
236
  # )
237
  # demo.queue(concurrency_count=1, max_size=5)
238
- audio_chunked.launch(show_api=False)
 
 
 
11
  import json
12
  import os
13
  import re
14
+ import os
15
 
16
  import pandas as pd
17
 
18
+ import importlib
19
  modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")
20
 
21
+ from Prediction_Head.MTGGenre_head import MLPProberBase
22
  # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
23
 
24
 
 
34
 
35
 
36
  inputs = [
37
+ gr.components.Audio(type="filepath", label="Add music audio file"),
38
  ]
39
 
40
  title = "Isai - toward better music understanding"
41
  description = "This space uses MERT-95M model to peform various music information retrieval tasks."
42
+
43
  audio_examples = [
44
  ["samples/143.mp3"],
45
  ["samples/205.mp3"],
 
79
  'NSynthP': 1,
80
  'VocalSetS': 2,
81
  'VocalSetT': 9,
82
+ }
83
 
84
  MERT_BEST_LAYER_IDX = {
85
  'EMO': 5,
 
94
  'NSynthP': 1,
95
  'VocalSetS': 2,
96
  'VocalSetT': 9,
97
+ }
98
  CLASSIFIERS = {
99
 
100
  }
 
136
  # print(f'setting rate from {sample_rate} to {resample_rate}')
137
  resampler = T.Resample(sample_rate, resample_rate)
138
  waveform = resampler(waveform)
139
+
140
  waveform = waveform.view(-1,) # make it (n_sample, )
141
  model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
142
  model_inputs.to(device)
 
160
  else:
161
  logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
162
  # print(f'task {task} logits:', logits.shape, 'num class:', num_class)
163
+
164
+ sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1
165
  sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
166
  # print(sorted_prob)
167
  # print(sorted_prob.shape)
168
+
169
  top_n_show = 5 if num_class >= 5 else num_class
170
  # task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
171
  # task_output_texts = task_output_texts + '----------------------\n'
 
186
  df_objects.append(row_elements)
187
  df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
188
  return df
189
+
190
  def convert_audio(inputs, microphone):
191
  if (microphone is not None):
192
  inputs = microphone
193
+ df = model_infernce(inputs)
194
  return df
195
 
196
  def live_convert_audio(microphone):
197
  if (microphone is not None):
198
  inputs = microphone
199
+ df = model_infernce(inputs)
200
  return df
201
 
202
  audio_chunked = gr.Interface(
 
229
  # [
230
  # audio_chunked,
231
  # live_audio_chunked,
232
+ # ],
233
  # [
234
  # "Audio File or Recording",
235
  # "Live Streaming Music"
236
  # ]
237
  # )
238
  # demo.queue(concurrency_count=1, max_size=5)
239
+
240
+ server_name = os.environ.get('GRADIO_SERVER_NAME', "127.0.0.1")
241
+ audio_chunked.launch(server_name=server_name, show_api=False)