Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Commit
•
f97f1bb
1
Parent(s):
f117179
Rename app.py to musiclib.py
Browse files- app.py +0 -0
- musiclib.py +67 -0
app.py
DELETED
File without changes
|
musiclib.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# apache 2.0 license, modified by mrfakename, from https://github.com/BlinkDL/ChatRWKV/tree/main/music
|
2 |
+
|
3 |
+
import os, sys
|
4 |
+
import numpy as np
|
5 |
+
from cached_path import cached_path
|
6 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
7 |
+
|
8 |
+
os.environ['RWKV_JIT_ON'] = '1' #### set these before import RWKV
|
9 |
+
os.environ["RWKV_CUDA_ON"] = '0'
|
10 |
+
os.environ["RWKV_RESCALE_LAYER"] = '999' # must set this for RWKV-music models and "pip install rwkv --upgrade" to v0.8.12+
|
11 |
+
|
12 |
+
from rwkv.model import RWKV
|
13 |
+
from rwkv.utils import PIPELINE
|
14 |
+
|
15 |
+
MODEL_FILE = str(cached_path('hf://BlinkDL/rwkv-4-music/RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth'))
|
16 |
+
|
17 |
+
ABC_MODE = ('-ABC-' in MODEL_FILE)
|
18 |
+
MIDI_MODE = ('-MIDI-' in MODEL_FILE)
|
19 |
+
|
20 |
+
model = RWKV(model=MODEL_FILE, strategy='mps fp32')
|
21 |
+
pipeline = PIPELINE(model, "tokenizer-midi.json")
|
22 |
+
|
23 |
+
tokenizer = pipeline
|
24 |
+
EOS_ID = 0
|
25 |
+
TOKEN_SEP = ' '
|
26 |
+
|
27 |
+
def musicgen(ccc='<pad>', piano_only=False):
|
28 |
+
# ccc = '<pad>'
|
29 |
+
ccc_output = '<start>'
|
30 |
+
# ccc = "v:5b:3 v:5b:2 t125 t125 t125 t106 pi:43:5 t24 pi:4a:7 t15 pi:4f:7 t17 pi:56:7 t18 pi:54:7 t125 t49 pi:51:7 t117 pi:4d:7 t125 t125 t111 pi:37:7 t14 pi:3e:6 t15 pi:43:6 t12 pi:4a:7 t17 pi:48:7 t125 t60 pi:45:7 t121 pi:41:7 t125 t117 s:46:5 s:52:5 f:46:5 f:52:5 t121 s:45:5 s:46:0 s:51:5 s:52:0 f:45:5 f:46:0 f:51:5 f:52:0 t121 s:41:5 s:45:0 s:4d:5 s:51:0 f:41:5 f:45:0 f:4d:5 f:51:0 t102 pi:37:0 pi:3e:0 pi:41:0 pi:43:0 pi:45:0 pi:48:0 pi:4a:0 pi:4d:0 pi:4f:0 pi:51:0 pi:54:0 pi:56:0 t19 s:3e:5 s:41:0 s:4a:5 s:4d:0 f:3e:5 f:41:0 f:4a:5 f:4d:0 t121 v:3a:5 t121 v:39:7 t15 v:3a:0 t106 v:35:8 t10 v:39:0 t111 v:30:8 v:35:0 t125 t117 v:32:8 t10 v:30:0 t125 t125 t103 v:5b:0 v:5b:0 t9 pi:4a:7"
|
31 |
+
# ccc = '<pad> ' + ccc
|
32 |
+
# ccc_output = '<start> pi:4a:7'
|
33 |
+
output = ''
|
34 |
+
|
35 |
+
output += (ccc_output)
|
36 |
+
|
37 |
+
occurrence = {}
|
38 |
+
state = None
|
39 |
+
for i in range(4096): # only trained with ctx4096 (will be longer soon)
|
40 |
+
if i == 0:
|
41 |
+
out, state = model.forward(tokenizer.encode(ccc), state)
|
42 |
+
else:
|
43 |
+
out, state = model.forward([token], state)
|
44 |
+
|
45 |
+
if MIDI_MODE: # seems only required for MIDI mode
|
46 |
+
for n in occurrence:
|
47 |
+
out[n] -= (0 + occurrence[n] * 0.5)
|
48 |
+
|
49 |
+
out[0] += (i - 2000) / 500 # not too short, not too long
|
50 |
+
out[127] -= 1 # avoid "t125"
|
51 |
+
if piano_only:
|
52 |
+
out[128:12416] -= 1e10
|
53 |
+
out[13952:20096] -= 1e10
|
54 |
+
token = pipeline.sample_logits(out, temperature=1.0, top_k=8, top_p=0.8)
|
55 |
+
if token == EOS_ID: break
|
56 |
+
|
57 |
+
if MIDI_MODE: # seems only required for MIDI mode
|
58 |
+
for n in occurrence: occurrence[n] *= 0.997 #### decay repetition penalty
|
59 |
+
if token >= 128 or token == 127:
|
60 |
+
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
61 |
+
else:
|
62 |
+
occurrence[token] = 0.3 + (occurrence[token] if token in occurrence else 0)
|
63 |
+
|
64 |
+
output += (TOKEN_SEP + tokenizer.decode([token]))
|
65 |
+
|
66 |
+
output += (' <end>')
|
67 |
+
return output
|