Katpeeler commited on
Commit
fcd504e
·
1 Parent(s): 2d56d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -2,15 +2,24 @@ import gradio as gr
2
  import note_seq
3
  import numpy as np
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
  from constants import GM_INSTRUMENTS
6
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("Katpeeler/midi_model_3")
8
  model = AutoModelForCausalLM.from_pretrained("Katpeeler/midi_model_3")
9
 
 
 
10
  NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / 120
11
  BAR_LENGTH_120BPM = 4.0 * 60 / 120
 
 
12
  SAMPLE_RATE=44100
13
 
 
 
 
14
  def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True, instrument_mapper=None, only_piano=False):
15
  if isinstance(token_sequence, str):
16
  token_sequence = token_sequence.split()
@@ -109,54 +118,74 @@ def empty_note_sequence(qpm=120.0, total_time=0.0):
109
  note_sequence.total_time = total_time
110
  return note_sequence
111
 
 
 
 
112
  def process(num1, num2, num3):
 
 
 
113
  created_text = f"""PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST={num1} BAR_START NOTE_ON={num2}"""
 
 
114
  global NOTE_LENGTH_16TH_120BPM
115
  NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / num3
116
  global BAR_LENGTH_120BPM
117
  BAR_LENGTH_120BPM = 4.0 * 60 / num3
 
 
118
  input_ids = tokenizer.encode(created_text, return_tensors="pt")
119
  generated_ids = model.generate(input_ids, max_length=500)
120
  global generated_sequence
121
  generated_sequence = tokenizer.decode(generated_ids[0])
122
 
123
- # Convert text of notes to audio
124
  note_sequence = token_sequence_to_note_sequence(generated_sequence)
 
125
  synth = note_seq.midi_synth.synthesize
126
  array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
127
  note_plot = note_seq.plot_sequence(note_sequence, False)
128
  array_of_floats /=1.414
129
  array_of_floats *= 32767
130
  int16_data = array_of_floats.astype(np.int16)
 
131
  return SAMPLE_RATE, int16_data
132
 
133
-
134
  def generation():
135
  return generated_sequence
136
 
137
-
 
138
  def identity(x, state):
139
  state += 1
140
  return x, state, state
141
 
 
142
  with gr.Blocks() as demo:
 
143
  gr.Markdown("Midi Generation")
144
- #with gr.Tab("Token generation"):
145
- # text_output = gr.Textbox()
146
- # text_button = gr.Button("show generated tokens")
147
  with gr.Tab("Audio generation"):
 
148
  audio_output = gr.Audio()
 
149
  number1 = gr.Slider(1, 100, value=25, label="Inst number", step=1, info="Choose between 1 and 100")
150
  number2 = gr.Slider(1, 100, value=40, label="Note number", step=1, info="Choose between 1 and 100")
151
  number3 = gr.Slider(60, 140, value=120, label="BPM", step=5, info="Choose between 60 and 140")
 
152
  audio_button = gr.Button("generate audio")
 
153
  with gr.Tab("Token generation"):
 
154
  text_output = gr.Textbox()
 
155
  text_button = gr.Button("show generated tokens")
156
-
 
157
  text_button.click(generation, inputs=None, outputs=text_output)
158
  audio_button.click(process, inputs=[number1, number2, number3], outputs=audio_output)
159
 
160
-
161
  if __name__ == "__main__":
162
  demo.launch()
 
2
  import note_seq
3
  import numpy as np
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ # Instrument list is imported but not currently used.
6
  from constants import GM_INSTRUMENTS
7
 
8
+ # Import the current midi_model
9
  tokenizer = AutoTokenizer.from_pretrained("Katpeeler/midi_model_3")
10
  model = AutoModelForCausalLM.from_pretrained("Katpeeler/midi_model_3")
11
 
12
+ # Define note and bar length, relative to 120bpm.
13
+ # This is overriden if the user adjusts the bpm
14
  NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / 120
15
  BAR_LENGTH_120BPM = 4.0 * 60 / 120
