skytnt commited on
Commit
183a87e
·
1 Parent(s): 94ac77e
Files changed (2) hide show
  1. app.py +32 -12
  2. midi_model.py +5 -1
app.py CHANGED
@@ -142,7 +142,12 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
142
  def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
143
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
144
  seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
145
- model = models[model_name]
 
 
 
 
 
146
  model.to(device=opt.device)
147
  tokenizer = model.tokenizer
148
  bpm = int(bpm)
@@ -253,7 +258,7 @@ def finish_run(model_name, mid_seq):
253
  if mid_seq is None:
254
  outputs = [None] * OUTPUT_BATCH_SIZE
255
  return *outputs, []
256
- tokenizer = models[model_name].tokenizer
257
  outputs = []
258
  end_msgs = [create_msg("progress", [0, 0])]
259
  if not os.path.exists("outputs"):
@@ -277,7 +282,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
277
  if (not should_render_audio) or mid_seq is None:
278
  outputs = [None] * OUTPUT_BATCH_SIZE
279
  return tuple(outputs)
280
- tokenizer = models[model_name].tokenizer
281
  outputs = []
282
  if not os.path.exists("outputs"):
283
  os.mkdir("outputs")
@@ -294,7 +299,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
294
  def undo_continuation(model_name, mid_seq, continuation_state):
295
  if mid_seq is None or len(continuation_state) < 2:
296
  return mid_seq, continuation_state, send_msgs([])
297
- tokenizer = models[model_name].tokenizer
298
  if isinstance(continuation_state[-1], list):
299
  mid_seq = continuation_state[-1]
300
  else:
@@ -364,12 +369,21 @@ if __name__ == "__main__":
364
  thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
365
  synthesizer = MidiSynthesizer(soundfont_path)
366
  models_info = {
367
- "generic pretrain model (tv2o-medium) by skytnt": ["skytnt/midi-model-tv2o-medium", "", "tv2o-medium"],
368
- "generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
369
- "generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
370
- "generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
371
- "j-pop finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "jpop-tv2o-medium/", "tv2o-medium"],
372
- "touhou finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "touhou-tv2o-medium/", "tv2o-medium"],
 
 
 
 
 
 
 
 
 
373
  }
374
  models = {}
375
  if opt.device == "cuda":
@@ -379,14 +393,20 @@ if __name__ == "__main__":
379
  torch.backends.cudnn.allow_tf32 = True
380
  torch.backends.cuda.enable_mem_efficient_sdp(True)
381
  torch.backends.cuda.enable_flash_sdp(True)
382
- for name, (repo_id, path, config) in models_info.items():
383
  model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
384
  model = MIDIModel(config=MIDIModelConfig.from_name(config))
385
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
386
  state_dict = ckpt.get("state_dict", ckpt)
387
  model.load_state_dict(state_dict, strict=False)
 
 
 
 
388
  model.to(device="cpu", dtype=torch.float32).eval()
389
- models[name] = model
 
 
390
 
391
  load_javascript()
392
  app = gr.Blocks()
 
142
  def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
143
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
144
  seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
145
+ model, lora_name = models[model_name]
146
+ if lora_name is None and model.peft_loaded():
147
+ model.disable_adapters()
148
+ elif lora_name is not None:
149
+ model.enable_adapters()
150
+ model.set_adapter(lora_name)
151
  model.to(device=opt.device)
152
  tokenizer = model.tokenizer
153
  bpm = int(bpm)
 
258
  if mid_seq is None:
259
  outputs = [None] * OUTPUT_BATCH_SIZE
260
  return *outputs, []
261
+ tokenizer = models[model_name][0].tokenizer
262
  outputs = []
263
  end_msgs = [create_msg("progress", [0, 0])]
264
  if not os.path.exists("outputs"):
 
282
  if (not should_render_audio) or mid_seq is None:
283
  outputs = [None] * OUTPUT_BATCH_SIZE
284
  return tuple(outputs)
285
+ tokenizer = models[model_name][0].tokenizer
286
  outputs = []
287
  if not os.path.exists("outputs"):
