Spaces:
Running
on
Zero
Running
on
Zero
erl-j
commited on
Commit
·
b362624
0
Parent(s):
first commit
Browse files- .gitattributes +37 -0
- .gitignore +4 -0
- README.md +12 -0
- app.py +200 -0
- custom.css +203 -0
- custom.js +355 -0
- requirements.txt +9 -0
- sf-creator-fork/.gitignore +3 -0
- sf-creator-fork/README.md +90 -0
- sf-creator-fork/main.py +72 -0
- sf-creator-fork/requirements.txt +2 -0
- sf-creator-fork/resources/decentsampler/dspreset.xsd +134 -0
- sf-creator-fork/sfcreator/__init__.py +0 -0
- sf-creator-fork/sfcreator/decentsampler/__init__.py +47 -0
- sf-creator-fork/sfcreator/sfz/__init__.py +37 -0
- sf-creator-fork/sfcreator/soundfont/__init__.py +0 -0
- sf-creator-fork/sfcreator/soundfont/soundfont.py +119 -0
- sf-creator-fork/sfcreator/soundfont/test_soundfont.py +70 -0
- train_lfm.py +314 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
# everything in assets is large
|
37 |
+
assets/**/* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
*/.DS_Store
|
3 |
+
__pycache__/
|
4 |
+
output/
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: soundfont-generator
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.8.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import einops
|
3 |
+
import gradio as gr
|
4 |
+
import datetime
|
5 |
+
import numpy as np
|
6 |
+
import spaces
|
7 |
+
import soundfile
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import zipfile
|
11 |
+
from pathlib import Path
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
sys.path.append("sf-creator-fork")
|
15 |
+
from main import sfz, decentsampler
|
16 |
+
|
17 |
+
decoder_path = "erl-j/soundfont-generator-assets/decoder.pt"
|
18 |
+
model_path = "erl-j/soundfont-generator-assets/synth_lfm_modern_bfloat16.pt"
|
19 |
+
# Download models from Hugging Face Hub
|
20 |
+
decoder_path = hf_hub_download("erl-j/soundfont-generator-assets", "decoder.pt")
|
21 |
+
model_path = hf_hub_download("erl-j/soundfont-generator-assets", "synth_lfm_modern_bfloat16.pt")
|
22 |
+
|
23 |
+
# Load models once at startup
|
24 |
+
device = "cuda"
|
25 |
+
decoder = torch.load(decoder_path, map_location=device).half().eval()
|
26 |
+
model = (
|
27 |
+
torch.load(model_path, map_location=device)
|
28 |
+
.half()
|
29 |
+
.eval()
|
30 |
+
)
|
31 |
+
|
32 |
+
@spaces.GPU
|
33 |
+
def generate_and_export_soundfont(text, steps=20, instrument_name=None):
|
34 |
+
sample_start = datetime.datetime.now()
|
35 |
+
|
36 |
+
# Generate audio as before
|
37 |
+
z = model.sample(1, text=[text], steps=steps)
|
38 |
+
z_reshaped = einops.rearrange(z, "b t c d -> (b c) d t")
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
audio = decoder.decode(z_reshaped)
|
42 |
+
|
43 |
+
audio_output = einops.rearrange(audio, "b c t -> c (b t)").cpu().numpy()
|
44 |
+
audio_output = audio_output / np.max(np.abs(audio_output))
|
45 |
+
|
46 |
+
# Export individual wav files
|
47 |
+
export_audio = audio.cpu().numpy().astype(np.float32)
|
48 |
+
output_dir = "output"
|
49 |
+
os.makedirs(output_dir, exist_ok=True)
|
50 |
+
|
51 |
+
# Create instrument name if not provided
|
52 |
+
if not instrument_name:
|
53 |
+
instrument_name = text.replace(" ", "_")[:20]
|
54 |
+
|
55 |
+
# Save individual WAV files
|
56 |
+
pitches = [
|
57 |
+
"C1",
|
58 |
+
"F#1",
|
59 |
+
"C2",
|
60 |
+
"F#2",
|
61 |
+
"C3",
|
62 |
+
"F#3",
|
63 |
+
"C4",
|
64 |
+
"F#4",
|
65 |
+
"C5",
|
66 |
+
"F#5",
|
67 |
+
"C6",
|
68 |
+
"F#6",
|
69 |
+
"C7",
|
70 |
+
"F#7",
|
71 |
+
"C8",
|
72 |
+
]
|
73 |
+
wav_files = []
|
74 |
+
for i in range(audio.shape[0]):
|
75 |
+
wav_path = f"{output_dir}/{pitches[i]}.wav"
|
76 |
+
soundfile.write(wav_path, export_audio[i].T, 44100)
|
77 |
+
wav_files.append(wav_path)
|
78 |
+
|
79 |
+
# Generate SFZ file
|
80 |
+
sfz(
|
81 |
+
directory=output_dir,
|
82 |
+
lowkey="21",
|
83 |
+
highkey="108",
|
84 |
+
instrument=instrument_name,
|
85 |
+
loopmode="no_loop",
|
86 |
+
polyphony=None,
|
87 |
+
)
|
88 |
+
|
89 |
+
# Create zip file containing SFZ and WAV files for the complete soundfont
|
90 |
+
zip_path = f"{output_dir}/{instrument_name}_package.zip"
|
91 |
+
with zipfile.ZipFile(zip_path, "w") as zipf:
|
92 |
+
# Add SFZ file
|
93 |
+
sfz_file = f"{output_dir}/{instrument_name}.sfz"
|
94 |
+
zipf.write(sfz_file, os.path.basename(sfz_file))
|
95 |
+
# Add all WAV files
|
96 |
+
for wav_file in wav_files:
|
97 |
+
if os.path.exists(wav_file):
|
98 |
+
zipf.write(wav_file, os.path.basename(wav_file))
|
99 |
+
|
100 |
+
total_time = (datetime.datetime.now() - sample_start).total_seconds()
|
101 |
+
|
102 |
+
return (
|
103 |
+
(44100, audio_output.T),
|
104 |
+
f"Generation took {total_time:.2f}s\nFiles saved in {output_dir}",
|
105 |
+
zip_path,
|
106 |
+
wav_files,
|
107 |
+
)
|
108 |
+
|
109 |
+
custom_js = open("custom.js").read()
|
110 |
+
custom_css = open("custom.css").read()
|
111 |
+
|
112 |
+
demo = gr.Blocks(title="Erl-j's sound font generator", js=custom_js,
|
113 |
+
css = custom_css)
|
114 |
+
|
115 |
+
with demo:
|
116 |
+
gr.Markdown("""
|
117 |
+
# Erl-j's Soundfont Generator.
|
118 |
+
Generate soundfonts from text descriptions using latent flow matching. You can then download the complete SFZ soundfont package to use the instrument locally.
|
119 |
+
## Instructions
|
120 |
+
1. Enter a text prompt to describe the audio you want to generate.
|
121 |
+
2. Adjust the number of generation steps to tradeoff between quality and speed (kindof).
|
122 |
+
3. Click the "Generate Soundfont" button to generate the audio and soundfont.
|
123 |
+
4. Preview the generated instrument with the keyboard.
|
124 |
+
5. Export the soundfont by clicking the "Download SFZ Soundfont Package" button. You can then use the soundfont in a SFZ-compatible VST like [Sforzando](https://www.plogue.com/products/sforzando/).
|
125 |
+
""")
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
with gr.Row():
|
130 |
+
steps = gr.Slider(
|
131 |
+
minimum=1, maximum=50, value=20, step=1, label="Generation steps"
|
132 |
+
)
|
133 |
+
|
134 |
+
with gr.Row():
|
135 |
+
text_input = gr.Textbox(
|
136 |
+
label="Prompt",
|
137 |
+
placeholder="Enter text description (e.g. 'hard bass', 'sparkly bells')",
|
138 |
+
lines=2,
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
generate_btn = gr.Button("Generate Soundfont", variant="primary")
|
143 |
+
|
144 |
+
with gr.Row():
|
145 |
+
audio_output = gr.Audio(label="Generated Audio Preview", visible=False)
|
146 |
+
status_output = gr.Textbox(label="Status", lines=2, visible=False)
|
147 |
+
|
148 |
+
with gr.Row():
|
149 |
+
wav_files = gr.File(label="Individual WAV Files", file_count="multiple", visible=False, elem_id="individual-wav-files")
|
150 |
+
|
151 |
+
html = """
|
152 |
+
<div id="custom-player"
|
153 |
+
style="width: 100%; height: 600px; background-color: "red"; border: 1px solid #f8f9fa; border-radius: 5px; margin-top: 10px;"
|
154 |
+
></div>
|
155 |
+
"""
|
156 |
+
|
157 |
+
gr.HTML(html, min_height=800, max_height=800)
|
158 |
+
with gr.Row():
|
159 |
+
sf = gr.File(label="Download SFZ Soundfont Package", type="filepath", visible=True, elem_id="sfz")
|
160 |
+
|
161 |
+
gr.Markdown("""
|
162 |
+
# About
|
163 |
+
The model is a modified version of [stable audio open](https://huggingface.co/stabilityai/stable-audio-open-1.0).
|
164 |
+
|
165 |
+
Unlike the original model, this version uses latent flow matching rather than latent diffusion.
|
166 |
+
Secondly, the pitches are stacked in a channel dimension rather than concatenated in the time dimension.
|
167 |
+
This allows for faster generation.
|
168 |
+
|
169 |
+
Soundfont export code is based on the [sf-creator](https://github.com/paulwellnerbou/sf-creator) project.
|
170 |
+
|
171 |
+
Similar work by Nercessian and Imort: [InstrumentGen](https://instrumentgen.netlify.app/).
|
172 |
+
|
173 |
+
Thank you @carlthome for coming up with the name.
|
174 |
+
|
175 |
+
To cite this work, please use the following BibTeX entry:
|
176 |
+
```bibtex
|
177 |
+
@misc{erl-j-soundfont-generator,
|
178 |
+
title={Erl-j's Soundfont Generator},
|
179 |
+
author={Nicolas Jonason},
|
180 |
+
year={2024},
|
181 |
+
publisher={Huggingface},
|
182 |
+
}
|
183 |
+
```
|
184 |
+
""")
|
185 |
+
|
186 |
+
generate_btn.click(
|
187 |
+
fn=generate_and_export_soundfont,
|
188 |
+
inputs=[text_input, steps],
|
189 |
+
outputs=[audio_output, status_output, sf, wav_files],
|
190 |
+
).success(js="() => console.log('Success')")
|
191 |
+
|
192 |
+
text_input.submit(
|
193 |
+
fn=generate_and_export_soundfont,
|
194 |
+
inputs=[text_input, steps],
|
195 |
+
outputs=[audio_output, status_output, sf, wav_files],
|
196 |
+
)
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
print("Starting demo...")
|
200 |
+
demo.launch()
|
custom.css
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;500&display=swap');
|
2 |
+
|
3 |
+
.keyboard-container {
|
4 |
+
width: 100%;
|
5 |
+
padding: 1.5rem;
|
6 |
+
background: #fafafa;
|
7 |
+
border: 1px solid #e5e5e5;
|
8 |
+
border-radius: 4px;
|
9 |
+
font-family: 'Roboto', sans-serif;
|
10 |
+
user-select: none;
|
11 |
+
}
|
12 |
+
|
13 |
+
.keyboard-row {
|
14 |
+
display: flex;
|
15 |
+
gap: 0.25rem;
|
16 |
+
margin-bottom: 0.25rem;
|
17 |
+
width: 100%;
|
18 |
+
}
|
19 |
+
|
20 |
+
.key {
|
21 |
+
width: calc((100% - 2.75rem) / 12);
|
22 |
+
aspect-ratio: 1;
|
23 |
+
min-width: 40px;
|
24 |
+
flex: none;
|
25 |
+
border: 1px solid #e5e5e5;
|
26 |
+
border-radius: 4px;
|
27 |
+
display: flex;
|
28 |
+
flex-direction: column;
|
29 |
+
align-items: center;
|
30 |
+
justify-content: center;
|
31 |
+
cursor: pointer;
|
32 |
+
background: white;
|
33 |
+
transition: all 0.2s cubic-bezier(0.4, 0, 0.2, 1);
|
34 |
+
user-select: none;
|
35 |
+
padding: 0.5rem;
|
36 |
+
}
|
37 |
+
|
38 |
+
.key:hover {
|
39 |
+
background: #f5f5f5;
|
40 |
+
transform: translateY(-1px);
|
41 |
+
}
|
42 |
+
|
43 |
+
.key:active {
|
44 |
+
transform: translateY(0);
|
45 |
+
}
|
46 |
+
|
47 |
+
.key-label {
|
48 |
+
font-size: 0.875rem;
|
49 |
+
font-weight: 500;
|
50 |
+
color: #333;
|
51 |
+
user-select: none;
|
52 |
+
}
|
53 |
+
|
54 |
+
.note-label {
|
55 |
+
font-size: 0.75rem;
|
56 |
+
color: #666;
|
57 |
+
margin-top: 0.25rem;
|
58 |
+
user-select: none;
|
59 |
+
}
|
60 |
+
|
61 |
+
.controls {
|
62 |
+
display: flex;
|
63 |
+
flex-wrap: wrap;
|
64 |
+
gap: 1rem;
|
65 |
+
margin-bottom: 1.5rem;
|
66 |
+
}
|
67 |
+
|
68 |
+
.effects-controls {
|
69 |
+
display: grid;
|
70 |
+
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
71 |
+
gap: 1rem;
|
72 |
+
margin-bottom: 1.5rem;
|
73 |
+
padding: 1rem;
|
74 |
+
background: #fafafa;
|
75 |
+
border-radius: 4px;
|
76 |
+
}
|
77 |
+
|
78 |
+
.control-group {
|
79 |
+
display: flex;
|
80 |
+
flex-direction: column;
|
81 |
+
gap: 0.5rem;
|
82 |
+
}
|
83 |
+
|
84 |
+
.control-group label {
|
85 |
+
font-size: 0.75rem;
|
86 |
+
color: #666;
|
87 |
+
text-transform: uppercase;
|
88 |
+
letter-spacing: 0.05em;
|
89 |
+
user-select: none;
|
90 |
+
}
|
91 |
+
|
92 |
+
input[type="range"] {
|
93 |
+
width: 100%;
|
94 |
+
height: 24px; /* Increased height for better touch target */
|
95 |
+
background: transparent; /* Remove default background */
|
96 |
+
border-radius: 2px;
|
97 |
+
appearance: none;
|
98 |
+
cursor: pointer;
|
99 |
+
margin: 0;
|
100 |
+
padding: 10px 0; /* Add padding for better touch area */
|
101 |
+
}
|
102 |
+
|
103 |
+
input[type="range"]::-webkit-slider-thumb {
|
104 |
+
appearance: none;
|
105 |
+
width: 24px; /* Increased size */
|
106 |
+
height: 24px; /* Increased size */
|
107 |
+
background: #000000;
|
108 |
+
border-radius: 50%;
|
109 |
+
cursor: pointer;
|
110 |
+
transition: background 0.2s;
|
111 |
+
margin-top: -10px; /* Center thumb vertically */
|
112 |
+
user-select: none;
|
113 |
+
}
|
114 |
+
|
115 |
+
input[type="range"]::-webkit-slider-runnable-track {
|
116 |
+
background: #393a39;
|
117 |
+
height: 4px;
|
118 |
+
border-radius: 2px;
|
119 |
+
user-select: none;
|
120 |
+
}
|
121 |
+
|
122 |
+
select {
|
123 |
+
padding: 0.5rem;
|
124 |
+
border-radius: 4px;
|
125 |
+
border: 1px solid #e5e5e5;
|
126 |
+
background: white;
|
127 |
+
font-family: 'Roboto', sans-serif;
|
128 |
+
font-size: 0.875rem;
|
129 |
+
user-select: none;
|
130 |
+
cursor: pointer;
|
131 |
+
}
|
132 |
+
|
133 |
+
.button-group {
|
134 |
+
display: flex;
|
135 |
+
align-items: center;
|
136 |
+
gap: 0.5rem;
|
137 |
+
}
|
138 |
+
|
139 |
+
.button-group button {
|
140 |
+
width: 2rem;
|
141 |
+
height: 2rem;
|
142 |
+
padding: 0;
|
143 |
+
display: flex;
|
144 |
+
align-items: center;
|
145 |
+
justify-content: center;
|
146 |
+
border: 1px solid #e5e5e5;
|
147 |
+
border-radius: 4px;
|
148 |
+
background: white;
|
149 |
+
cursor: pointer;
|
150 |
+
transition: all 0.2s;
|
151 |
+
user-select: none;
|
152 |
+
}
|
153 |
+
|
154 |
+
.button-group button:hover {
|
155 |
+
background: #f5f5f5;
|
156 |
+
}
|
157 |
+
|
158 |
+
button {
|
159 |
+
padding: 0.5rem 1rem;
|
160 |
+
border: none;
|
161 |
+
border-radius: 4px;
|
162 |
+
background: #015131;
|
163 |
+
color: white;
|
164 |
+
font-family: 'Roboto', sans-serif;
|
165 |
+
font-size: 0.875rem;
|
166 |
+
cursor: pointer;
|
167 |
+
transition: all 0.2s;
|
168 |
+
user-select: none;
|
169 |
+
}
|
170 |
+
|
171 |
+
button:hover {
|
172 |
+
background: #002114;
|
173 |
+
}
|
174 |
+
|
175 |
+
body {
|
176 |
+
font-family: 'Roboto', sans-serif;
|
177 |
+
font-size: 1rem;
|
178 |
+
line-height: 1.5;
|
179 |
+
color: #333;
|
180 |
+
background: #f5f5f5;
|
181 |
+
margin: 0;
|
182 |
+
padding: 0;
|
183 |
+
display: flex;
|
184 |
+
justify-content: center;
|
185 |
+
align-items: center;
|
186 |
+
min-height: 100vh;
|
187 |
+
}
|
188 |
+
|
189 |
+
@media (max-width: 768px) {
|
190 |
+
.control-group { min-width: 100%; }
|
191 |
+
.key { min-width: 35px; }
|
192 |
+
.key-label { font-size: 0.75rem; }
|
193 |
+
|
194 |
+
input[type="range"] {
|
195 |
+
height: 32px; /* Even larger touch target on mobile */
|
196 |
+
padding: 14px 0;
|
197 |
+
}
|
198 |
+
|
199 |
+
input[type="range"]::-webkit-slider-thumb {
|
200 |
+
width: 28px; /* Larger thumb on mobile */
|
201 |
+
height: 28px;
|
202 |
+
}
|
203 |
+
}
|
custom.js
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function previewPlayer() {
|
2 |
+
class KeyboardPlayer {
|
3 |
+
constructor(containerId) {
|
4 |
+
this.container = document.getElementById(containerId);
|
5 |
+
this.initializeProperties();
|
6 |
+
this.loadToneJS().then(() => this.init());
|
7 |
+
this.setupWavFileObserver();
|
8 |
+
|
9 |
+
|
10 |
+
// Add click handlers for activation/deactivation
|
11 |
+
this.container.addEventListener('click', (e) => {
|
12 |
+
e.stopPropagation();
|
13 |
+
if (!this.keyboardEnabled) {
|
14 |
+
this.enableKeyboard();
|
15 |
+
}
|
16 |
+
});
|
17 |
+
|
18 |
+
document.addEventListener('click', (e) => {
|
19 |
+
if (!this.container.contains(e.target)) {
|
20 |
+
this.disableKeyboard();
|
21 |
+
}
|
22 |
+
});
|
23 |
+
|
24 |
+
// disable keyboard
|
25 |
+
this.disableKeyboard();
|
26 |
+
}
|
27 |
+
|
28 |
+
enableKeyboard() {
|
29 |
+
this.keyboardEnabled = true;
|
30 |
+
this.container.style.opacity = '1';
|
31 |
+
}
|
32 |
+
|
33 |
+
disableKeyboard() {
|
34 |
+
this.keyboardEnabled = false;
|
35 |
+
this.container.style.opacity = '0.5';
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
setupWavFileObserver() {
|
41 |
+
const observer = new MutationObserver((mutations) => {
|
42 |
+
const hasDownloadLinkChanges = mutations.some(mutation =>
|
43 |
+
mutation.type === 'childList' &&
|
44 |
+
mutation.target.classList.contains('download-link')
|
45 |
+
);
|
46 |
+
|
47 |
+
if (hasDownloadLinkChanges) {
|
48 |
+
this.initializeSampler();
|
49 |
+
this.enableKeyboard();
|
50 |
+
// scroll so middle of keyboard is in centre of viewport
|
51 |
+
const keyboardTop = this.container.querySelector('.keyboard').getBoundingClientRect().top;
|
52 |
+
window.scrollTo(0, keyboardTop - window.innerHeight / 2, { behavior: 'smooth' });
|
53 |
+
}
|
54 |
+
});
|
55 |
+
|
56 |
+
const wavFilesContainer = document.getElementById('individual-wav-files');
|
57 |
+
if (wavFilesContainer) {
|
58 |
+
observer.observe(wavFilesContainer, {
|
59 |
+
childList: true,
|
60 |
+
subtree: true
|
61 |
+
});
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
initializeProperties() {
|
66 |
+
this.sampler = null;
|
67 |
+
this.keyboardEnabled = true;
|
68 |
+
this.layout = null;
|
69 |
+
this.rootPitch = 60;
|
70 |
+
this.columnOffset = 2;
|
71 |
+
this.rowOffset = 4;
|
72 |
+
this.activeNotes = new Map();
|
73 |
+
this.reverb = null;
|
74 |
+
this.releaseTime = 0.1;
|
75 |
+
this.noteNames = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'];
|
76 |
+
this.majorScale = [0, 2, 4, 5, 7, 9, 11];
|
77 |
+
}
|
78 |
+
|
79 |
+
async loadToneJS() {
|
80 |
+
if (window.Tone) return;
|
81 |
+
const script = document.createElement('script');
|
82 |
+
script.src = 'https://cdnjs.cloudflare.com/ajax/libs/tone/14.8.49/Tone.js';
|
83 |
+
return new Promise((resolve, reject) => {
|
84 |
+
script.onload = resolve;
|
85 |
+
script.onerror = () => reject(new Error('Failed to load Tone.js'));
|
86 |
+
document.head.appendChild(script);
|
87 |
+
});
|
88 |
+
}
|
89 |
+
|
90 |
+
init() {
|
91 |
+
this.createUI();
|
92 |
+
this.detectKeyboardLayout();
|
93 |
+
this.setupEventListeners();
|
94 |
+
this.initializeEffects();
|
95 |
+
this.initializeSampler();
|
96 |
+
}
|
97 |
+
|
98 |
+
createUI() {
|
99 |
+
this.container.innerHTML = `
|
100 |
+
<div class="keyboard-container">
|
101 |
+
<div class="effects-controls">
|
102 |
+
<h3>Release & Reverb</h3>
|
103 |
+
<div class="effect-slider">
|
104 |
+
<label>Release: <span class="release-value">0.1s</span></label>
|
105 |
+
<input type="range" class="release-slider" min="0" max="3" step="0.1" value="0.1">
|
106 |
+
</div>
|
107 |
+
<div class="effect-slider">
|
108 |
+
<label>Reverb: <span class="reverb-value">50%</span></label>
|
109 |
+
<input type="range" class="reverb-slider" min="0" max="100" value="50">
|
110 |
+
</div>
|
111 |
+
</div>
|
112 |
+
<div class="keyboard"></div>
|
113 |
+
<br>
|
114 |
+
<div class="mapping-controls">
|
115 |
+
<h3>Keyboard Mapping</h3>
|
116 |
+
<div class="control-group">
|
117 |
+
<label>Root Pitch: <span class="root-value">C4</span></label>
|
118 |
+
<input type="range" class="root-slider" min="24" max="84" value="60">
|
119 |
+
</div>
|
120 |
+
<div class="control-group">
|
121 |
+
<label>Column Offset: <span class="column-value">2</span> keys from left</label>
|
122 |
+
<input type="range" class="column-slider" min="0" max="6" value="2">
|
123 |
+
</div>
|
124 |
+
<div class="control-group">
|
125 |
+
<label>Row Offset: <span class="row-value">4</span> scale degree(s)</label>
|
126 |
+
<input type="range" class="row-slider" min="1" max="20" value="4">
|
127 |
+
</div>
|
128 |
+
</div>
|
129 |
+
</div>
|
130 |
+
`;
|
131 |
+
this.cacheElements();
|
132 |
+
}
|
133 |
+
|
134 |
+
cacheElements() {
|
135 |
+
const selectors = {
|
136 |
+
keyboard: '.keyboard',
|
137 |
+
rootSlider: '.root-slider',
|
138 |
+
rootValue: '.root-value',
|
139 |
+
columnSlider: '.column-slider',
|
140 |
+
columnValue: '.column-value',
|
141 |
+
rowSlider: '.row-slider',
|
142 |
+
rowValue: '.row-value',
|
143 |
+
releaseSlider: '.release-slider',
|
144 |
+
releaseValue: '.release-value',
|
145 |
+
reverbSlider: '.reverb-slider',
|
146 |
+
reverbValue: '.reverb-value'
|
147 |
+
};
|
148 |
+
this.elements = Object.fromEntries(
|
149 |
+
Object.entries(selectors).map(([key, selector]) =>
|
150 |
+
[key, this.container.querySelector(selector)]
|
151 |
+
)
|
152 |
+
);
|
153 |
+
}
|
154 |
+
|
155 |
+
setupEventListeners() {
|
156 |
+
const handlers = {
|
157 |
+
releaseSlider: e => {
|
158 |
+
this.releaseTime = parseFloat(e.target.value);
|
159 |
+
this.elements.releaseValue.textContent = `${this.releaseTime}s`;
|
160 |
+
},
|
161 |
+
reverbSlider: e => {
|
162 |
+
const wetness = parseInt(e.target.value) / 100;
|
163 |
+
this.reverb.wet.value = wetness;
|
164 |
+
this.elements.reverbValue.textContent = `${e.target.value}%`;
|
165 |
+
},
|
166 |
+
rootSlider: e => {
|
167 |
+
this.rootPitch = parseInt(e.target.value);
|
168 |
+
this.elements.rootValue.textContent = this.midiToNoteName(this.rootPitch);
|
169 |
+
this.updateNotes();
|
170 |
+
},
|
171 |
+
columnSlider: e => {
|
172 |
+
this.columnOffset = parseInt(e.target.value);
|
173 |
+
this.elements.columnValue.textContent = this.columnOffset;
|
174 |
+
this.updateNotes();
|
175 |
+
},
|
176 |
+
rowSlider: e => {
|
177 |
+
this.rowOffset = parseInt(e.target.value);
|
178 |
+
this.elements.rowValue.textContent = this.rowOffset;
|
179 |
+
this.updateNotes();
|
180 |
+
}
|
181 |
+
};
|
182 |
+
|
183 |
+
Object.entries(handlers).forEach(([element, handler]) =>
|
184 |
+
this.elements[element].addEventListener('input', handler));
|
185 |
+
|
186 |
+
document.addEventListener('mouseup', () => this.handleMouseUp());
|
187 |
+
document.addEventListener('keydown', e => !e.repeat && this.handleKeyEvent(e, true));
|
188 |
+
document.addEventListener('keyup', e => this.handleKeyEvent(e, false));
|
189 |
+
}
|
190 |
+
|
191 |
+
initializeEffects() {
|
192 |
+
this.reverb = new Tone.Reverb({ decay: 1.5, wet: 0.5 }).toDestination();
|
193 |
+
}
|
194 |
+
|
195 |
+
async initializeSampler() {
|
196 |
+
const availableNotes = ['C1', 'F#1', 'C2', 'F#2', 'C3', 'F#3', 'C4', 'F#4', 'C5', 'F#5'];
|
197 |
+
const urls = Object.fromEntries(
|
198 |
+
availableNotes
|
199 |
+
.map(note => [note, document.querySelector(`a[href*="${note}.wav"]`)?.href])
|
200 |
+
.filter(([, url]) => url)
|
201 |
+
);
|
202 |
+
|
203 |
+
if (!Object.keys(urls).length) {
|
204 |
+
this.handleSamplerError();
|
205 |
+
return;
|
206 |
+
}
|
207 |
+
|
208 |
+
this.sampler = new Tone.Sampler({
|
209 |
+
urls,
|
210 |
+
onload: () => this.handleSamplerLoad(),
|
211 |
+
}).connect(this.reverb);
|
212 |
+
}
|
213 |
+
|
214 |
+
handleSamplerError() {
|
215 |
+
console.log('No WAV files found');
|
216 |
+
|
217 |
+
}
|
218 |
+
|
219 |
+
handleSamplerLoad() {
|
220 |
+
console.log('Sampler loaded');
|
221 |
+
this.container.querySelectorAll('.key').forEach(key => key.style.opacity = '1');
|
222 |
+
}
|
223 |
+
|
224 |
+
detectKeyboardLayout() {
|
225 |
+
this.layout = {
|
226 |
+
keys: [
|
227 |
+
{ keys: '1234567890'.split(''), offset: 0 },
|
228 |
+
{ keys: 'QWERTYUIOP'.split(''), offset: 1 },
|
229 |
+
{ keys: 'ASDFGHJKL'.split(''), offset: 1.5 },
|
230 |
+
{ keys: 'ZXCVBNM,.'.split(''), offset: 2 }
|
231 |
+
]
|
232 |
+
}.keys;
|
233 |
+
this.createKeyboard();
|
234 |
+
}
|
235 |
+
|
236 |
+
createKeyboard() {
|
237 |
+
this.elements.keyboard.innerHTML = '';
|
238 |
+
this.layout.forEach((row, rowIndex) => {
|
239 |
+
const rowElement = document.createElement('div');
|
240 |
+
rowElement.className = 'keyboard-row';
|
241 |
+
rowElement.style.paddingLeft = `${row.offset * 3}%`;
|
242 |
+
row.keys.forEach(key => rowElement.appendChild(this.createKey(key)));
|
243 |
+
this.elements.keyboard.appendChild(rowElement);
|
244 |
+
});
|
245 |
+
this.updateNotes();
|
246 |
+
}
|
247 |
+
|
248 |
+
createKey(keyLabel) {
|
249 |
+
const key = document.createElement('div');
|
250 |
+
key.className = 'key';
|
251 |
+
key.innerHTML = `
|
252 |
+
<div class="key-label">${keyLabel}</div>
|
253 |
+
<div class="note-label"></div>
|
254 |
+
`;
|
255 |
+
key.addEventListener('mousedown', () => this.startNote(key));
|
256 |
+
key.addEventListener('mouseenter', e => e.buttons === 1 && this.startNote(key));
|
257 |
+
key.addEventListener('mouseleave', () => this.stopNote(key));
|
258 |
+
return key;
|
259 |
+
}
|
260 |
+
|
261 |
+
updateNotes() {
|
262 |
+
Array.from(this.elements.keyboard.children).forEach((row, rowIndex) => {
|
263 |
+
Array.from(row.children).forEach((key, columnIndex) => {
|
264 |
+
const horizontalDistance = columnIndex - this.columnOffset;
|
265 |
+
const verticalDistance = rowIndex * this.rowOffset;
|
266 |
+
const totalScaleDegrees = horizontalDistance - verticalDistance;
|
267 |
+
const octaves = Math.floor(totalScaleDegrees / 7);
|
268 |
+
const remainingDegrees = ((totalScaleDegrees % 7) + 7) % 7;
|
269 |
+
const semitonesFromRoot = this.majorScale[remainingDegrees] + (octaves * 12);
|
270 |
+
const midiNote = this.rootPitch + semitonesFromRoot;
|
271 |
+
|
272 |
+
this.updateKeyDisplay(key, midiNote);
|
273 |
+
});
|
274 |
+
});
|
275 |
+
}
|
276 |
+
|
277 |
+
updateKeyDisplay(key, midiNote) {
|
278 |
+
const isBaseRoot = midiNote === this.rootPitch;
|
279 |
+
const isOctaveRoot = midiNote % 12 === this.rootPitch % 12;
|
280 |
+
key.style.backgroundColor = isBaseRoot ? '#90EE90' : isOctaveRoot ? '#E8F5E9' : '';
|
281 |
+
const noteName = this.midiToNoteName(midiNote);
|
282 |
+
key.querySelector('.note-label').textContent = noteName;
|
283 |
+
key.dataset.note = noteName;
|
284 |
+
key.dataset.midi = midiNote;
|
285 |
+
}
|
286 |
+
|
287 |
+
handleKeyEvent(e, isKeyDown) {
|
288 |
+
if (!this.keyboardEnabled || !this.sampler) return;
|
289 |
+
const keyElement = this.findKeyElement(e.key.toUpperCase());
|
290 |
+
if (keyElement) {
|
291 |
+
e.preventDefault();
|
292 |
+
isKeyDown ? this.startNote(keyElement) : this.stopNote(keyElement);
|
293 |
+
}
|
294 |
+
}
|
295 |
+
|
296 |
+
startNote(keyElement) {
|
297 |
+
if (!this.sampler || !keyElement || this.activeNotes.has(keyElement)) return;
|
298 |
+
const note = keyElement.dataset.note;
|
299 |
+
if (!note) return;
|
300 |
+
|
301 |
+
Tone.start().then(() => {
|
302 |
+
this.sampler.triggerAttack(note);
|
303 |
+
this.activeNotes.set(keyElement, { note });
|
304 |
+
this.animateKey(keyElement, true);
|
305 |
+
});
|
306 |
+
}
|
307 |
+
|
308 |
+
stopNote(keyElement) {
|
309 |
+
if (!this.sampler || !keyElement) return;
|
310 |
+
const noteInfo = this.activeNotes.get(keyElement);
|
311 |
+
if (noteInfo) {
|
312 |
+
this.sampler.triggerRelease(noteInfo.note, "+" + this.releaseTime);
|
313 |
+
this.activeNotes.delete(keyElement);
|
314 |
+
this.animateKey(keyElement, false);
|
315 |
+
}
|
316 |
+
}
|
317 |
+
|
318 |
+
handleMouseUp() {
|
319 |
+
this.activeNotes.forEach((_, keyElement) => this.stopNote(keyElement));
|
320 |
+
}
|
321 |
+
|
322 |
+
findKeyElement(keyLabel) {
|
323 |
+
for (const row of this.elements.keyboard.children) {
|
324 |
+
for (const key of row.children) {
|
325 |
+
if (key.querySelector('.key-label').textContent === keyLabel) return key;
|
326 |
+
}
|
327 |
+
}
|
328 |
+
return null;
|
329 |
+
}
|
330 |
+
|
331 |
+
animateKey(keyElement, isDown) {
|
332 |
+
const midiNote = parseInt(keyElement.dataset.midi);
|
333 |
+
const isBaseRoot = midiNote === this.rootPitch;
|
334 |
+
const isOctaveRoot = midiNote % 12 === this.rootPitch % 12;
|
335 |
+
|
336 |
+
keyElement.style.transform = isDown ? 'scale(0.95)' : '';
|
337 |
+
keyElement.style.backgroundColor = isBaseRoot ? '#90EE90' :
|
338 |
+
isOctaveRoot ? '#E8F5E9' :
|
339 |
+
isDown ? '#f0f0f0' : '';
|
340 |
+
}
|
341 |
+
|
342 |
+
midiToNoteName(midiNumber) {
|
343 |
+
const octave = Math.floor(midiNumber / 12) - 1;
|
344 |
+
return `${this.noteNames[midiNumber % 12]}${octave}`;
|
345 |
+
}
|
346 |
+
}
|
347 |
+
|
348 |
+
let container = document.getElementById('custom-player');
|
349 |
+
if (!container) {
|
350 |
+
container = document.createElement('div');
|
351 |
+
container.id = 'custom-player';
|
352 |
+
document.body.appendChild(container);
|
353 |
+
}
|
354 |
+
new KeyboardPlayer('custom-player');
|
355 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch
|
3 |
+
stable-audio-tools==0.0.16
|
4 |
+
gradio==5.8.0
|
5 |
+
einops
|
6 |
+
spaces
|
7 |
+
lxml
|
8 |
+
transformers==4.44.0
|
9 |
+
tokenizers==0.19.1
|
sf-creator-fork/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
__pycache__
|
3 |
+
*/__pycache__
|
sf-creator-fork/README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Soundfont creation
|
2 |
+
|
3 |
+
This library aims to create a soundfont based on a directory containing sound files (`.wav`).
|
4 |
+
|
5 |
+
Current support:
|
6 |
+
|
7 |
+
* [SFZ Format](https://sfzformat.com/)
|
8 |
+
* [DecentSampler](https://www.decentsamples.com/product/decent-sampler-plugin/)
|
9 |
+
|
10 |
+
Planned support:
|
11 |
+
|
12 |
+
* [Soundfont 2 `sf2`](https://en.wikipedia.org/wiki/SoundFont) (This will be more tricky, as it is a binary format,
|
13 |
+
as [Polyphone](https://www.polyphone-soundfonts.com/) is able to convert `sfz` to `sf2`, I will postpone this)
|
14 |
+
|
15 |
+
## Project setup
|
16 |
+
|
17 |
+
### Creating a project specific virtual environment (recommended)
|
18 |
+
|
19 |
+
You can omit this step if you are ok with installing the dependencies
|
20 |
+
system wide and go directly to the next step: [Installing dependencies](#installing-dependencies).
|
21 |
+
|
22 |
+
```
|
23 |
+
virtualenv venv
|
24 |
+
source venv/bin/activate
|
25 |
+
```
|
26 |
+
or, under Windows:
|
27 |
+
```
|
28 |
+
virtualenv venv
|
29 |
+
venv\Scripts\activate
|
30 |
+
```
|
31 |
+
|
32 |
+
### Installing dependencies
|
33 |
+
|
34 |
+
```
|
35 |
+
pip install -r requirements.txt
|
36 |
+
```
|
37 |
+
|
38 |
+
## Run
|
39 |
+
|
40 |
+
This will create a file `soundfont.sfz` alongside the `wav` files in the given directory.
|
41 |
+
|
42 |
+
```
|
43 |
+
python main.py sfz <directory-to-wave-files>
|
44 |
+
```
|
45 |
+
|
46 |
+
Run `python main.py --help` or `python main.py <command> --help`, where `<command>` can be `sfz` or `decentsampler` for now,
|
47 |
+
to get the full list of arguments.
|
48 |
+
|
49 |
+
## Automatic note detection and mapping of samples
|
50 |
+
|
51 |
+
The given samples are scanned for note names (A0 to C8). If a note name is found in a filename of a sample, the midi
|
52 |
+
note for this sample will be set automatically.
|
53 |
+
|
54 |
+
In addition to that, in case of missing samples in between for certain notes an automatic distribution is calculated, so that all notes between A0 and C8 are covered.
|
55 |
+
|
56 |
+
If there are two samples for the same note available, a round robin/random change is assumed.
|
57 |
+
|
58 |
+
## TODO and resources
|
59 |
+
|
60 |
+
- [ ] Make automatic distribution over all midi notes from 21 to 108 optional, and add an option to configure the highest and the lowest, ideally relative to the hightest and lowest pitch of the samples
|
61 |
+
- [ ] Detect pitch automatically (for melodic instruments at least), using https://pypi.org/project/crepe/
|
62 |
+
|
63 |
+
### More SFZ Support
|
64 |
+
|
65 |
+
The best starting point for SFZ is https://sfzformat.com/.
|
66 |
+
|
67 |
+
- [ ] Look at [SFZ Python Automapper by Peter Eastman](https://vis.versilstudios.com/sfzconverter.html#u13452-4), this looks like there is a lot that can be reused for sfz files
|
68 |
+
- [ ] And https://github.com/freepats/freepats-tools, too?
|
69 |
+
- [ ] Add support for velocity levels
|
70 |
+
- [ ] Add support for more options supported by SFZ, reverb, effects, attack, release and so on
|
71 |
+
|
72 |
+
### DecentSampler support
|
73 |
+
|
74 |
+
An XML based format developed by David Hilowitz (see https://youtu.be/UxPRmD_RNCY).
|
75 |
+
|
76 |
+
- [x] Create an XML Schema for highlighting and autocompletion
|
77 |
+
- [x] Implement `DecentSamplerWriter`
|
78 |
+
- [x] Add options for UI (cover)
|
79 |
+
- [ ] ...and effects
|
80 |
+
|
81 |
+
### SF2 Support
|
82 |
+
|
83 |
+
This will get tricky, as this is a binary format with not too much examples. There are a few applications reading or even writing sf2 out there, at least at a very basic level.
|
84 |
+
But as [Polyphone](https://www.polyphone-soundfonts.com/) is able to convert `sfz` to `sf2`, I will postpone this.
|
85 |
+
|
86 |
+
* Basic SFZ to SF2 converter in python: https://github.com/freepats/freepats-tools
|
87 |
+
* C++ library [sf2cute](http://gocha.github.io/sf2cute/)
|
88 |
+
* Python library reading sf2: https://pypi.org/project/sf2utils/
|
89 |
+
* C# code writing basic sf2 file: https://github.com/Kermalis/SoundFont2
|
90 |
+
* Code of Polyphone: https://github.com/davy7125/polyphone
|
sf-creator-fork/main.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import glob
|
5 |
+
|
6 |
+
from sfcreator.decentsampler import DecentSamplerWriter
|
7 |
+
from sfcreator.sfz import SfzWriter
|
8 |
+
from sfcreator.soundfont.soundfont import SoundFont, Sample, NoteNameMidiNumberMapper, HighLowKeyDistributor
|
9 |
+
|
10 |
+
|
11 |
+
def list_supported_files(directory):
|
12 |
+
return glob.glob(directory + "/*.wav")
|
13 |
+
|
14 |
+
|
15 |
+
def map_note_name(note_name: str):
|
16 |
+
if note_name.isdigit():
|
17 |
+
return int(note_name)
|
18 |
+
else:
|
19 |
+
return NoteNameMidiNumberMapper().map(note_name)
|
20 |
+
|
21 |
+
# @cli.command()
|
22 |
+
# @click.argument('directory')
|
23 |
+
# @click.option('--highkey', required=False, type=str, default="108")
|
24 |
+
# @click.option('--lowkey', required=False, type=str, default="21")
|
25 |
+
# @click.option('--instrument', help="name of the instrument", required=False, type=str)
|
26 |
+
# @click.option('--loopmode',
|
27 |
+
# help="loop mode, no_loop (default), one_shot, loop_continuous or loop_sustain,"
|
28 |
+
# " see https://sfzformat.com/opcodes/loop_mode",
|
29 |
+
# required=False, type=str, default="no_loop")
|
30 |
+
# @click.option('--polyphony',
|
31 |
+
# help="Polyphony voice limit, see https://sfzformat.com/opcodes/polyphony",
|
32 |
+
# required=False, type=str, default=None)
|
33 |
+
def sfz(directory: str, lowkey: str, highkey: str, instrument: str, loopmode: str, polyphony: str):
|
34 |
+
soundfont = make_soundfont(directory, map_note_name(lowkey), map_note_name(highkey), instrument, loopmode,
|
35 |
+
polyphony)
|
36 |
+
SfzWriter().write(directory, soundfont)
|
37 |
+
|
38 |
+
|
39 |
+
# @cli.command()
|
40 |
+
# @click.argument('directory')
|
41 |
+
# @click.option('--highkey', required=False, type=str, default="108")
|
42 |
+
# @click.option('--lowkey', required=False, type=str, default="21")
|
43 |
+
# @click.option('--instrument', help="name of the instrument", required=False, type=str)
|
44 |
+
# @click.option('--image', required=False, type=str)
|
45 |
+
# @click.option('--loopmode',
|
46 |
+
# help="loop mode, only supported for sfz. For compatibility reasons, --loopmode=one_shot will cause"
|
47 |
+
# " --release=20 for decent sampler, to ensure samples are always played until the end.",
|
48 |
+
# required=False, type=str, default="no_loop")
|
49 |
+
# @click.option('--polyphony',
|
50 |
+
# help="Polyphony voice limit, not supported in decent sampler format.",
|
51 |
+
# required=False, type=str, default=None)
|
52 |
+
|
53 |
+
def decentsampler(directory: str, lowkey: str, highkey: str, instrument: str, loopmode: str, polyphony: str,
|
54 |
+
image: str):
|
55 |
+
soundfont = make_soundfont(directory, map_note_name(lowkey), map_note_name(highkey), instrument, loopmode,
|
56 |
+
polyphony)
|
57 |
+
DecentSamplerWriter().write(directory, soundfont, image)
|
58 |
+
|
59 |
+
|
60 |
+
def make_soundfont(directory: str, lowkey: int, highkey: int, instrument: str, loopmode: str, polyphony: str):
|
61 |
+
files = list_supported_files(directory)
|
62 |
+
samples: List[Sample] = []
|
63 |
+
mapper = NoteNameMidiNumberMapper()
|
64 |
+
for file in files:
|
65 |
+
samples.append(mapper.mapped_sample(file))
|
66 |
+
soundfont = SoundFont(samples, loop_mode=loopmode, polyphony=polyphony)
|
67 |
+
if instrument is None or len(instrument) == 0:
|
68 |
+
soundfont.instrument_name = os.path.basename(os.path.dirname(directory))
|
69 |
+
else:
|
70 |
+
soundfont.instrument_name = instrument
|
71 |
+
HighLowKeyDistributor().distribute(soundfont, low_key=lowkey, high_key=highkey)
|
72 |
+
return soundfont
|
sf-creator-fork/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
click~=7.1.2
|
2 |
+
lxml~=4.6.1
|
sf-creator-fork/resources/decentsampler/dspreset.xsd
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<xs:schema attributeFormDefault="unqualified" elementFormDefault="qualified" xmlns:xs="http://www.w3.org/2001/XMLSchema">
|
2 |
+
<xs:element name="DecentSampler">
|
3 |
+
<xs:complexType>
|
4 |
+
<xs:sequence>
|
5 |
+
<xs:element name="ui">
|
6 |
+
<xs:complexType>
|
7 |
+
<xs:sequence>
|
8 |
+
<xs:element name="tab">
|
9 |
+
<xs:complexType>
|
10 |
+
<xs:sequence>
|
11 |
+
<xs:element name="labeled-knob" maxOccurs="unbounded" minOccurs="0">
|
12 |
+
<xs:complexType>
|
13 |
+
<xs:sequence>
|
14 |
+
<xs:element name="binding" maxOccurs="unbounded" minOccurs="0">
|
15 |
+
<xs:complexType>
|
16 |
+
<xs:simpleContent>
|
17 |
+
<xs:extension base="xs:string">
|
18 |
+
<xs:attribute type="xs:string" name="type" use="optional"/>
|
19 |
+
<xs:attribute type="xs:string" name="level" use="optional"/>
|
20 |
+
<xs:attribute type="xs:byte" name="position" use="optional"/>
|
21 |
+
<xs:attribute type="xs:string" name="parameter" use="optional"/>
|
22 |
+
<xs:attribute type="xs:string" name="translation" use="optional"/>
|
23 |
+
<xs:attribute type="xs:byte" name="translationOutputMin" use="optional"/>
|
24 |
+
<xs:attribute type="xs:float" name="translationOutputMax" use="optional"/>
|
25 |
+
<xs:attribute type="xs:string" name="translationTable" use="optional"/>
|
26 |
+
</xs:extension>
|
27 |
+
</xs:simpleContent>
|
28 |
+
</xs:complexType>
|
29 |
+
</xs:element>
|
30 |
+
</xs:sequence>
|
31 |
+
<xs:attribute type="xs:short" name="x" use="optional"/>
|
32 |
+
<xs:attribute type="xs:byte" name="y" use="optional"/>
|
33 |
+
<xs:attribute type="xs:string" name="label" use="optional"/>
|
34 |
+
<xs:attribute type="xs:string" name="type" use="optional"/>
|
35 |
+
<xs:attribute type="xs:short" name="minValue" use="optional"/>
|
36 |
+
<xs:attribute type="xs:short" name="maxValue" use="optional"/>
|
37 |
+
<xs:attribute type="xs:string" name="textColor" use="optional"/>
|
38 |
+
<xs:attribute type="xs:float" name="value" use="optional"/>
|
39 |
+
</xs:complexType>
|
40 |
+
</xs:element>
|
41 |
+
</xs:sequence>
|
42 |
+
<xs:attribute type="xs:string" name="name"/>
|
43 |
+
</xs:complexType>
|
44 |
+
</xs:element>
|
45 |
+
</xs:sequence>
|
46 |
+
<xs:attribute type="xs:string" name="bgImage"/>
|
47 |
+
<xs:attribute type="xs:short" name="width"/>
|
48 |
+
<xs:attribute type="xs:short" name="height"/>
|
49 |
+
<xs:attribute type="xs:string" name="layoutMode"/>
|
50 |
+
<xs:attribute type="xs:string" name="bgMode"/>
|
51 |
+
</xs:complexType>
|
52 |
+
</xs:element>
|
53 |
+
<xs:element name="groups">
|
54 |
+
<xs:complexType>
|
55 |
+
<xs:sequence>
|
56 |
+
<xs:element name="group" maxOccurs="unbounded" minOccurs="0">
|
57 |
+
<xs:complexType>
|
58 |
+
<xs:sequence>
|
59 |
+
<xs:element name="sample" maxOccurs="unbounded" minOccurs="0">
|
60 |
+
<xs:complexType>
|
61 |
+
<xs:simpleContent>
|
62 |
+
<xs:extension base="xs:string">
|
63 |
+
<xs:attribute type="xs:string" name="path" use="optional"/>
|
64 |
+
<xs:attribute type="xs:byte" name="rootNote" use="optional"/>
|
65 |
+
<xs:attribute type="xs:byte" name="loNote" use="optional"/>
|
66 |
+
<xs:attribute type="xs:byte" name="hiNote" use="optional"/>
|
67 |
+
<xs:attribute type="xs:int" name="loopStart" use="optional"/>
|
68 |
+
<xs:attribute type="xs:int" name="loopEnd" use="optional"/>
|
69 |
+
</xs:extension>
|
70 |
+
</xs:simpleContent>
|
71 |
+
</xs:complexType>
|
72 |
+
</xs:element>
|
73 |
+
</xs:sequence>
|
74 |
+
<xs:attribute type="xs:string" name="name" use="optional"/>
|
75 |
+
<xs:attribute type="xs:string" name="volume" use="optional"/>
|
76 |
+
<xs:attribute type="xs:byte" name="ampVelTrack" use="optional"/>
|
77 |
+
<xs:attribute type="xs:byte" name="modVolume" use="optional"/>
|
78 |
+
</xs:complexType>
|
79 |
+
</xs:element>
|
80 |
+
</xs:sequence>
|
81 |
+
<xs:attribute type="xs:float" name="attack"/>
|
82 |
+
<xs:attribute type="xs:float" name="decay"/>
|
83 |
+
<xs:attribute type="xs:float" name="sustain"/>
|
84 |
+
<xs:attribute type="xs:byte" name="release"/>
|
85 |
+
</xs:complexType>
|
86 |
+
</xs:element>
|
87 |
+
<xs:element name="effects">
|
88 |
+
<xs:complexType>
|
89 |
+
<xs:sequence>
|
90 |
+
<xs:element name="effect" maxOccurs="unbounded" minOccurs="0">
|
91 |
+
<xs:complexType>
|
92 |
+
<xs:simpleContent>
|
93 |
+
<xs:extension base="xs:string">
|
94 |
+
<xs:attribute type="xs:string" name="type" use="optional"/>
|
95 |
+
</xs:extension>
|
96 |
+
</xs:simpleContent>
|
97 |
+
</xs:complexType>
|
98 |
+
</xs:element>
|
99 |
+
</xs:sequence>
|
100 |
+
</xs:complexType>
|
101 |
+
</xs:element>
|
102 |
+
<xs:element name="midi">
|
103 |
+
<xs:complexType>
|
104 |
+
<xs:sequence>
|
105 |
+
<xs:element name="cc">
|
106 |
+
<xs:complexType>
|
107 |
+
<xs:sequence>
|
108 |
+
<xs:element name="binding">
|
109 |
+
<xs:complexType>
|
110 |
+
<xs:simpleContent>
|
111 |
+
<xs:extension base="xs:string">
|
112 |
+
<xs:attribute type="xs:string" name="level"/>
|
113 |
+
<xs:attribute type="xs:string" name="type"/>
|
114 |
+
<xs:attribute type="xs:byte" name="position"/>
|
115 |
+
<xs:attribute type="xs:string" name="parameter"/>
|
116 |
+
<xs:attribute type="xs:string" name="translation"/>
|
117 |
+
<xs:attribute type="xs:byte" name="translationOutputMin"/>
|
118 |
+
<xs:attribute type="xs:float" name="translationOutputMax"/>
|
119 |
+
</xs:extension>
|
120 |
+
</xs:simpleContent>
|
121 |
+
</xs:complexType>
|
122 |
+
</xs:element>
|
123 |
+
</xs:sequence>
|
124 |
+
<xs:attribute type="xs:byte" name="number"/>
|
125 |
+
</xs:complexType>
|
126 |
+
</xs:element>
|
127 |
+
</xs:sequence>
|
128 |
+
</xs:complexType>
|
129 |
+
</xs:element>
|
130 |
+
</xs:sequence>
|
131 |
+
<xs:attribute type="xs:byte" name="pluginVersion"/>
|
132 |
+
</xs:complexType>
|
133 |
+
</xs:element>
|
134 |
+
</xs:schema>
|
sf-creator-fork/sfcreator/__init__.py
ADDED
File without changes
|
sf-creator-fork/sfcreator/decentsampler/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from lxml import etree
|
4 |
+
from sfcreator.soundfont.soundfont import SoundFont, Sample
|
5 |
+
|
6 |
+
class DecentSamplerWriter:
|
7 |
+
def write(self, directory: str, soundfont: SoundFont, image: str):
|
8 |
+
root = etree.Element("DecentSampler", pluginVersion="1")
|
9 |
+
|
10 |
+
if image is not None:
|
11 |
+
ui = etree.Element("ui", bgImage=image, width="812", height="375", layoutMode="relative", bgMode="top_left")
|
12 |
+
root.append(ui)
|
13 |
+
tab = etree.Element("tab", name="main")
|
14 |
+
ui.append(tab)
|
15 |
+
|
16 |
+
groups = etree.Element("groups")
|
17 |
+
root.append(groups)
|
18 |
+
|
19 |
+
if soundfont.loop_mode == "one_shot":
|
20 |
+
# assuming no sample will be longer than 20s
|
21 |
+
groups.set("release", "20")
|
22 |
+
elif soundfont.release is not None:
|
23 |
+
groups.set("release", str(soundfont.release))
|
24 |
+
|
25 |
+
for root_key in soundfont.root_keys():
|
26 |
+
self.write_root_key_sample_group(groups, soundfont.samples_for_root_key(root_key))
|
27 |
+
|
28 |
+
filename = directory + "/" + soundfont.instrument_name + ".dspreset"
|
29 |
+
print("Writing to " + filename)
|
30 |
+
et = etree.ElementTree(root)
|
31 |
+
et.write(filename, pretty_print=True, encoding='utf-8', xml_declaration=True)
|
32 |
+
|
33 |
+
def write_root_key_sample_group(self, groups, samples: List[Sample]):
|
34 |
+
group = etree.Element("group")
|
35 |
+
groups.append(group)
|
36 |
+
if len(samples) > 1:
|
37 |
+
group.set("seqMode", "random")
|
38 |
+
for index, sample in enumerate(samples, start=1):
|
39 |
+
xml_sample = etree.Element("sample")
|
40 |
+
group.append(xml_sample)
|
41 |
+
xml_sample.set("rootNote", str(sample.key_range.root_key))
|
42 |
+
xml_sample.set("loNote", str(sample.key_range.low_key))
|
43 |
+
xml_sample.set("hiNote", str(sample.key_range.high_key))
|
44 |
+
xml_sample.set("loVel", str(sample.velocity_range.low_velocity))
|
45 |
+
xml_sample.set("hiVel", str(sample.velocity_range.high_velocity))
|
46 |
+
xml_sample.set("seqPosition", str(index))
|
47 |
+
xml_sample.set("path", str(os.path.basename(sample.filename)))
|
sf-creator-fork/sfcreator/sfz/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from sfcreator.soundfont.soundfont import SoundFont, Sample
|
5 |
+
|
6 |
+
|
7 |
+
class SfzWriter:
|
8 |
+
def write(self, directory: str, soundfont: SoundFont):
|
9 |
+
f = open(directory + "/" + soundfont.instrument_name + ".sfz", "w")
|
10 |
+
f.write("<group>\r\n")
|
11 |
+
if soundfont.loop_mode is not None:
|
12 |
+
f.write(f"loop_mode={soundfont.loop_mode}\r\n")
|
13 |
+
if soundfont.polyphony is not None:
|
14 |
+
f.write(f"loop_mode={soundfont.polyphony}\r\n")
|
15 |
+
for root_key in soundfont.root_keys():
|
16 |
+
self.write_root_key_sample_group(f, soundfont.samples_for_root_key(root_key))
|
17 |
+
|
18 |
+
def write_root_key_sample_group(self, f, samples: List[Sample]):
|
19 |
+
if len(samples) == 1:
|
20 |
+
sample = samples[0]
|
21 |
+
f.write(f"<region> sample={os.path.basename(sample.filename)}"
|
22 |
+
f" pitch_keycenter={str(sample.key_range.root_key)} lokey={str(sample.key_range.low_key)}"
|
23 |
+
f" hikey={str(sample.key_range.high_key)}\r\n")
|
24 |
+
else:
|
25 |
+
lorand = 0.0
|
26 |
+
randstep = 1 / len(samples)
|
27 |
+
for sample in samples:
|
28 |
+
hirand = lorand + randstep
|
29 |
+
f.write(f"<region> sample={os.path.basename(sample.filename)}"
|
30 |
+
f" pitch_keycenter={str(sample.key_range.root_key)} lokey={str(sample.key_range.low_key)}"
|
31 |
+
f" hikey={str(sample.key_range.high_key)}")
|
32 |
+
if lorand > 0.0:
|
33 |
+
f.write(f" lorand={lorand}")
|
34 |
+
if hirand < 1.0:
|
35 |
+
f.write(f" hirand={hirand}")
|
36 |
+
f.write("\r\n")
|
37 |
+
lorand = hirand
|
sf-creator-fork/sfcreator/soundfont/__init__.py
ADDED
File without changes
|
sf-creator-fork/sfcreator/soundfont/soundfont.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
|
4 |
+
class KeyRange:
|
5 |
+
def __init__(self, root_key: int = 60, low_key: int = None, high_key: int = None):
|
6 |
+
self.root_key = root_key
|
7 |
+
self.low_key = low_key
|
8 |
+
self.high_key = high_key
|
9 |
+
if self.low_key is None:
|
10 |
+
self.low_key = self.root_key
|
11 |
+
if self.high_key is None:
|
12 |
+
self.high_key = self.root_key
|
13 |
+
|
14 |
+
def in_range(self, key):
|
15 |
+
return self.low_key <= key <= self.high_key
|
16 |
+
|
17 |
+
def __eq__(self, other):
|
18 |
+
return self.low_key == other.low_key and self.root_key == other.root_key and self.high_key == other.high_key
|
19 |
+
|
20 |
+
|
21 |
+
class VelocityRange:
|
22 |
+
def __init__(self, low_velocity: int = 0, high_velocity: int = 127):
|
23 |
+
self.low_velocity = low_velocity
|
24 |
+
self.high_velocity = high_velocity
|
25 |
+
|
26 |
+
def in_range(self, key):
|
27 |
+
return self.low_velocity <= key <= self.high_velocity
|
28 |
+
|
29 |
+
def __eq__(self, other):
|
30 |
+
return self.low_velocity == other.low_velocity and self.high_velocity == other.root_key and\
|
31 |
+
self.high_velocity == other.high_key
|
32 |
+
|
33 |
+
|
34 |
+
class Sample:
|
35 |
+
def __init__(self, filename: str, root_key: int = 60, low_key: int = None, high_key: int = None):
|
36 |
+
self.velocity_range = VelocityRange()
|
37 |
+
self.filename = filename
|
38 |
+
self.key_range = KeyRange(root_key, low_key, high_key)
|
39 |
+
self.index = 0
|
40 |
+
|
41 |
+
def __repr__(self) -> str:
|
42 |
+
return 'Sample(filename=' + self.filename + ', root_key=' + str(self.key_range.root_key) + ', low_key=' + str(
|
43 |
+
self.key_range.low_key) + ', high_key=' + str(self.key_range.high_key) + ')'
|
44 |
+
|
45 |
+
def __eq__(self, other):
|
46 |
+
return self.filename == other.filename and self.key_range == other.key_range and self.index == other.index
|
47 |
+
|
48 |
+
|
49 |
+
class SoundFont:
|
50 |
+
def __init__(self, samples: List[Sample], loop_mode: str = "no_loop", polyphony: str = None, release: int = None,
|
51 |
+
instrument_name=None):
|
52 |
+
self.instrument_name = instrument_name
|
53 |
+
self.samples = samples
|
54 |
+
self.loop_mode = loop_mode
|
55 |
+
self.polyphony = polyphony
|
56 |
+
self.release = release
|
57 |
+
|
58 |
+
def root_keys(self):
|
59 |
+
return sorted(set([sample.key_range.root_key for sample in self.samples]))
|
60 |
+
|
61 |
+
def range_for_key(self, key) -> KeyRange:
|
62 |
+
samples_in_range = [sample for sample in self.samples if sample.key_range.in_range(key)]
|
63 |
+
return samples_in_range[0].key_range if len(samples_in_range) > 0 else None
|
64 |
+
|
65 |
+
def samples_for_root_key(self, root_key):
|
66 |
+
return [sample for sample in self.samples if sample.key_range.root_key == root_key]
|
67 |
+
|
68 |
+
def set_range(self, root_key, low_key=None, high_key=None):
|
69 |
+
for sample in self.samples_for_root_key(root_key):
|
70 |
+
if low_key is not None:
|
71 |
+
sample.key_range.low_key = low_key
|
72 |
+
if high_key is not None:
|
73 |
+
sample.key_range.high_key = high_key
|
74 |
+
|
75 |
+
|
76 |
+
class HighLowKeyDistributor:
|
77 |
+
def distribute(self, soundfont: SoundFont, low_key: int = 21, high_key: int = 108):
|
78 |
+
soundfont.samples.sort(key=lambda sample: sample.key_range.root_key, reverse=False)
|
79 |
+
|
80 |
+
prev_root_key: int = None
|
81 |
+
for root_key in soundfont.root_keys():
|
82 |
+
range = soundfont.range_for_key(root_key)
|
83 |
+
if prev_root_key is None:
|
84 |
+
lo_key = min(low_key, range.low_key)
|
85 |
+
soundfont.set_range(root_key, low_key=min(low_key, lo_key))
|
86 |
+
else:
|
87 |
+
prev_range = soundfont.range_for_key(prev_root_key)
|
88 |
+
mid_sample_key = int((range.low_key - prev_range.high_key) / 2) + prev_range.high_key
|
89 |
+
soundfont.set_range(prev_root_key, high_key=mid_sample_key)
|
90 |
+
soundfont.set_range(root_key, low_key=mid_sample_key + 1)
|
91 |
+
prev_root_key = root_key
|
92 |
+
soundfont.set_range(soundfont.root_keys()[-1],
|
93 |
+
high_key=max(high_key, soundfont.range_for_key(soundfont.root_keys()[-1]).high_key))
|
94 |
+
|
95 |
+
|
96 |
+
class NoteNameMidiNumberMapper:
|
97 |
+
def __init__(self):
|
98 |
+
self.index_offset = 21
|
99 |
+
self.note_name_midi_number_map: List[str] = []
|
100 |
+
for octave_number in range(0, 8):
|
101 |
+
for c in range(ord('A'), ord('B') + 1):
|
102 |
+
self._add_note(c, octave_number)
|
103 |
+
for c in range(ord('C'), ord('G') + 1):
|
104 |
+
self._add_note(c, octave_number + 1)
|
105 |
+
self.note_name_midi_number_map.append("C8")
|
106 |
+
|
107 |
+
def _add_note(self, c, octave_number):
|
108 |
+
self.note_name_midi_number_map.append(f"{chr(c)}{str(octave_number)}")
|
109 |
+
if chr(c) != "A" and chr(c) != "E":
|
110 |
+
self.note_name_midi_number_map.append(f"{chr(c)}#{str(octave_number)}")
|
111 |
+
|
112 |
+
def mapped_sample(self, filename: str):
|
113 |
+
for note in self.note_name_midi_number_map:
|
114 |
+
if note in filename:
|
115 |
+
return Sample(filename, self.map(note))
|
116 |
+
return Sample(filename)
|
117 |
+
|
118 |
+
def map(self, note_name: str):
|
119 |
+
return self.note_name_midi_number_map.index(note_name) + self.index_offset
|
sf-creator-fork/sfcreator/soundfont/test_soundfont.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest import TestCase
|
2 |
+
|
3 |
+
from sfcreator.soundfont.soundfont import *
|
4 |
+
|
5 |
+
|
6 |
+
class TestHighLowKeyDistributor(TestCase):
|
7 |
+
def test_distribution_of_samples_one_sample_only(self):
|
8 |
+
distributor = HighLowKeyDistributor()
|
9 |
+
|
10 |
+
sf = SoundFont(samples=[
|
11 |
+
Sample("filename.wav", root_key=60)
|
12 |
+
])
|
13 |
+
distributor.distribute(sf)
|
14 |
+
|
15 |
+
def test_distribution_of_samples_unsorted_samples(self):
|
16 |
+
distributor = HighLowKeyDistributor()
|
17 |
+
sf = SoundFont(samples=[
|
18 |
+
Sample("40.wav", root_key=40),
|
19 |
+
Sample("60.wav", root_key=60),
|
20 |
+
Sample("20.wav", root_key=20)
|
21 |
+
])
|
22 |
+
distributor.distribute(sf)
|
23 |
+
|
24 |
+
self.assertEqual(Sample("20.wav", root_key=20, low_key=20, high_key=30), sf.samples[0])
|
25 |
+
self.assertEqual(Sample("40.wav", root_key=40, low_key=31, high_key=50), sf.samples[1])
|
26 |
+
self.assertEqual(Sample("60.wav", root_key=60, low_key=51, high_key=108), sf.samples[2])
|
27 |
+
|
28 |
+
def test_distribution_of_samples_over_108(self):
|
29 |
+
distributor = HighLowKeyDistributor()
|
30 |
+
sf = SoundFont(samples=[
|
31 |
+
Sample("40.wav", root_key=40),
|
32 |
+
Sample("110.wav", root_key=110)
|
33 |
+
])
|
34 |
+
distributor.distribute(sf)
|
35 |
+
|
36 |
+
self.assertEqual(Sample("40.wav", root_key=40, low_key=21, high_key=75), sf.samples[0])
|
37 |
+
self.assertEqual(Sample("110.wav", root_key=110, low_key=76, high_key=110), sf.samples[1])
|
38 |
+
|
39 |
+
def test_distribution_of_samples_with_duplicated_notes(self):
|
40 |
+
distributor = HighLowKeyDistributor()
|
41 |
+
sf = SoundFont(samples=[
|
42 |
+
Sample("40.wav", root_key=40),
|
43 |
+
Sample("40-1.wav", root_key=40),
|
44 |
+
Sample("60.wav", root_key=60),
|
45 |
+
Sample("20.wav", root_key=20)
|
46 |
+
])
|
47 |
+
distributor.distribute(sf)
|
48 |
+
|
49 |
+
self.assertEqual(Sample("20.wav", root_key=20, low_key=20, high_key=30), sf.samples[0])
|
50 |
+
self.assertEqual(Sample("40.wav", root_key=40, low_key=31, high_key=50), sf.samples[1])
|
51 |
+
self.assertEqual(Sample("40-1.wav", root_key=40, low_key=31, high_key=50), sf.samples[2])
|
52 |
+
self.assertEqual(Sample("60.wav", root_key=60, low_key=51, high_key=108), sf.samples[3])
|
53 |
+
|
54 |
+
|
55 |
+
class TestSoundFont(TestCase):
|
56 |
+
def test_creation_of_soundfont(self):
|
57 |
+
sf = SoundFont(samples=[
|
58 |
+
Sample("filename.wav", root_key=60)
|
59 |
+
])
|
60 |
+
self.assertEqual(sf.samples[0].key_range.high_key, 60)
|
61 |
+
self.assertEqual(sf.samples[0].key_range.low_key, 60)
|
62 |
+
|
63 |
+
|
64 |
+
class TestNoteNameMidiNumberMapper(TestCase):
|
65 |
+
def test_map_note_name(self):
|
66 |
+
mapper = NoteNameMidiNumberMapper()
|
67 |
+
self.assertEqual(21, mapper.map("A0"))
|
68 |
+
self.assertEqual(108, mapper.map("C8"))
|
69 |
+
self.assertEqual(61, mapper.map("C#4"))
|
70 |
+
self.assertEqual(60, mapper.map("C4"))
|
train_lfm.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
from torch import nn
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import einops
|
7 |
+
import wandb
|
8 |
+
import torch
|
9 |
+
# import wandb logging
|
10 |
+
from pytorch_lightning.loggers import WandbLogger
|
11 |
+
from stable_audio_tools import get_pretrained_model
|
12 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
13 |
+
|
14 |
+
|
15 |
+
class SinActivation(nn.Module):
|
16 |
+
def forward(self, x):
|
17 |
+
return torch.sin(x)
|
18 |
+
|
19 |
+
class FourierFeatures(nn.Module):
|
20 |
+
|
21 |
+
def __init__(self, in_features, out_features, n_layers):
|
22 |
+
super().__init__()
|
23 |
+
self.in_features = in_features
|
24 |
+
self.out_features = out_features
|
25 |
+
self.n_layers = n_layers
|
26 |
+
layers = []
|
27 |
+
layers += [nn.Linear(in_features, out_features)]
|
28 |
+
# add sin activation
|
29 |
+
layers += [SinActivation()]
|
30 |
+
for i in range(n_layers-1):
|
31 |
+
layers += [nn.Linear(out_features, out_features)]
|
32 |
+
layers += [SinActivation()]
|
33 |
+
self.layers = nn.Sequential(*layers)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.layers(x)
|
37 |
+
|
38 |
+
class FlowMatchingModule(pl.LightningModule):
|
39 |
+
|
40 |
+
def __init__(self, main_model=None, text_conditioner=None, max_tokens=128, n_channels=None, t_input=None):
|
41 |
+
super().__init__()
|
42 |
+
self.save_hyperparameters(ignore=['main_model', "text_conditioner"])
|
43 |
+
|
44 |
+
self.model = main_model.transformer
|
45 |
+
self.input_layer = main_model.transformer.project_in
|
46 |
+
self.output_layer = main_model.transformer.project_out
|
47 |
+
|
48 |
+
self.text_conditioner = text_conditioner
|
49 |
+
|
50 |
+
self.d_model = self.input_layer.weight.shape[0]
|
51 |
+
self.d_input = self.input_layer.weight.shape[1]
|
52 |
+
|
53 |
+
# use fourier features for schedule
|
54 |
+
self.schedule_embedding = FourierFeatures(1, self.d_model, 2)
|
55 |
+
# use learned positional encoding
|
56 |
+
self.pitch_embedding = nn.Parameter(torch.randn(n_channels, self.d_model))
|
57 |
+
# make embedding layer for tags
|
58 |
+
self.channels = n_channels
|
59 |
+
|
60 |
+
mean_proj = []
|
61 |
+
for layer in self.model.layers:
|
62 |
+
mean_proj += [nn.Linear(self.d_model, self.d_model)]
|
63 |
+
self.mean_proj = nn.ModuleList(mean_proj)
|
64 |
+
|
65 |
+
def get_example_inputs(self):
|
66 |
+
text = "A piano playing a C major chord"
|
67 |
+
conditioning, conditioning_mask = self.text_conditioner(text, device = self.device)
|
68 |
+
|
69 |
+
# repeat conditioning
|
70 |
+
conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels)
|
71 |
+
conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels)
|
72 |
+
|
73 |
+
t = torch.rand(1, device=self.device)
|
74 |
+
z = torch.randn(1, self.hparams.t_input ,self.hparams.n_channels, self.d_input , device=self.device)
|
75 |
+
return z, conditioning, conditioning_mask, t
|
76 |
+
|
77 |
+
|
78 |
+
def forward(self, x, conditioning, conditioning_mask, t):
|
79 |
+
batch, t_input, n_channels, d_input = x.shape
|
80 |
+
|
81 |
+
# add conditioning to x
|
82 |
+
x = self.input_layer(x)
|
83 |
+
tz = self.schedule_embedding(t[:,None,None,None])
|
84 |
+
pitch_z = self.pitch_embedding[None, None, :n_channels, :]
|
85 |
+
# print shapes
|
86 |
+
x = x + tz + pitch_z
|
87 |
+
rot = self.model.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
88 |
+
|
89 |
+
conditioning = einops.rearrange(conditioning, 'b t c d -> (b c) t d', c=self.channels)
|
90 |
+
conditioning_mask = einops.rearrange(conditioning_mask, 'b t c -> (b c) t', c=self.channels)
|
91 |
+
|
92 |
+
for layer_idx, layer in enumerate(self.model.layers):
|
93 |
+
x = einops.rearrange(x, 'b t c d -> (b c) t d')
|
94 |
+
|
95 |
+
x = layer(x, rotary_pos_emb=rot, context = conditioning, context_mask = conditioning_mask)
|
96 |
+
x = einops.rearrange(x, '(b c) t d -> b t c d', c=self.channels)
|
97 |
+
x_ch_mean = x.mean(dim=2)
|
98 |
+
x_ch_mean = self.mean_proj[layer_idx](x_ch_mean)
|
99 |
+
# non linearity
|
100 |
+
# x_ch_mean = torch.relu(x_ch_mean)
|
101 |
+
# # layer norm
|
102 |
+
# x_ch_mean = torch.layer_norm(x_ch_mean, x_ch_mean.shape[1:])
|
103 |
+
x += x_ch_mean[:, :, None, :]
|
104 |
+
x = self.output_layer(x)
|
105 |
+
return x
|
106 |
+
|
107 |
+
def step(self, batch, batch_idx):
|
108 |
+
x = batch["z"]
|
109 |
+
text = batch["text"]
|
110 |
+
conditioning, conditioning_mask = self.text_conditioner(text, device = self.device)
|
111 |
+
|
112 |
+
# repeat conditioning
|
113 |
+
conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels)
|
114 |
+
conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels)
|
115 |
+
|
116 |
+
x = einops.rearrange(x, 'b c d t -> b t c d')
|
117 |
+
z0 = torch.randn(x.shape, device=x.device)
|
118 |
+
z1 = x
|
119 |
+
t = torch.rand(x.shape[0], device=x.device)
|
120 |
+
zt = t[:,None,None,None] * z1 + (1 - t[:,None,None,None]) * z0
|
121 |
+
vt = self(zt,conditioning,conditioning_mask,t)
|
122 |
+
loss = (vt - (z1 - z0)).pow(2).mean()
|
123 |
+
return loss
|
124 |
+
|
125 |
+
@torch.inference_mode()
|
126 |
+
def sample(self, batch_size, text, steps=10, same_latent=False):
|
127 |
+
# Ensure model is on the correct device
|
128 |
+
device = next(self.parameters()).device
|
129 |
+
dtype = self.input_layer.weight.dtype
|
130 |
+
|
131 |
+
# Move conditioning to the correct device and dtype
|
132 |
+
conditioning, conditioning_mask = self.text_conditioner(text, device=device)
|
133 |
+
conditioning = einops.repeat(conditioning, "b t d-> b t c d", c=self.channels)
|
134 |
+
conditioning_mask = einops.repeat(
|
135 |
+
conditioning_mask, "b t -> b t c", c=self.channels
|
136 |
+
)
|
137 |
+
conditioning = conditioning.to(device=device, dtype=dtype)
|
138 |
+
conditioning_mask = conditioning_mask.to(device=device)
|
139 |
+
|
140 |
+
self.eval()
|
141 |
+
with torch.no_grad():
|
142 |
+
# Create initial noise on the correct device and dtype
|
143 |
+
z0 = torch.randn(
|
144 |
+
batch_size,
|
145 |
+
self.hparams.t_input,
|
146 |
+
self.hparams.n_channels,
|
147 |
+
self.d_input,
|
148 |
+
device=device,
|
149 |
+
dtype=dtype,
|
150 |
+
)
|
151 |
+
|
152 |
+
if same_latent:
|
153 |
+
z0 = z0[0].repeat(batch_size, 1, 1, 1)
|
154 |
+
|
155 |
+
zt = z0
|
156 |
+
for step in tqdm(range(steps)):
|
157 |
+
t = torch.tensor([step / steps], device=device, dtype=dtype)
|
158 |
+
zt = zt + (1 / steps) * self.forward(
|
159 |
+
zt, conditioning, conditioning_mask, t
|
160 |
+
)
|
161 |
+
|
162 |
+
return zt
|
163 |
+
|
164 |
+
|
165 |
+
def training_step(self, batch, batch_idx):
|
166 |
+
loss = self.step(batch, batch_idx)
|
167 |
+
self.log('trn_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
168 |
+
return loss
|
169 |
+
|
170 |
+
def validation_step(self, batch, batch_idx):
|
171 |
+
loss = self.step(batch, batch_idx)
|
172 |
+
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
173 |
+
return loss
|
174 |
+
|
175 |
+
def configure_optimizers(self):
|
176 |
+
return torch.optim.Adam(self.parameters(), lr=1e-5)
|
177 |
+
|
178 |
+
class EncodedAudioDataset(torch.utils.data.Dataset):
|
179 |
+
def __init__(self, paths, pitch_range):
|
180 |
+
records = []
|
181 |
+
print("Loading data")
|
182 |
+
for path in tqdm(paths):
|
183 |
+
records+=torch.load(path)
|
184 |
+
self.records = records
|
185 |
+
self.pitch_range = pitch_range
|
186 |
+
|
187 |
+
# keep only records with z
|
188 |
+
self.records = [r for r in self.records if "z" in r]
|
189 |
+
|
190 |
+
print(f"Loaded {len(self.records)} records")
|
191 |
+
|
192 |
+
|
193 |
+
def compose_prompt(self,record):
|
194 |
+
title = record["name"] if "name" in record else record["title"]
|
195 |
+
|
196 |
+
tags = record["tags"]
|
197 |
+
|
198 |
+
# take tags
|
199 |
+
# shuffle
|
200 |
+
tags = np.random.choice(tags, len(tags), replace=False)
|
201 |
+
# take random number of tags
|
202 |
+
tags = list(tags[:np.random.randint(0, len(tags)+1)])
|
203 |
+
#
|
204 |
+
# take either the title or group or type or nothing
|
205 |
+
if "type_group" in record and "type" in record:
|
206 |
+
type_group = record["type_group"]
|
207 |
+
type = record["type"]
|
208 |
+
head = np.random.choice([title, type_group, type])
|
209 |
+
else:
|
210 |
+
head = np.random.choice([title])
|
211 |
+
|
212 |
+
# append tags
|
213 |
+
# with 75% chance add head
|
214 |
+
elements = tags
|
215 |
+
if np.random.rand() < 0.75:
|
216 |
+
elements = [head] + elements
|
217 |
+
|
218 |
+
# shuffle elements
|
219 |
+
elements = np.random.choice(elements, len(elements), replace=False)
|
220 |
+
|
221 |
+
prompt = " ".join(elements)
|
222 |
+
|
223 |
+
# make everything lowercase
|
224 |
+
prompt = prompt.lower()
|
225 |
+
return prompt
|
226 |
+
|
227 |
+
def __len__(self):
|
228 |
+
return len(self.records)
|
229 |
+
|
230 |
+
def __getitem__(self, idx):
|
231 |
+
return {
|
232 |
+
"z": self.records[idx]["z"][self.pitch_range[0]:self.pitch_range[1]],
|
233 |
+
"text": self.compose_prompt(self.records[idx])
|
234 |
+
}
|
235 |
+
|
236 |
+
def check_for_nans(self):
|
237 |
+
for r in self.records:
|
238 |
+
# check if z has nan values
|
239 |
+
if np.isnan(r["z"]).any():
|
240 |
+
raise ValueError("Nan values in z")
|
241 |
+
|
242 |
+
def get_z_shape(self):
|
243 |
+
shapes = [r["z"].shape for r in self.records]
|
244 |
+
# return unique shapes
|
245 |
+
return list(set(shapes))
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
|
250 |
+
# set seed
|
251 |
+
SEED = 0
|
252 |
+
torch.manual_seed(SEED)
|
253 |
+
|
254 |
+
BATCH_SIZE = 1
|
255 |
+
LATENT_T = 86
|
256 |
+
|
257 |
+
# initialize wandb logger
|
258 |
+
wandb.init()
|
259 |
+
logger = WandbLogger(project="synth_flow")
|
260 |
+
|
261 |
+
# don't log models
|
262 |
+
wandb.config.log_model = False
|
263 |
+
|
264 |
+
DATASET = "dataset_a"
|
265 |
+
if DATASET == "dataset_a":
|
266 |
+
PITCH_RANGE = [2,12]
|
267 |
+
|
268 |
+
trn_ds = EncodedAudioDataset([f"artefacts/synth_data_{i}.pt" for i in range(9)], PITCH_RANGE)
|
269 |
+
trn_ds.check_for_nans()
|
270 |
+
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True)
|
271 |
+
|
272 |
+
val_ds = EncodedAudioDataset([f"artefacts/synth_data_9.pt"], PITCH_RANGE)
|
273 |
+
val_ds.check_for_nans()
|
274 |
+
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
|
275 |
+
|
276 |
+
|
277 |
+
elif DATASET == "dataset_b":
|
278 |
+
|
279 |
+
PITCH_RANGE = [0,10]
|
280 |
+
trn_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_{i}.pt" for i in range(3)], PITCH_RANGE)
|
281 |
+
trn_ds.check_for_nans()
|
282 |
+
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True)
|
283 |
+
|
284 |
+
val_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_3.pt"], PITCH_RANGE)
|
285 |
+
val_ds.check_for_nans()
|
286 |
+
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
|
287 |
+
|
288 |
+
src_model = get_pretrained_model("stabilityai/stable-audio-open-1.0")[0].to("cpu")
|
289 |
+
src_model = src_model.to("cpu")
|
290 |
+
transformer_model = src_model.model.model
|
291 |
+
transformer_model = transformer_model.train()
|
292 |
+
text_conditioner = src_model.conditioner.conditioners.prompt
|
293 |
+
|
294 |
+
t5_version = "google-t5/t5-base"
|
295 |
+
|
296 |
+
|
297 |
+
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
|
298 |
+
|
299 |
+
model = FlowMatchingModule(
|
300 |
+
main_model=transformer_model,
|
301 |
+
text_conditioner=text_conditioner,
|
302 |
+
n_channels=PITCH_RANGE[1] - PITCH_RANGE[0],
|
303 |
+
t_input=LATENT_T,
|
304 |
+
)
|
305 |
+
|
306 |
+
trainer = pl.Trainer(devices = [3], logger=logger, gradient_clip_val=1.0, callbacks=[lr_callback], max_epochs=1000, precision="16-mixed")
|
307 |
+
|
308 |
+
trainer.fit(model, trn_dl, val_dl, ckpt_path="synth_flow/9gzpz0i6/epoch=85-step=774000.ckpt")
|
309 |
+
# save checkpoint
|
310 |
+
trainer.save_checkpoint("artefacts/model_finetuned_2.ckpt")
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|