skytnt commited on
Commit
2b86b7e
·
1 Parent(s): 469ef25
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -100,9 +100,9 @@ def get_duration(model_name, tab, mid_seq, continuation_state, instruments, drum
100
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
101
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
102
  if "large" in model_name:
103
- return gen_events // 10 + 10
104
  else:
105
- return gen_events // 20 + 10
106
 
107
 
108
  @spaces.GPU(duration=get_duration)
@@ -110,7 +110,7 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
110
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
111
  gen_events, temp, top_p, top_k, allow_cc):
112
  model = models[model_name]
113
- model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
114
  tokenizer = model.tokenizer
115
  bpm = int(bpm)
116
  if time_sig == "auto":
@@ -302,8 +302,8 @@ if __name__ == "__main__":
302
  "generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
303
  "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
304
  "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
305
- "j-pop finetune model (tv1-medium) by skytnt": ["skytnt/midi-model-ft", "jpop/", "tv1-medium"],
306
- "touhou finetune model (tv1-medium) by skytnt": ["skytnt/midi-model-ft", "touhou/", "tv1-medium"],
307
  }
308
  models = {}
309
  if opt.device == "cuda":
@@ -315,6 +315,7 @@ if __name__ == "__main__":
315
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
316
  state_dict = ckpt.get("state_dict", ckpt)
317
  model.load_state_dict(state_dict, strict=False)
 
318
  models[name] = model
319
 
320
  load_javascript()
 
100
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
101
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
102
  if "large" in model_name:
103
+ return gen_events // 10 + 15
104
  else:
105
+ return gen_events // 20 + 15
106
 
107
 
108
  @spaces.GPU(duration=get_duration)
 
110
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
111
  gen_events, temp, top_p, top_k, allow_cc):
112
  model = models[model_name]
113
+ model.to(device=opt.device)
114
  tokenizer = model.tokenizer
115
  bpm = int(bpm)
116
  if time_sig == "auto":
 
302
  "generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
303
  "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
304
  "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
305
+ "j-pop finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "jpop-tv2o-medium/", "tv2o-medium"],
306
+ "touhou finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "touhou-tv2o-medium/", "tv2o-medium"],
307
  }
308
  models = {}
309
  if opt.device == "cuda":
 
315
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
316
  state_dict = ckpt.get("state_dict", ckpt)
317
  model.load_state_dict(state_dict, strict=False)
318
+ model.to(device="cpu", dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32)
319
  models[name] = model
320
 
321
  load_javascript()