16
+ # Sample rate should never change, and should be imported from constants.
17
+ # I will do this once I confirm I can't use a higher sample rate for playing back audio here.
18
  SAMPLE_RATE=44100
19
 
20
+ # Main method for transposing from tokens back to midi notes.
21
+ # Can specify an instrument_mapper when ready to add more sounds
22
+ # THIS METHOD IS FROM DR.TRISTAN BEHRENS (https://huggingface.co/TristanBehrens)
23
  def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True, instrument_mapper=None, only_piano=False):
24
  if isinstance(token_sequence, str):
25
  token_sequence = token_sequence.split()
 
118
  note_sequence.total_time = total_time
119
  return note_sequence
120
 
121
+ # The process that is called when the user clicks the "generate audio" button.
122
+ # Currently takes in 3 number arguments, correlating to two parts of the input prompt,
123
+ # and the bpm.
124
  def process(num1, num2, num3):
125
+ # Prompt used to generate. I have this hard-coded currently to make generation smoother.
126
+ # I include the start of the midi file, style and genre (since they are unused), start a track,
127
+ # and allow the user to adjust the instrument number and the first note from the UI.
128
  created_text = f"""PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST={num1} BAR_START NOTE_ON={num2}"""
129
+
130
+ # adjustments for bpm
131
  global NOTE_LENGTH_16TH_120BPM
132
  NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / num3
133
  global BAR_LENGTH_120BPM
134
  BAR_LENGTH_120BPM = 4.0 * 60 / num3
135
+
136
+ # send the input prompt to the tokenizer, and generate
137
  input_ids = tokenizer.encode(created_text, return_tensors="pt")
138
  generated_ids = model.generate(input_ids, max_length=500)
139
  global generated_sequence
140
  generated_sequence = tokenizer.decode(generated_ids[0])
141
 
142
+ # Convert the text of notes to audio
143
  note_sequence = token_sequence_to_note_sequence(generated_sequence)
144
+ # The synth engine for playing sound
145
  synth = note_seq.midi_synth.synthesize
146
  array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
147
  note_plot = note_seq.plot_sequence(note_sequence, False)
148
  array_of_floats /=1.414
149
  array_of_floats *= 32767
150
  int16_data = array_of_floats.astype(np.int16)
151
+ # return the sampmle rate and array, needed for gradio audio widget
152
  return SAMPLE_RATE, int16_data
153
 
154
+ # simple call to show the generated tokens
155
  def generation():
156
  return generated_sequence
157
 
158
+ # unused call that was used to store instant feedback of the gradio sliders.
159
+ # I ended up using a simpler method for them, but am keeping this in case it becomes useful later.
160
  def identity(x, state):
161
  state += 1
162
  return x, state, state
163
 
164
+ # Gradio app structure
165
  with gr.Blocks() as demo:
166
+ # Title of the page
167
  gr.Markdown("Midi Generation")
168
+ # The audio generation tab
 
 
169
  with gr.Tab("Audio generation"):
170
+ # an audio widget
171
  audio_output = gr.Audio()
172
+ # the slider widgets for the user to adjust the values for generation
173
  number1 = gr.Slider(1, 100, value=25, label="Inst number", step=1, info="Choose between 1 and 100")
174
  number2 = gr.Slider(1, 100, value=40, label="Note number", step=1, info="Choose between 1 and 100")
175
  number3 = gr.Slider(60, 140, value=120, label="BPM", step=5, info="Choose between 60 and 140")
176
+ # the button to send the prompt
177
  audio_button = gr.Button("generate audio")
178
+ # the token generation tab
179
  with gr.Tab("Token generation"):
180
+ # a text widget to display the generated tokens
181
  text_output = gr.Textbox()
182
+ # the button to display the generated tokens
183
  text_button = gr.Button("show generated tokens")
184
+
185
+ # The definitions for button clicks
186
  text_button.click(generation, inputs=None, outputs=text_output)
187
  audio_button.click(process, inputs=[number1, number2, number3], outputs=audio_output)
188
 
189
+ # runs the application
190
  if __name__ == "__main__":
191
  demo.launch()