suric commited on
Commit
4336e0a
·
1 Parent(s): 923611d

add transcribe button

Browse files
Files changed (2) hide show
  1. app.py +8 -4
  2. gradio_components/prediction.py +25 -1
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
 
3
  import gradio as gr
4
 
5
- from gradio_components.prediction import predict
6
 
7
  theme = gr.themes.Glass(
8
  primary_hue="fuchsia",
@@ -136,13 +136,17 @@ def UI():
136
  )
137
  with gr.Row():
138
  submit = gr.Button("Generate Music")
139
- output = gr.Audio("listen to the generated music", type="filepath")
 
 
 
 
140
 
141
  submit.click(
142
  fn=predict,
143
  inputs=[model_path, prompt, melody, duration, topk, topp, temperature,
144
  sample_rate],
145
- outputs=output
146
  )
147
 
148
  gr.Examples(
@@ -204,7 +208,7 @@ def UI():
204
  ],
205
  inputs=[melody, difficulty, sample_rate, duration],
206
  label="Audio Examples",
207
- outputs=[output],
208
  # cache_examples=True,
209
  )
210
  demo.queue().launch()
 
2
 
3
  import gradio as gr
4
 
5
+ from gradio_components.prediction import predict, transcribe
6
 
7
  theme = gr.themes.Glass(
8
  primary_hue="fuchsia",
 
136
  )
137
  with gr.Row():
138
  submit = gr.Button("Generate Music")
139
+ output_audio = gr.Audio("listen to the generated music", type="filepath")
140
+ with gr.Row():
141
+ transcribe_button = gr.Button("Transcribe")
142
+ d = gr.DownloadButton("Download the file", visible=False)
143
+ transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
144
 
145
  submit.click(
146
  fn=predict,
147
  inputs=[model_path, prompt, melody, duration, topk, topp, temperature,
148
  sample_rate],
149
+ outputs=output_audio
150
  )
151
 
152
  gr.Examples(
 
208
  ],
209
  inputs=[melody, difficulty, sample_rate, duration],
210
  label="Audio Examples",
211
+ outputs=[output_audio],
212
  # cache_examples=True,
213
  )
214
  demo.queue().launch()
gradio_components/prediction.py CHANGED
@@ -9,6 +9,9 @@ from audiocraft.models import MusicGen
9
  from tempfile import NamedTemporaryFile
10
  from pathlib import Path
11
  from transformers import AutoModelForSeq2SeqLM
 
 
 
12
 
13
 
14
  def load_model(version='facebook/musicgen-melody'):
@@ -103,4 +106,25 @@ def predict(model_path, text, melody, duration, topk, topp, temperature, target_
103
  top_p=topp,
104
  temperature=temperature,
105
  gradio_progress=progress)
106
- return wavs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from tempfile import NamedTemporaryFile
10
  from pathlib import Path
11
  from transformers import AutoModelForSeq2SeqLM
12
+ import basic_pitch
13
+ import basic_pitch.inference
14
+ from basic_pitch import ICASSP_2022_MODEL_PATH
15
 
16
 
17
  def load_model(version='facebook/musicgen-melody'):
 
106
  top_p=topp,
107
  temperature=temperature,
108
  gradio_progress=progress)
109
+ return wavs[0]
110
+
111
+
112
+ def transcribe(audio_path):
113
+ # model_output, midi_data, note_events = predict("generated_0.wav")
114
+ model_output, midi_data, note_events = basic_pitch.inference.predict(
115
+ audio_path=audio_path,
116
+ model_or_model_path=ICASSP_2022_MODEL_PATH,
117
+ )
118
+
119
+ with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
120
+ try:
121
+ midi_data.write(file)
122
+ print(f"midi file saved to {file.name}")
123
+ except Exception as e:
124
+ print(f"Error while writing midi file: {e}")
125
+ raise e
126
+
127
+ return gr.DownloadButton(
128
+ value=file.name,
129
+ label=f"Download MIDI file {file.name}",
130
+ visible=True)