monet-joe commited on
Commit
e1ce1aa
1 Parent(s): afd60f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -33
app.py CHANGED
@@ -4,18 +4,17 @@ import time
4
  import torch
5
  import shutil
6
  import argparse
 
7
  import gradio as gr
8
  from config import *
9
  from utils import Patchilizer, TunesFormer, DEVICE
10
- from convert import abc_to_midi, abc_to_musicxml, musicxml_to_mxl, mxl2jpg, midi2wav
11
  from modelscope import snapshot_download
12
  from transformers import GPT2Config
13
- import warnings
14
 
15
- warnings.filterwarnings("ignore")
16
 
17
  # 模型下载
18
- MODEL_DIR = snapshot_download("MuGeminorum/hoyoGPT")
19
 
20
 
21
  def get_args(parser: argparse.ArgumentParser):
@@ -61,7 +60,7 @@ def get_args(parser: argparse.ArgumentParser):
61
  return args
62
 
63
 
64
- def generate_abc(args, epochs: str, region: str):
65
  patchilizer = Patchilizer()
66
 
67
  patch_config = GPT2Config(
@@ -79,14 +78,11 @@ def generate_abc(args, epochs: str, region: str):
79
  )
80
 
81
  model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
82
- filename = f"{MODEL_DIR}/{epochs}/weights.pth"
83
- checkpoint = torch.load(filename, map_location=torch.device("cpu"))
84
  model.load_state_dict(checkpoint["model"])
85
  model = model.to(DEVICE)
86
  model.eval()
87
-
88
  prompt = f"A:{region}\n"
89
-
90
  tunes = ""
91
  num_tunes = args.num_tunes
92
  max_patch = args.max_patch
@@ -96,14 +92,12 @@ def generate_abc(args, epochs: str, region: str):
96
  seed = args.seed
97
  show_control_code = args.show_control_code
98
 
99
- print(" HYPERPARAMETERS ".center(60, "#"), "\n")
100
  arg_dict: dict = vars(args)
101
-
102
  for key in arg_dict.keys():
103
  print(f"{key}: {str(arg_dict[key])}")
104
 
105
- print("\n", " OUTPUT TUNES ".center(60, "#"))
106
-
107
  start_time = time.time()
108
  for i in range(num_tunes):
109
  title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
@@ -179,40 +173,37 @@ def generate_abc(args, epochs: str, region: str):
179
  os.makedirs(TEMP_DIR, exist_ok=True)
180
  timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
181
  try:
182
- out_midi = abc_to_midi(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.mid")
183
- out_xml = abc_to_musicxml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
184
- out_mxl = musicxml_to_mxl(f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
185
- pdf_file, jpg_file = mxl2jpg(out_mxl)
186
- wav_file = midi2wav(out_midi)
187
 
188
- return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
189
 
190
  except Exception as e:
191
  print(f"Invalid abc generated: {e}, retrying...")
192
- return generate_abc(args, epochs, region)
193
 
194
 
195
- def inference(epochs: str, region: str):
196
  if os.path.exists(TEMP_DIR):
197
  shutil.rmtree(TEMP_DIR)
198
 
199
  parser = argparse.ArgumentParser()
200
  args = get_args(parser)
201
- return generate_abc(args, epochs, region)
202
 
203
 
204
  if __name__ == "__main__":
 
 
205
  with gr.Blocks() as demo:
206
  gr.Markdown(
207
- """<center>Welcome to this Space, made by bilibili <a href="https://space.bilibili.com/30620472">@MuGeminorum</a> based on the Tunesformer open source project, completely free.</center>"""
208
  )
209
  with gr.Row():
210
  with gr.Column():
211
- weight_opt = gr.Dropdown(
212
- choices=["5", "15", "20"],
213
- value="15",
214
- label="Model Selection(epochs)",
215
- )
216
  region_opt = gr.Dropdown(
217
  choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
218
  value="Mondstadt",
@@ -222,9 +213,7 @@ if __name__ == "__main__":
222
  gr.Markdown(
223
  """
224
  <center>
225
- Currently, the model is still under debugging, and since the training data is converted from MIDI, there are a lot of spectral irregularities, which leads to many generated results with spectral irregularities and other problems.<br>
226
-
227
- Planned in the Genshin main line killed, all countries and regions after all the characters are open, the second creation of the concert will be complete and balanced samples, then re-fine-tune the model and add the reality of the style of screening to assist the game of the various countries output gatekeepers, in order to enhance the output of the differentiation and quality.
228
 
229
  Data source: <a href="https://musescore.org">MuseScore</a><br>
230
  Tag embedded data source: <a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
@@ -236,7 +225,7 @@ if __name__ == "__main__":
236
  with gr.Column():
237
  wav_output = gr.Audio(label="Audio", type="filepath")
238
  dld_midi = gr.components.File(label="Download MIDI")
239
- pdf_score = gr.components.File(label="Download PDF Score")
240
  dld_xml = gr.components.File(label="Download MusicXML")
241
  dld_mxl = gr.components.File(label="Download MXL")
242
  abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
@@ -244,7 +233,7 @@ if __name__ == "__main__":
244
 
245
  gen_btn.click(
246
  inference,
247
- inputs=[weight_opt, region_opt],
248
  outputs=[
249
  abc_output,
250
  dld_midi,
 
4
  import torch
5
  import shutil
6
  import argparse
7
+ import warnings
8
  import gradio as gr
9
  from config import *
10
  from utils import Patchilizer, TunesFormer, DEVICE
11
+ from convert import abc2xml, xml2, xml2img
12
  from modelscope import snapshot_download
13
  from transformers import GPT2Config
 
14
 
 
15
 
16
  # 模型下载
17
+ WEIGHTS_PATH = snapshot_download("MuGeminorum/hoyoGPT") + "/weights.pth"
18
 
19
 
20
  def get_args(parser: argparse.ArgumentParser):
 
60
  return args
61
 
62
 
63
+ def generate_abc(args, region: str):
64
  patchilizer = Patchilizer()
65
 
66
  patch_config = GPT2Config(
 
78
  )
79
 
80
  model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
81
+ checkpoint = torch.load(WEIGHTS_PATH, map_location=torch.device("cpu"))
 
82
  model.load_state_dict(checkpoint["model"])
83
  model = model.to(DEVICE)
84
  model.eval()
 
85
  prompt = f"A:{region}\n"
 
86
  tunes = ""
87
  num_tunes = args.num_tunes
88
  max_patch = args.max_patch
 
92
  seed = args.seed
93
  show_control_code = args.show_control_code
94
 
95
+ print(" Hyper parms ".center(60, "#"), "\n")
96
  arg_dict: dict = vars(args)
 
97
  for key in arg_dict.keys():
98
  print(f"{key}: {str(arg_dict[key])}")
99
 
100
+ print("\n", " Output tunes ".center(60, "#"))
 
101
  start_time = time.time()
102
  for i in range(num_tunes):
103
  title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
 
173
  os.makedirs(TEMP_DIR, exist_ok=True)
174
  timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
175
  try:
176
+ xml = abc2xml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
177
+ midi = xml2(xml, "mid")
178
+ audio = xml2(xml, "wav")
179
+ pdf, jpg = xml2img(xml)
180
+ mxl = xml2(xml, "mxl")
181
 
182
+ return tunes, midi, pdf, xml, mxl, jpg, audio
183
 
184
  except Exception as e:
185
  print(f"Invalid abc generated: {e}, retrying...")
186
+ return generate_abc(args, region)
187
 
188
 
189
+ def inference(region: str):
190
  if os.path.exists(TEMP_DIR):
191
  shutil.rmtree(TEMP_DIR)
192
 
193
  parser = argparse.ArgumentParser()
194
  args = get_args(parser)
195
+ return generate_abc(args, region)
196
 
197
 
198
  if __name__ == "__main__":
199
+ warnings.filterwarnings("ignore")
200
+
201
  with gr.Blocks() as demo:
202
  gr.Markdown(
203
+ """<center>Welcome to this space, made by bilibili <a href="https://space.bilibili.com/30620472">@MuGeminorum</a> based on Tunesformer open source project, totally free.</center>"""
204
  )
205
  with gr.Row():
206
  with gr.Column():
 
 
 
 
 
207
  region_opt = gr.Dropdown(
208
  choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
209
  value="Mondstadt",
 
213
  gr.Markdown(
214
  """
215
  <center>
216
+ Currently, the model is still under debugging. Planned in the Genshin main line killed, all countries and regions after all the characters are open, the second creation of the concert will be complete and balanced samples, then re-fine-tune the model and add the reality of the style of screening to assist the game of the various countries output gatekeepers, in order to enhance the output of the differentiation and quality.
 
 
217
 
218
  Data source: <a href="https://musescore.org">MuseScore</a><br>
219
  Tag embedded data source: <a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
 
225
  with gr.Column():
226
  wav_output = gr.Audio(label="Audio", type="filepath")
227
  dld_midi = gr.components.File(label="Download MIDI")
228
+ pdf_score = gr.components.File(label="Download PDF score")
229
  dld_xml = gr.components.File(label="Download MusicXML")
230
  dld_mxl = gr.components.File(label="Download MXL")
231
  abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
 
233
 
234
  gen_btn.click(
235
  inference,
236
+ inputs=region_opt,
237
  outputs=[
238
  abc_output,
239
  dld_midi,