Spaces:
Paused
Paused
add lora
Browse files- app.py +32 -12
- midi_model.py +5 -1
app.py
CHANGED
@@ -142,7 +142,12 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
|
|
142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
145 |
-
model = models[model_name]
|
|
|
|
|
|
|
|
|
|
|
146 |
model.to(device=opt.device)
|
147 |
tokenizer = model.tokenizer
|
148 |
bpm = int(bpm)
|
@@ -253,7 +258,7 @@ def finish_run(model_name, mid_seq):
|
|
253 |
if mid_seq is None:
|
254 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
255 |
return *outputs, []
|
256 |
-
tokenizer = models[model_name].tokenizer
|
257 |
outputs = []
|
258 |
end_msgs = [create_msg("progress", [0, 0])]
|
259 |
if not os.path.exists("outputs"):
|
@@ -277,7 +282,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
277 |
if (not should_render_audio) or mid_seq is None:
|
278 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
279 |
return tuple(outputs)
|
280 |
-
tokenizer = models[model_name].tokenizer
|
281 |
outputs = []
|
282 |
if not os.path.exists("outputs"):
|
283 |
os.mkdir("outputs")
|
@@ -294,7 +299,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
294 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
295 |
if mid_seq is None or len(continuation_state) < 2:
|
296 |
return mid_seq, continuation_state, send_msgs([])
|
297 |
-
tokenizer = models[model_name].tokenizer
|
298 |
if isinstance(continuation_state[-1], list):
|
299 |
mid_seq = continuation_state[-1]
|
300 |
else:
|
@@ -364,12 +369,21 @@ if __name__ == "__main__":
|
|
364 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
365 |
synthesizer = MidiSynthesizer(soundfont_path)
|
366 |
models_info = {
|
367 |
-
"generic pretrain model (tv2o-medium) by skytnt": [
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
}
|
374 |
models = {}
|
375 |
if opt.device == "cuda":
|
@@ -379,14 +393,20 @@ if __name__ == "__main__":
|
|
379 |
torch.backends.cudnn.allow_tf32 = True
|
380 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
381 |
torch.backends.cuda.enable_flash_sdp(True)
|
382 |
-
for name, (repo_id, path, config) in models_info.items():
|
383 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
384 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
385 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
386 |
state_dict = ckpt.get("state_dict", ckpt)
|
387 |
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
388 |
model.to(device="cpu", dtype=torch.float32).eval()
|
389 |
-
models[name] = model
|
|
|
|
|
390 |
|
391 |
load_javascript()
|
392 |
app = gr.Blocks()
|
|
|
142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
145 |
+
model, lora_name = models[model_name]
|
146 |
+
if lora_name is None and model.peft_loaded():
|
147 |
+
model.disable_adapters()
|
148 |
+
elif lora_name is not None:
|
149 |
+
model.enable_adapters()
|
150 |
+
model.set_adapter(lora_name)
|
151 |
model.to(device=opt.device)
|
152 |
tokenizer = model.tokenizer
|
153 |
bpm = int(bpm)
|
|
|
258 |
if mid_seq is None:
|
259 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
260 |
return *outputs, []
|
261 |
+
tokenizer = models[model_name][0].tokenizer
|
262 |
outputs = []
|
263 |
end_msgs = [create_msg("progress", [0, 0])]
|
264 |
if not os.path.exists("outputs"):
|
|
|
282 |
if (not should_render_audio) or mid_seq is None:
|
283 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
284 |
return tuple(outputs)
|
285 |
+
tokenizer = models[model_name][0].tokenizer
|
286 |
outputs = []
|
287 |
if not os.path.exists("outputs"):
|
288 |
os.mkdir("outputs")
|
|
|
299 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
300 |
if mid_seq is None or len(continuation_state) < 2:
|
301 |
return mid_seq, continuation_state, send_msgs([])
|
302 |
+
tokenizer = models[model_name][0].tokenizer
|
303 |
if isinstance(continuation_state[-1], list):
|
304 |
mid_seq = continuation_state[-1]
|
305 |
else:
|
|
|
369 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
370 |
synthesizer = MidiSynthesizer(soundfont_path)
|
371 |
models_info = {
|
372 |
+
"generic pretrain model (tv2o-medium) by skytnt": [
|
373 |
+
"skytnt/midi-model-tv2o-medium", "", "tv2o-medium", {
|
374 |
+
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
375 |
+
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
376 |
+
}
|
377 |
+
],
|
378 |
+
"generic pretrain model (tv2o-large) by asigalov61": [
|
379 |
+
"asigalov61/Music-Llama", "", "tv2o-large", {}
|
380 |
+
],
|
381 |
+
"generic pretrain model (tv2o-medium) by asigalov61": [
|
382 |
+
"asigalov61/Music-Llama-Medium", "", "tv2o-medium", {}
|
383 |
+
],
|
384 |
+
"generic pretrain model (tv1-medium) by skytnt": [
|
385 |
+
"skytnt/midi-model", "", "tv1-medium", {}
|
386 |
+
]
|
387 |
}
|
388 |
models = {}
|
389 |
if opt.device == "cuda":
|
|
|
393 |
torch.backends.cudnn.allow_tf32 = True
|
394 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
395 |
torch.backends.cuda.enable_flash_sdp(True)
|
396 |
+
for name, (repo_id, path, config, loras) in models_info.items():
|
397 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
398 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
399 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
400 |
state_dict = ckpt.get("state_dict", ckpt)
|
401 |
model.load_state_dict(state_dict, strict=False)
|
402 |
+
for lora_name, lora_repo in loras.items():
|
403 |
+
model.load_adapter(lora_repo, lora_name)
|
404 |
+
if loras:
|
405 |
+
model.disable_adapters()
|
406 |
model.to(device="cpu", dtype=torch.float32).eval()
|
407 |
+
models[name] = model, None
|
408 |
+
for lora_name, lora_repo in loras.items():
|
409 |
+
models[f"{name} with {lora_name} lora"] = model, lora_name
|
410 |
|
411 |
load_javascript()
|
412 |
app = gr.Blocks()
|
midi_model.py
CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
|
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
8 |
from transformers import LlamaModel, LlamaConfig
|
|
|
9 |
|
10 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
11 |
|
@@ -55,7 +56,7 @@ class MIDIModelConfig:
|
|
55 |
raise ValueError(f"Unknown model size {size}")
|
56 |
|
57 |
|
58 |
-
class MIDIModel(nn.Module):
|
59 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
60 |
super(MIDIModel, self).__init__()
|
61 |
self.tokenizer = config.tokenizer
|
@@ -69,6 +70,9 @@ class MIDIModel(nn.Module):
|
|
69 |
self.device = kwargs["device"]
|
70 |
return super(MIDIModel, self).to(*args, **kwargs)
|
71 |
|
|
|
|
|
|
|
72 |
def forward_token(self, hidden_state, x=None):
|
73 |
"""
|
74 |
|
|
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
8 |
from transformers import LlamaModel, LlamaConfig
|
9 |
+
from transformers.integrations import PeftAdapterMixin
|
10 |
|
11 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
12 |
|
|
|
56 |
raise ValueError(f"Unknown model size {size}")
|
57 |
|
58 |
|
59 |
+
class MIDIModel(nn.Module, PeftAdapterMixin):
|
60 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
61 |
super(MIDIModel, self).__init__()
|
62 |
self.tokenizer = config.tokenizer
|
|
|
70 |
self.device = kwargs["device"]
|
71 |
return super(MIDIModel, self).to(*args, **kwargs)
|
72 |
|
73 |
+
def peft_loaded(self):
|
74 |
+
return self._hf_peft_config_loaded
|
75 |
+
|
76 |
def forward_token(self, hidden_state, x=None):
|
77 |
"""
|
78 |
|