288
  os.mkdir("outputs")
 
299
  def undo_continuation(model_name, mid_seq, continuation_state):
300
  if mid_seq is None or len(continuation_state) < 2:
301
  return mid_seq, continuation_state, send_msgs([])
302
+ tokenizer = models[model_name][0].tokenizer
303
  if isinstance(continuation_state[-1], list):
304
  mid_seq = continuation_state[-1]
305
  else:
 
369
  thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
370
  synthesizer = MidiSynthesizer(soundfont_path)
371
  models_info = {
372
+ "generic pretrain model (tv2o-medium) by skytnt": [
373
+ "skytnt/midi-model-tv2o-medium", "", "tv2o-medium", {
374
+ "jpop": "skytnt/midi-model-tv2om-jpop-lora",
375
+ "touhou": "skytnt/midi-model-tv2om-touhou-lora"
376
+ }
377
+ ],
378
+ "generic pretrain model (tv2o-large) by asigalov61": [
379
+ "asigalov61/Music-Llama", "", "tv2o-large", {}
380
+ ],
381
+ "generic pretrain model (tv2o-medium) by asigalov61": [
382
+ "asigalov61/Music-Llama-Medium", "", "tv2o-medium", {}
383
+ ],
384
+ "generic pretrain model (tv1-medium) by skytnt": [
385
+ "skytnt/midi-model", "", "tv1-medium", {}
386
+ ]
387
  }
388
  models = {}
389
  if opt.device == "cuda":
 
393
  torch.backends.cudnn.allow_tf32 = True
394
  torch.backends.cuda.enable_mem_efficient_sdp(True)
395
  torch.backends.cuda.enable_flash_sdp(True)
396
+ for name, (repo_id, path, config, loras) in models_info.items():
397
  model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
398
  model = MIDIModel(config=MIDIModelConfig.from_name(config))
399
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
400
  state_dict = ckpt.get("state_dict", ckpt)
401
  model.load_state_dict(state_dict, strict=False)
402
+ for lora_name, lora_repo in loras.items():
403
+ model.load_adapter(lora_repo, lora_name)
404
+ if loras:
405
+ model.disable_adapters()
406
  model.to(device="cpu", dtype=torch.float32).eval()
407
+ models[name] = model, None
408
+ for lora_name, lora_repo in loras.items():
409
+ models[f"{name} with {lora_name} lora"] = model, lora_name
410
 
411
  load_javascript()
412
  app = gr.Blocks()
midi_model.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  import tqdm
8
  from transformers import LlamaModel, LlamaConfig
 
9
 
10
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
11
 
@@ -55,7 +56,7 @@ class MIDIModelConfig:
55
  raise ValueError(f"Unknown model size {size}")
56
 
57
 
58
- class MIDIModel(nn.Module):
59
  def __init__(self, config: MIDIModelConfig, *args, **kwargs):
60
  super(MIDIModel, self).__init__()
61
  self.tokenizer = config.tokenizer
@@ -69,6 +70,9 @@ class MIDIModel(nn.Module):
69
  self.device = kwargs["device"]
70
  return super(MIDIModel, self).to(*args, **kwargs)
71
 
 
 
 
72
  def forward_token(self, hidden_state, x=None):
73
  """
74
 
 
6
  import torch.nn.functional as F
7
  import tqdm
8
  from transformers import LlamaModel, LlamaConfig
9
+ from transformers.integrations import PeftAdapterMixin
10
 
11
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
12
 
 
56
  raise ValueError(f"Unknown model size {size}")
57
 
58
 
59
+ class MIDIModel(nn.Module, PeftAdapterMixin):
60
  def __init__(self, config: MIDIModelConfig, *args, **kwargs):
61
  super(MIDIModel, self).__init__()
62
  self.tokenizer = config.tokenizer
 
70
  self.device = kwargs["device"]
71
  return super(MIDIModel, self).to(*args, **kwargs)
72
 
73
+ def peft_loaded(self):
74
+ return self._hf_peft_config_loaded
75
+
76
  def forward_token(self, hidden_state, x=None):
77
  """
78