erl-j commited on
Commit
b362624
·
0 Parent(s):

first commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ # everything in assets is large
37
+ assets/**/* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .DS_Store
2
+ */.DS_Store
3
+ __pycache__/
4
+ output/
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: soundfont-generator
3
+ emoji: 🚀
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.8.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+ import gradio as gr
4
+ import datetime
5
+ import numpy as np
6
+ import spaces
7
+ import soundfile
8
+ import os
9
+ import sys
10
+ import zipfile
11
+ from pathlib import Path
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ sys.path.append("sf-creator-fork")
15
+ from main import sfz, decentsampler
16
+
17
+ decoder_path = "erl-j/soundfont-generator-assets/decoder.pt"
18
+ model_path = "erl-j/soundfont-generator-assets/synth_lfm_modern_bfloat16.pt"
19
+ # Download models from Hugging Face Hub
20
+ decoder_path = hf_hub_download("erl-j/soundfont-generator-assets", "decoder.pt")
21
+ model_path = hf_hub_download("erl-j/soundfont-generator-assets", "synth_lfm_modern_bfloat16.pt")
22
+
23
+ # Load models once at startup
24
+ device = "cuda"
25
+ decoder = torch.load(decoder_path, map_location=device).half().eval()
26
+ model = (
27
+ torch.load(model_path, map_location=device)
28
+ .half()
29
+ .eval()
30
+ )
31
+
32
+ @spaces.GPU
33
+ def generate_and_export_soundfont(text, steps=20, instrument_name=None):
34
+ sample_start = datetime.datetime.now()
35
+
36
+ # Generate audio as before
37
+ z = model.sample(1, text=[text], steps=steps)
38
+ z_reshaped = einops.rearrange(z, "b t c d -> (b c) d t")
39
+
40
+ with torch.no_grad():
41
+ audio = decoder.decode(z_reshaped)
42
+
43
+ audio_output = einops.rearrange(audio, "b c t -> c (b t)").cpu().numpy()
44
+ audio_output = audio_output / np.max(np.abs(audio_output))
45
+
46
+ # Export individual wav files
47
+ export_audio = audio.cpu().numpy().astype(np.float32)
48
+ output_dir = "output"
49
+ os.makedirs(output_dir, exist_ok=True)
50
+
51
+ # Create instrument name if not provided
52
+ if not instrument_name:
53
+ instrument_name = text.replace(" ", "_")[:20]
54
+
55
+ # Save individual WAV files
56
+ pitches = [
57
+ "C1",
58
+ "F#1",
59
+ "C2",
60
+ "F#2",
61
+ "C3",
62
+ "F#3",
63
+ "C4",
64
+ "F#4",
65
+ "C5",
66
+ "F#5",
67
+ "C6",
68
+ "F#6",
69
+ "C7",
70
+ "F#7",
71
+ "C8",
72
+ ]
73
+ wav_files = []
74
+ for i in range(audio.shape[0]):
75
+ wav_path = f"{output_dir}/{pitches[i]}.wav"
76
+ soundfile.write(wav_path, export_audio[i].T, 44100)
77
+ wav_files.append(wav_path)
78
+
79
+ # Generate SFZ file
80
+ sfz(
81
+ directory=output_dir,
82
+ lowkey="21",
83
+ highkey="108",
84
+ instrument=instrument_name,
85
+ loopmode="no_loop",
86
+ polyphony=None,
87
+ )
88
+
89
+ # Create zip file containing SFZ and WAV files for the complete soundfont
90
+ zip_path = f"{output_dir}/{instrument_name}_package.zip"
91
+ with zipfile.ZipFile(zip_path, "w") as zipf:
92
+ # Add SFZ file
93
+ sfz_file = f"{output_dir}/{instrument_name}.sfz"
94
+ zipf.write(sfz_file, os.path.basename(sfz_file))
95
+ # Add all WAV files
96
+ for wav_file in wav_files:
97
+ if os.path.exists(wav_file):
98
+ zipf.write(wav_file, os.path.basename(wav_file))
99
+
100
+ total_time = (datetime.datetime.now() - sample_start).total_seconds()
101
+
102
+ return (
103
+ (44100, audio_output.T),
104
+ f"Generation took {total_time:.2f}s\nFiles saved in {output_dir}",
105
+ zip_path,
106
+ wav_files,
107
+ )
108
+
109
+ custom_js = open("custom.js").read()
110
+ custom_css = open("custom.css").read()
111
+
112
+ demo = gr.Blocks(title="Erl-j's sound font generator", js=custom_js,
113
+ css = custom_css)
114
+
115
+ with demo:
116
+ gr.Markdown("""
117
+ # Erl-j's Soundfont Generator.
118
+ Generate soundfonts from text descriptions using latent flow matching. You can then download the complete SFZ soundfont package to use the instrument locally.
119
+ ## Instructions
120
+ 1. Enter a text prompt to describe the audio you want to generate.
121
+ 2. Adjust the number of generation steps to tradeoff between quality and speed (kindof).
122
+ 3. Click the "Generate Soundfont" button to generate the audio and soundfont.
123
+ 4. Preview the generated instrument with the keyboard.
124
+ 5. Export the soundfont by clicking the "Download SFZ Soundfont Package" button. You can then use the soundfont in a SFZ-compatible VST like [Sforzando](https://www.plogue.com/products/sforzando/).
125
+ """)
126
+
127
+
128
+
129
+ with gr.Row():
130
+ steps = gr.Slider(
131
+ minimum=1, maximum=50, value=20, step=1, label="Generation steps"
132
+ )
133
+
134
+ with gr.Row():
135
+ text_input = gr.Textbox(
136
+ label="Prompt",
137
+ placeholder="Enter text description (e.g. 'hard bass', 'sparkly bells')",
138
+ lines=2,
139
+ )
140
+
141
+ with gr.Row():
142
+ generate_btn = gr.Button("Generate Soundfont", variant="primary")
143
+
144
+ with gr.Row():
145
+ audio_output = gr.Audio(label="Generated Audio Preview", visible=False)
146
+ status_output = gr.Textbox(label="Status", lines=2, visible=False)
147
+
148
+ with gr.Row():
149
+ wav_files = gr.File(label="Individual WAV Files", file_count="multiple", visible=False, elem_id="individual-wav-files")
150
+
151
+ html = """
152
+ <div id="custom-player"
153
+ style="width: 100%; height: 600px; background-color: "red"; border: 1px solid #f8f9fa; border-radius: 5px; margin-top: 10px;"
154
+ ></div>
155
+ """
156
+
157
+ gr.HTML(html, min_height=800, max_height=800)
158
+ with gr.Row():
159
+ sf = gr.File(label="Download SFZ Soundfont Package", type="filepath", visible=True, elem_id="sfz")
160
+
161
+ gr.Markdown("""
162
+ # About
163
+ The model is a modified version of [stable audio open](https://huggingface.co/stabilityai/stable-audio-open-1.0).
164
+
165
+ Unlike the original model, this version uses latent flow matching rather than latent diffusion.
166
+ Secondly, the pitches are stacked in a channel dimension rather than concatenated in the time dimension.
167
+ This allows for faster generation.
168
+
169
+ Soundfont export code is based on the [sf-creator](https://github.com/paulwellnerbou/sf-creator) project.
170
+
171
+ Similar work by Nercessian and Imort: [InstrumentGen](https://instrumentgen.netlify.app/).
172
+
173
+ Thank you @carlthome for coming up with the name.
174
+
175
+ To cite this work, please use the following BibTeX entry:
176
+ ```bibtex
177
+ @misc{erl-j-soundfont-generator,
178
+ title={Erl-j's Soundfont Generator},
179
+ author={Nicolas Jonason},
180
+ year={2024},
181
+ publisher={Huggingface},
182
+ }
183
+ ```
184
+ """)
185
+
186
+ generate_btn.click(
187
+ fn=generate_and_export_soundfont,
188
+ inputs=[text_input, steps],
189
+ outputs=[audio_output, status_output, sf, wav_files],
190
+ ).success(js="() => console.log('Success')")
191
+
192
+ text_input.submit(
193
+ fn=generate_and_export_soundfont,
194
+ inputs=[text_input, steps],
195
+ outputs=[audio_output, status_output, sf, wav_files],
196
+ )
197
+
198
+ if __name__ == "__main__":
199
+ print("Starting demo...")
200
+ demo.launch()
custom.css ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;500&display=swap');
2
+
3
+ .keyboard-container {
4
+ width: 100%;
5
+ padding: 1.5rem;
6
+ background: #fafafa;
7
+ border: 1px solid #e5e5e5;
8
+ border-radius: 4px;
9
+ font-family: 'Roboto', sans-serif;
10
+ user-select: none;
11
+ }
12
+
13
+ .keyboard-row {
14
+ display: flex;
15
+ gap: 0.25rem;
16
+ margin-bottom: 0.25rem;
17
+ width: 100%;
18
+ }
19
+
20
+ .key {
21
+ width: calc((100% - 2.75rem) / 12);
22
+ aspect-ratio: 1;
23
+ min-width: 40px;
24
+ flex: none;
25
+ border: 1px solid #e5e5e5;
26
+ border-radius: 4px;
27
+ display: flex;
28
+ flex-direction: column;
29
+ align-items: center;
30
+ justify-content: center;
31
+ cursor: pointer;
32
+ background: white;
33
+ transition: all 0.2s cubic-bezier(0.4, 0, 0.2, 1);
34
+ user-select: none;
35
+ padding: 0.5rem;
36
+ }
37
+
38
+ .key:hover {
39
+ background: #f5f5f5;
40
+ transform: translateY(-1px);
41
+ }
42
+
43
+ .key:active {
44
+ transform: translateY(0);
45
+ }
46
+
47
+ .key-label {
48
+ font-size: 0.875rem;
49
+ font-weight: 500;
50
+ color: #333;
51
+ user-select: none;
52
+ }
53
+
54
+ .note-label {
55
+ font-size: 0.75rem;
56
+ color: #666;
57
+ margin-top: 0.25rem;
58
+ user-select: none;
59
+ }
60
+
61
+ .controls {
62
+ display: flex;
63
+ flex-wrap: wrap;
64
+ gap: 1rem;
65
+ margin-bottom: 1.5rem;
66
+ }
67
+
68
+ .effects-controls {
69
+ display: grid;
70
+ grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
71
+ gap: 1rem;
72
+ margin-bottom: 1.5rem;
73
+ padding: 1rem;
74
+ background: #fafafa;
75
+ border-radius: 4px;
76
+ }
77
+
78
+ .control-group {
79
+ display: flex;
80
+ flex-direction: column;
81
+ gap: 0.5rem;
82
+ }
83
+
84
+ .control-group label {
85
+ font-size: 0.75rem;
86
+ color: #666;
87
+ text-transform: uppercase;
88
+ letter-spacing: 0.05em;
89
+ user-select: none;
90
+ }
91
+
92
+ input[type="range"] {
93
+ width: 100%;
94
+ height: 24px; /* Increased height for better touch target */
95
+ background: transparent; /* Remove default background */
96
+ border-radius: 2px;
97
+ appearance: none;
98
+ cursor: pointer;
99
+ margin: 0;
100
+ padding: 10px 0; /* Add padding for better touch area */
101
+ }
102
+
103
+ input[type="range"]::-webkit-slider-thumb {
104
+ appearance: none;
105
+ width: 24px; /* Increased size */
106
+ height: 24px; /* Increased size */
107
+ background: #000000;
108
+ border-radius: 50%;
109
+ cursor: pointer;
110
+ transition: background 0.2s;
111
+ margin-top: -10px; /* Center thumb vertically */
112
+ user-select: none;
113
+ }
114
+
115
+ input[type="range"]::-webkit-slider-runnable-track {
116
+ background: #393a39;
117
+ height: 4px;
118
+ border-radius: 2px;
119
+ user-select: none;
120
+ }
121
+
122
+ select {
123
+ padding: 0.5rem;
124
+ border-radius: 4px;
125
+ border: 1px solid #e5e5e5;
126
+ background: white;
127
+ font-family: 'Roboto', sans-serif;
128
+ font-size: 0.875rem;
129
+ user-select: none;
130
+ cursor: pointer;
131
+ }
132
+
133
+ .button-group {
134
+ display: flex;
135
+ align-items: center;
136
+ gap: 0.5rem;
137
+ }
138
+
139
+ .button-group button {
140
+ width: 2rem;
141
+ height: 2rem;
142
+ padding: 0;
143
+ display: flex;
144
+ align-items: center;
145
+ justify-content: center;
146
+ border: 1px solid #e5e5e5;
147
+ border-radius: 4px;
148
+ background: white;
149
+ cursor: pointer;
150
+ transition: all 0.2s;
151
+ user-select: none;
152
+ }
153
+
154
+ .button-group button:hover {
155
+ background: #f5f5f5;
156
+ }
157
+
158
+ button {
159
+ padding: 0.5rem 1rem;
160
+ border: none;
161
+ border-radius: 4px;
162
+ background: #015131;
163
+ color: white;
164
+ font-family: 'Roboto', sans-serif;
165
+ font-size: 0.875rem;
166
+ cursor: pointer;
167
+ transition: all 0.2s;
168
+ user-select: none;
169
+ }
170
+
171
+ button:hover {
172
+ background: #002114;
173
+ }
174
+
175
+ body {
176
+ font-family: 'Roboto', sans-serif;
177
+ font-size: 1rem;
178
+ line-height: 1.5;
179
+ color: #333;
180
+ background: #f5f5f5;
181
+ margin: 0;
182
+ padding: 0;
183
+ display: flex;
184
+ justify-content: center;
185
+ align-items: center;
186
+ min-height: 100vh;
187
+ }
188
+
189
+ @media (max-width: 768px) {
190
+ .control-group { min-width: 100%; }
191
+ .key { min-width: 35px; }
192
+ .key-label { font-size: 0.75rem; }
193
+
194
+ input[type="range"] {
195
+ height: 32px; /* Even larger touch target on mobile */
196
+ padding: 14px 0;
197
+ }
198
+
199
+ input[type="range"]::-webkit-slider-thumb {
200
+ width: 28px; /* Larger thumb on mobile */
201
+ height: 28px;
202
+ }
203
+ }
custom.js ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function previewPlayer() {
2
+ class KeyboardPlayer {
3
+ constructor(containerId) {
4
+ this.container = document.getElementById(containerId);
5
+ this.initializeProperties();
6
+ this.loadToneJS().then(() => this.init());
7
+ this.setupWavFileObserver();
8
+
9
+
10
+ // Add click handlers for activation/deactivation
11
+ this.container.addEventListener('click', (e) => {
12
+ e.stopPropagation();
13
+ if (!this.keyboardEnabled) {
14
+ this.enableKeyboard();
15
+ }
16
+ });
17
+
18
+ document.addEventListener('click', (e) => {
19
+ if (!this.container.contains(e.target)) {
20
+ this.disableKeyboard();
21
+ }
22
+ });
23
+
24
+ // disable keyboard
25
+ this.disableKeyboard();
26
+ }
27
+
28
+ enableKeyboard() {
29
+ this.keyboardEnabled = true;
30
+ this.container.style.opacity = '1';
31
+ }
32
+
33
+ disableKeyboard() {
34
+ this.keyboardEnabled = false;
35
+ this.container.style.opacity = '0.5';
36
+ }
37
+
38
+
39
+
40
+ setupWavFileObserver() {
41
+ const observer = new MutationObserver((mutations) => {
42
+ const hasDownloadLinkChanges = mutations.some(mutation =>
43
+ mutation.type === 'childList' &&
44
+ mutation.target.classList.contains('download-link')
45
+ );
46
+
47
+ if (hasDownloadLinkChanges) {
48
+ this.initializeSampler();
49
+ this.enableKeyboard();
50
+ // scroll so middle of keyboard is in centre of viewport
51
+ const keyboardTop = this.container.querySelector('.keyboard').getBoundingClientRect().top;
52
+ window.scrollTo(0, keyboardTop - window.innerHeight / 2, { behavior: 'smooth' });
53
+ }
54
+ });
55
+
56
+ const wavFilesContainer = document.getElementById('individual-wav-files');
57
+ if (wavFilesContainer) {
58
+ observer.observe(wavFilesContainer, {
59
+ childList: true,
60
+ subtree: true
61
+ });
62
+ }
63
+ }
64
+
65
+ initializeProperties() {
66
+ this.sampler = null;
67
+ this.keyboardEnabled = true;
68
+ this.layout = null;
69
+ this.rootPitch = 60;
70
+ this.columnOffset = 2;
71
+ this.rowOffset = 4;
72
+ this.activeNotes = new Map();
73
+ this.reverb = null;
74
+ this.releaseTime = 0.1;
75
+ this.noteNames = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'];
76
+ this.majorScale = [0, 2, 4, 5, 7, 9, 11];
77
+ }
78
+
79
+ async loadToneJS() {
80
+ if (window.Tone) return;
81
+ const script = document.createElement('script');
82
+ script.src = 'https://cdnjs.cloudflare.com/ajax/libs/tone/14.8.49/Tone.js';
83
+ return new Promise((resolve, reject) => {
84
+ script.onload = resolve;
85
+ script.onerror = () => reject(new Error('Failed to load Tone.js'));
86
+ document.head.appendChild(script);
87
+ });
88
+ }
89
+
90
+ init() {
91
+ this.createUI();
92
+ this.detectKeyboardLayout();
93
+ this.setupEventListeners();
94
+ this.initializeEffects();
95
+ this.initializeSampler();
96
+ }
97
+
98
+ createUI() {
99
+ this.container.innerHTML = `
100
+ <div class="keyboard-container">
101
+ <div class="effects-controls">
102
+ <h3>Release & Reverb</h3>
103
+ <div class="effect-slider">
104
+ <label>Release: <span class="release-value">0.1s</span></label>
105
+ <input type="range" class="release-slider" min="0" max="3" step="0.1" value="0.1">
106
+ </div>
107
+ <div class="effect-slider">
108
+ <label>Reverb: <span class="reverb-value">50%</span></label>
109
+ <input type="range" class="reverb-slider" min="0" max="100" value="50">
110
+ </div>
111
+ </div>
112
+ <div class="keyboard"></div>
113
+ <br>
114
+ <div class="mapping-controls">
115
+ <h3>Keyboard Mapping</h3>
116
+ <div class="control-group">
117
+ <label>Root Pitch: <span class="root-value">C4</span></label>
118
+ <input type="range" class="root-slider" min="24" max="84" value="60">
119
+ </div>
120
+ <div class="control-group">
121
+ <label>Column Offset: <span class="column-value">2</span> keys from left</label>
122
+ <input type="range" class="column-slider" min="0" max="6" value="2">
123
+ </div>
124
+ <div class="control-group">
125
+ <label>Row Offset: <span class="row-value">4</span> scale degree(s)</label>
126
+ <input type="range" class="row-slider" min="1" max="20" value="4">
127
+ </div>
128
+ </div>
129
+ </div>
130
+ `;
131
+ this.cacheElements();
132
+ }
133
+
134
+ cacheElements() {
135
+ const selectors = {
136
+ keyboard: '.keyboard',
137
+ rootSlider: '.root-slider',
138
+ rootValue: '.root-value',
139
+ columnSlider: '.column-slider',
140
+ columnValue: '.column-value',
141
+ rowSlider: '.row-slider',
142
+ rowValue: '.row-value',
143
+ releaseSlider: '.release-slider',
144
+ releaseValue: '.release-value',
145
+ reverbSlider: '.reverb-slider',
146
+ reverbValue: '.reverb-value'
147
+ };
148
+ this.elements = Object.fromEntries(
149
+ Object.entries(selectors).map(([key, selector]) =>
150
+ [key, this.container.querySelector(selector)]
151
+ )
152
+ );
153
+ }
154
+
155
+ setupEventListeners() {
156
+ const handlers = {
157
+ releaseSlider: e => {
158
+ this.releaseTime = parseFloat(e.target.value);
159
+ this.elements.releaseValue.textContent = `${this.releaseTime}s`;
160
+ },
161
+ reverbSlider: e => {
162
+ const wetness = parseInt(e.target.value) / 100;
163
+ this.reverb.wet.value = wetness;
164
+ this.elements.reverbValue.textContent = `${e.target.value}%`;
165
+ },
166
+ rootSlider: e => {
167
+ this.rootPitch = parseInt(e.target.value);
168
+ this.elements.rootValue.textContent = this.midiToNoteName(this.rootPitch);
169
+ this.updateNotes();
170
+ },
171
+ columnSlider: e => {
172
+ this.columnOffset = parseInt(e.target.value);
173
+ this.elements.columnValue.textContent = this.columnOffset;
174
+ this.updateNotes();
175
+ },
176
+ rowSlider: e => {
177
+ this.rowOffset = parseInt(e.target.value);
178
+ this.elements.rowValue.textContent = this.rowOffset;
179
+ this.updateNotes();
180
+ }
181
+ };
182
+
183
+ Object.entries(handlers).forEach(([element, handler]) =>
184
+ this.elements[element].addEventListener('input', handler));
185
+
186
+ document.addEventListener('mouseup', () => this.handleMouseUp());
187
+ document.addEventListener('keydown', e => !e.repeat && this.handleKeyEvent(e, true));
188
+ document.addEventListener('keyup', e => this.handleKeyEvent(e, false));
189
+ }
190
+
191
+ initializeEffects() {
192
+ this.reverb = new Tone.Reverb({ decay: 1.5, wet: 0.5 }).toDestination();
193
+ }
194
+
195
+ async initializeSampler() {
196
+ const availableNotes = ['C1', 'F#1', 'C2', 'F#2', 'C3', 'F#3', 'C4', 'F#4', 'C5', 'F#5'];
197
+ const urls = Object.fromEntries(
198
+ availableNotes
199
+ .map(note => [note, document.querySelector(`a[href*="${note}.wav"]`)?.href])
200
+ .filter(([, url]) => url)
201
+ );
202
+
203
+ if (!Object.keys(urls).length) {
204
+ this.handleSamplerError();
205
+ return;
206
+ }
207
+
208
+ this.sampler = new Tone.Sampler({
209
+ urls,
210
+ onload: () => this.handleSamplerLoad(),
211
+ }).connect(this.reverb);
212
+ }
213
+
214
+ handleSamplerError() {
215
+ console.log('No WAV files found');
216
+
217
+ }
218
+
219
+ handleSamplerLoad() {
220
+ console.log('Sampler loaded');
221
+ this.container.querySelectorAll('.key').forEach(key => key.style.opacity = '1');
222
+ }
223
+
224
+ detectKeyboardLayout() {
225
+ this.layout = {
226
+ keys: [
227
+ { keys: '1234567890'.split(''), offset: 0 },
228
+ { keys: 'QWERTYUIOP'.split(''), offset: 1 },
229
+ { keys: 'ASDFGHJKL'.split(''), offset: 1.5 },
230
+ { keys: 'ZXCVBNM,.'.split(''), offset: 2 }
231
+ ]
232
+ }.keys;
233
+ this.createKeyboard();
234
+ }
235
+
236
+ createKeyboard() {
237
+ this.elements.keyboard.innerHTML = '';
238
+ this.layout.forEach((row, rowIndex) => {
239
+ const rowElement = document.createElement('div');
240
+ rowElement.className = 'keyboard-row';
241
+ rowElement.style.paddingLeft = `${row.offset * 3}%`;
242
+ row.keys.forEach(key => rowElement.appendChild(this.createKey(key)));
243
+ this.elements.keyboard.appendChild(rowElement);
244
+ });
245
+ this.updateNotes();
246
+ }
247
+
248
+ createKey(keyLabel) {
249
+ const key = document.createElement('div');
250
+ key.className = 'key';
251
+ key.innerHTML = `
252
+ <div class="key-label">${keyLabel}</div>
253
+ <div class="note-label"></div>
254
+ `;
255
+ key.addEventListener('mousedown', () => this.startNote(key));
256
+ key.addEventListener('mouseenter', e => e.buttons === 1 && this.startNote(key));
257
+ key.addEventListener('mouseleave', () => this.stopNote(key));
258
+ return key;
259
+ }
260
+
261
+ updateNotes() {
262
+ Array.from(this.elements.keyboard.children).forEach((row, rowIndex) => {
263
+ Array.from(row.children).forEach((key, columnIndex) => {
264
+ const horizontalDistance = columnIndex - this.columnOffset;
265
+ const verticalDistance = rowIndex * this.rowOffset;
266
+ const totalScaleDegrees = horizontalDistance - verticalDistance;
267
+ const octaves = Math.floor(totalScaleDegrees / 7);
268
+ const remainingDegrees = ((totalScaleDegrees % 7) + 7) % 7;
269
+ const semitonesFromRoot = this.majorScale[remainingDegrees] + (octaves * 12);
270
+ const midiNote = this.rootPitch + semitonesFromRoot;
271
+
272
+ this.updateKeyDisplay(key, midiNote);
273
+ });
274
+ });
275
+ }
276
+
277
+ updateKeyDisplay(key, midiNote) {
278
+ const isBaseRoot = midiNote === this.rootPitch;
279
+ const isOctaveRoot = midiNote % 12 === this.rootPitch % 12;
280
+ key.style.backgroundColor = isBaseRoot ? '#90EE90' : isOctaveRoot ? '#E8F5E9' : '';
281
+ const noteName = this.midiToNoteName(midiNote);
282
+ key.querySelector('.note-label').textContent = noteName;
283
+ key.dataset.note = noteName;
284
+ key.dataset.midi = midiNote;
285
+ }
286
+
287
+ handleKeyEvent(e, isKeyDown) {
288
+ if (!this.keyboardEnabled || !this.sampler) return;
289
+ const keyElement = this.findKeyElement(e.key.toUpperCase());
290
+ if (keyElement) {
291
+ e.preventDefault();
292
+ isKeyDown ? this.startNote(keyElement) : this.stopNote(keyElement);
293
+ }
294
+ }
295
+
296
+ startNote(keyElement) {
297
+ if (!this.sampler || !keyElement || this.activeNotes.has(keyElement)) return;
298
+ const note = keyElement.dataset.note;
299
+ if (!note) return;
300
+
301
+ Tone.start().then(() => {
302
+ this.sampler.triggerAttack(note);
303
+ this.activeNotes.set(keyElement, { note });
304
+ this.animateKey(keyElement, true);
305
+ });
306
+ }
307
+
308
+ stopNote(keyElement) {
309
+ if (!this.sampler || !keyElement) return;
310
+ const noteInfo = this.activeNotes.get(keyElement);
311
+ if (noteInfo) {
312
+ this.sampler.triggerRelease(noteInfo.note, "+" + this.releaseTime);
313
+ this.activeNotes.delete(keyElement);
314
+ this.animateKey(keyElement, false);
315
+ }
316
+ }
317
+
318
+ handleMouseUp() {
319
+ this.activeNotes.forEach((_, keyElement) => this.stopNote(keyElement));
320
+ }
321
+
322
+ findKeyElement(keyLabel) {
323
+ for (const row of this.elements.keyboard.children) {
324
+ for (const key of row.children) {
325
+ if (key.querySelector('.key-label').textContent === keyLabel) return key;
326
+ }
327
+ }
328
+ return null;
329
+ }
330
+
331
+ animateKey(keyElement, isDown) {
332
+ const midiNote = parseInt(keyElement.dataset.midi);
333
+ const isBaseRoot = midiNote === this.rootPitch;
334
+ const isOctaveRoot = midiNote % 12 === this.rootPitch % 12;
335
+
336
+ keyElement.style.transform = isDown ? 'scale(0.95)' : '';
337
+ keyElement.style.backgroundColor = isBaseRoot ? '#90EE90' :
338
+ isOctaveRoot ? '#E8F5E9' :
339
+ isDown ? '#f0f0f0' : '';
340
+ }
341
+
342
+ midiToNoteName(midiNumber) {
343
+ const octave = Math.floor(midiNumber / 12) - 1;
344
+ return `${this.noteNames[midiNumber % 12]}${octave}`;
345
+ }
346
+ }
347
+
348
+ let container = document.getElementById('custom-player');
349
+ if (!container) {
350
+ container = document.createElement('div');
351
+ container.id = 'custom-player';
352
+ document.body.appendChild(container);
353
+ }
354
+ new KeyboardPlayer('custom-player');
355
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ stable-audio-tools==0.0.16
4
+ gradio==5.8.0
5
+ einops
6
+ spaces
7
+ lxml
8
+ transformers==4.44.0
9
+ tokenizers==0.19.1
sf-creator-fork/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/
2
+ __pycache__
3
+ */__pycache__
sf-creator-fork/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Soundfont creation
2
+
3
+ This library aims to create a soundfont based on a directory containing sound files (`.wav`).
4
+
5
+ Current support:
6
+
7
+ * [SFZ Format](https://sfzformat.com/)
8
+ * [DecentSampler](https://www.decentsamples.com/product/decent-sampler-plugin/)
9
+
10
+ Planned support:
11
+
12
+ * [Soundfont 2 `sf2`](https://en.wikipedia.org/wiki/SoundFont) (This will be more tricky, as it is a binary format,
13
+ as [Polyphone](https://www.polyphone-soundfonts.com/) is able to convert `sfz` to `sf2`, I will postpone this)
14
+
15
+ ## Project setup
16
+
17
+ ### Creating a project specific virtual environment (recommended)
18
+
19
+ You can omit this step if you are ok with installing the dependencies
20
+ system wide and go directly to the next step: [Installing dependencies](#installing-dependencies).
21
+
22
+ ```
23
+ virtualenv venv
24
+ source venv/bin/activate
25
+ ```
26
+ or, under Windows:
27
+ ```
28
+ virtualenv venv
29
+ venv\Scripts\activate
30
+ ```
31
+
32
+ ### Installing dependencies
33
+
34
+ ```
35
+ pip install -r requirements.txt
36
+ ```
37
+
38
+ ## Run
39
+
40
+ This will create a file `soundfont.sfz` alongside the `wav` files in the given directory.
41
+
42
+ ```
43
+ python main.py sfz <directory-to-wave-files>
44
+ ```
45
+
46
+ Run `python main.py --help` or `python main.py <command> --help`, where `<command>` can be `sfz` or `decentsampler` for now,
47
+ to get the full list of arguments.
48
+
49
+ ## Automatic note detection and mapping of samples
50
+
51
+ The given samples are scanned for note names (A0 to C8). If a note name is found in a filename of a sample, the midi
52
+ note for this sample will be set automatically.
53
+
54
+ In addition to that, in case of missing samples in between for certain notes an automatic distribution is calculated, so that all notes between A0 and C8 are covered.
55
+
56
+ If there are two samples for the same note available, a round robin/random change is assumed.
57
+
58
+ ## TODO and resources
59
+
60
+ - [ ] Make automatic distribution over all midi notes from 21 to 108 optional, and add an option to configure the highest and the lowest, ideally relative to the hightest and lowest pitch of the samples
61
+ - [ ] Detect pitch automatically (for melodic instruments at least), using https://pypi.org/project/crepe/
62
+
63
+ ### More SFZ Support
64
+
65
+ The best starting point for SFZ is https://sfzformat.com/.
66
+
67
+ - [ ] Look at [SFZ Python Automapper by Peter Eastman](https://vis.versilstudios.com/sfzconverter.html#u13452-4), this looks like there is a lot that can be reused for sfz files
68
+ - [ ] And https://github.com/freepats/freepats-tools, too?
69
+ - [ ] Add support for velocity levels
70
+ - [ ] Add support for more options supported by SFZ, reverb, effects, attack, release and so on
71
+
72
+ ### DecentSampler support
73
+
74
+ An XML based format developed by David Hilowitz (see https://youtu.be/UxPRmD_RNCY).
75
+
76
+ - [x] Create an XML Schema for highlighting and autocompletion
77
+ - [x] Implement `DecentSamplerWriter`
78
+ - [x] Add options for UI (cover)
79
+ - [ ] ...and effects
80
+
81
+ ### SF2 Support
82
+
83
+ This will get tricky, as this is a binary format with not too much examples. There are a few applications reading or even writing sf2 out there, at least at a very basic level.
84
+ But as [Polyphone](https://www.polyphone-soundfonts.com/) is able to convert `sfz` to `sf2`, I will postpone this.
85
+
86
+ * Basic SFZ to SF2 converter in python: https://github.com/freepats/freepats-tools
87
+ * C++ library [sf2cute](http://gocha.github.io/sf2cute/)
88
+ * Python library reading sf2: https://pypi.org/project/sf2utils/
89
+ * C# code writing basic sf2 file: https://github.com/Kermalis/SoundFont2
90
+ * Code of Polyphone: https://github.com/davy7125/polyphone
sf-creator-fork/main.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import glob
5
+
6
+ from sfcreator.decentsampler import DecentSamplerWriter
7
+ from sfcreator.sfz import SfzWriter
8
+ from sfcreator.soundfont.soundfont import SoundFont, Sample, NoteNameMidiNumberMapper, HighLowKeyDistributor
9
+
10
+
11
+ def list_supported_files(directory):
12
+ return glob.glob(directory + "/*.wav")
13
+
14
+
15
+ def map_note_name(note_name: str):
16
+ if note_name.isdigit():
17
+ return int(note_name)
18
+ else:
19
+ return NoteNameMidiNumberMapper().map(note_name)
20
+
21
+ # @cli.command()
22
+ # @click.argument('directory')
23
+ # @click.option('--highkey', required=False, type=str, default="108")
24
+ # @click.option('--lowkey', required=False, type=str, default="21")
25
+ # @click.option('--instrument', help="name of the instrument", required=False, type=str)
26
+ # @click.option('--loopmode',
27
+ # help="loop mode, no_loop (default), one_shot, loop_continuous or loop_sustain,"
28
+ # " see https://sfzformat.com/opcodes/loop_mode",
29
+ # required=False, type=str, default="no_loop")
30
+ # @click.option('--polyphony',
31
+ # help="Polyphony voice limit, see https://sfzformat.com/opcodes/polyphony",
32
+ # required=False, type=str, default=None)
33
+ def sfz(directory: str, lowkey: str, highkey: str, instrument: str, loopmode: str, polyphony: str):
34
+ soundfont = make_soundfont(directory, map_note_name(lowkey), map_note_name(highkey), instrument, loopmode,
35
+ polyphony)
36
+ SfzWriter().write(directory, soundfont)
37
+
38
+
39
+ # @cli.command()
40
+ # @click.argument('directory')
41
+ # @click.option('--highkey', required=False, type=str, default="108")
42
+ # @click.option('--lowkey', required=False, type=str, default="21")
43
+ # @click.option('--instrument', help="name of the instrument", required=False, type=str)
44
+ # @click.option('--image', required=False, type=str)
45
+ # @click.option('--loopmode',
46
+ # help="loop mode, only supported for sfz. For compatibility reasons, --loopmode=one_shot will cause"
47
+ # " --release=20 for decent sampler, to ensure samples are always played until the end.",
48
+ # required=False, type=str, default="no_loop")
49
+ # @click.option('--polyphony',
50
+ # help="Polyphony voice limit, not supported in decent sampler format.",
51
+ # required=False, type=str, default=None)
52
+
53
+ def decentsampler(directory: str, lowkey: str, highkey: str, instrument: str, loopmode: str, polyphony: str,
54
+ image: str):
55
+ soundfont = make_soundfont(directory, map_note_name(lowkey), map_note_name(highkey), instrument, loopmode,
56
+ polyphony)
57
+ DecentSamplerWriter().write(directory, soundfont, image)
58
+
59
+
60
+ def make_soundfont(directory: str, lowkey: int, highkey: int, instrument: str, loopmode: str, polyphony: str):
61
+ files = list_supported_files(directory)
62
+ samples: List[Sample] = []
63
+ mapper = NoteNameMidiNumberMapper()
64
+ for file in files:
65
+ samples.append(mapper.mapped_sample(file))
66
+ soundfont = SoundFont(samples, loop_mode=loopmode, polyphony=polyphony)
67
+ if instrument is None or len(instrument) == 0:
68
+ soundfont.instrument_name = os.path.basename(os.path.dirname(directory))
69
+ else:
70
+ soundfont.instrument_name = instrument
71
+ HighLowKeyDistributor().distribute(soundfont, low_key=lowkey, high_key=highkey)
72
+ return soundfont
sf-creator-fork/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ click~=7.1.2
2
+ lxml~=4.6.1
sf-creator-fork/resources/decentsampler/dspreset.xsd ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <xs:schema attributeFormDefault="unqualified" elementFormDefault="qualified" xmlns:xs="http://www.w3.org/2001/XMLSchema">
2
+ <xs:element name="DecentSampler">
3
+ <xs:complexType>
4
+ <xs:sequence>
5
+ <xs:element name="ui">
6
+ <xs:complexType>
7
+ <xs:sequence>
8
+ <xs:element name="tab">
9
+ <xs:complexType>
10
+ <xs:sequence>
11
+ <xs:element name="labeled-knob" maxOccurs="unbounded" minOccurs="0">
12
+ <xs:complexType>
13
+ <xs:sequence>
14
+ <xs:element name="binding" maxOccurs="unbounded" minOccurs="0">
15
+ <xs:complexType>
16
+ <xs:simpleContent>
17
+ <xs:extension base="xs:string">
18
+ <xs:attribute type="xs:string" name="type" use="optional"/>
19
+ <xs:attribute type="xs:string" name="level" use="optional"/>
20
+ <xs:attribute type="xs:byte" name="position" use="optional"/>
21
+ <xs:attribute type="xs:string" name="parameter" use="optional"/>
22
+ <xs:attribute type="xs:string" name="translation" use="optional"/>
23
+ <xs:attribute type="xs:byte" name="translationOutputMin" use="optional"/>
24
+ <xs:attribute type="xs:float" name="translationOutputMax" use="optional"/>
25
+ <xs:attribute type="xs:string" name="translationTable" use="optional"/>
26
+ </xs:extension>
27
+ </xs:simpleContent>
28
+ </xs:complexType>
29
+ </xs:element>
30
+ </xs:sequence>
31
+ <xs:attribute type="xs:short" name="x" use="optional"/>
32
+ <xs:attribute type="xs:byte" name="y" use="optional"/>
33
+ <xs:attribute type="xs:string" name="label" use="optional"/>
34
+ <xs:attribute type="xs:string" name="type" use="optional"/>
35
+ <xs:attribute type="xs:short" name="minValue" use="optional"/>
36
+ <xs:attribute type="xs:short" name="maxValue" use="optional"/>
37
+ <xs:attribute type="xs:string" name="textColor" use="optional"/>
38
+ <xs:attribute type="xs:float" name="value" use="optional"/>
39
+ </xs:complexType>
40
+ </xs:element>
41
+ </xs:sequence>
42
+ <xs:attribute type="xs:string" name="name"/>
43
+ </xs:complexType>
44
+ </xs:element>
45
+ </xs:sequence>
46
+ <xs:attribute type="xs:string" name="bgImage"/>
47
+ <xs:attribute type="xs:short" name="width"/>
48
+ <xs:attribute type="xs:short" name="height"/>
49
+ <xs:attribute type="xs:string" name="layoutMode"/>
50
+ <xs:attribute type="xs:string" name="bgMode"/>
51
+ </xs:complexType>
52
+ </xs:element>
53
+ <xs:element name="groups">
54
+ <xs:complexType>
55
+ <xs:sequence>
56
+ <xs:element name="group" maxOccurs="unbounded" minOccurs="0">
57
+ <xs:complexType>
58
+ <xs:sequence>
59
+ <xs:element name="sample" maxOccurs="unbounded" minOccurs="0">
60
+ <xs:complexType>
61
+ <xs:simpleContent>
62
+ <xs:extension base="xs:string">
63
+ <xs:attribute type="xs:string" name="path" use="optional"/>
64
+ <xs:attribute type="xs:byte" name="rootNote" use="optional"/>
65
+ <xs:attribute type="xs:byte" name="loNote" use="optional"/>
66
+ <xs:attribute type="xs:byte" name="hiNote" use="optional"/>
67
+ <xs:attribute type="xs:int" name="loopStart" use="optional"/>
68
+ <xs:attribute type="xs:int" name="loopEnd" use="optional"/>
69
+ </xs:extension>
70
+ </xs:simpleContent>
71
+ </xs:complexType>
72
+ </xs:element>
73
+ </xs:sequence>
74
+ <xs:attribute type="xs:string" name="name" use="optional"/>
75
+ <xs:attribute type="xs:string" name="volume" use="optional"/>
76
+ <xs:attribute type="xs:byte" name="ampVelTrack" use="optional"/>
77
+ <xs:attribute type="xs:byte" name="modVolume" use="optional"/>
78
+ </xs:complexType>
79
+ </xs:element>
80
+ </xs:sequence>
81
+ <xs:attribute type="xs:float" name="attack"/>
82
+ <xs:attribute type="xs:float" name="decay"/>
83
+ <xs:attribute type="xs:float" name="sustain"/>
84
+ <xs:attribute type="xs:byte" name="release"/>
85
+ </xs:complexType>
86
+ </xs:element>
87
+ <xs:element name="effects">
88
+ <xs:complexType>
89
+ <xs:sequence>
90
+ <xs:element name="effect" maxOccurs="unbounded" minOccurs="0">
91
+ <xs:complexType>
92
+ <xs:simpleContent>
93
+ <xs:extension base="xs:string">
94
+ <xs:attribute type="xs:string" name="type" use="optional"/>
95
+ </xs:extension>
96
+ </xs:simpleContent>
97
+ </xs:complexType>
98
+ </xs:element>
99
+ </xs:sequence>
100
+ </xs:complexType>
101
+ </xs:element>
102
+ <xs:element name="midi">
103
+ <xs:complexType>
104
+ <xs:sequence>
105
+ <xs:element name="cc">
106
+ <xs:complexType>
107
+ <xs:sequence>
108
+ <xs:element name="binding">
109
+ <xs:complexType>
110
+ <xs:simpleContent>
111
+ <xs:extension base="xs:string">
112
+ <xs:attribute type="xs:string" name="level"/>
113
+ <xs:attribute type="xs:string" name="type"/>
114
+ <xs:attribute type="xs:byte" name="position"/>
115
+ <xs:attribute type="xs:string" name="parameter"/>
116
+ <xs:attribute type="xs:string" name="translation"/>
117
+ <xs:attribute type="xs:byte" name="translationOutputMin"/>
118
+ <xs:attribute type="xs:float" name="translationOutputMax"/>
119
+ </xs:extension>
120
+ </xs:simpleContent>
121
+ </xs:complexType>
122
+ </xs:element>
123
+ </xs:sequence>
124
+ <xs:attribute type="xs:byte" name="number"/>
125
+ </xs:complexType>
126
+ </xs:element>
127
+ </xs:sequence>
128
+ </xs:complexType>
129
+ </xs:element>
130
+ </xs:sequence>
131
+ <xs:attribute type="xs:byte" name="pluginVersion"/>
132
+ </xs:complexType>
133
+ </xs:element>
134
+ </xs:schema>
sf-creator-fork/sfcreator/__init__.py ADDED
File without changes
sf-creator-fork/sfcreator/decentsampler/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from lxml import etree
4
+ from sfcreator.soundfont.soundfont import SoundFont, Sample
5
+
6
+ class DecentSamplerWriter:
7
+ def write(self, directory: str, soundfont: SoundFont, image: str):
8
+ root = etree.Element("DecentSampler", pluginVersion="1")
9
+
10
+ if image is not None:
11
+ ui = etree.Element("ui", bgImage=image, width="812", height="375", layoutMode="relative", bgMode="top_left")
12
+ root.append(ui)
13
+ tab = etree.Element("tab", name="main")
14
+ ui.append(tab)
15
+
16
+ groups = etree.Element("groups")
17
+ root.append(groups)
18
+
19
+ if soundfont.loop_mode == "one_shot":
20
+ # assuming no sample will be longer than 20s
21
+ groups.set("release", "20")
22
+ elif soundfont.release is not None:
23
+ groups.set("release", str(soundfont.release))
24
+
25
+ for root_key in soundfont.root_keys():
26
+ self.write_root_key_sample_group(groups, soundfont.samples_for_root_key(root_key))
27
+
28
+ filename = directory + "/" + soundfont.instrument_name + ".dspreset"
29
+ print("Writing to " + filename)
30
+ et = etree.ElementTree(root)
31
+ et.write(filename, pretty_print=True, encoding='utf-8', xml_declaration=True)
32
+
33
+ def write_root_key_sample_group(self, groups, samples: List[Sample]):
34
+ group = etree.Element("group")
35
+ groups.append(group)
36
+ if len(samples) > 1:
37
+ group.set("seqMode", "random")
38
+ for index, sample in enumerate(samples, start=1):
39
+ xml_sample = etree.Element("sample")
40
+ group.append(xml_sample)
41
+ xml_sample.set("rootNote", str(sample.key_range.root_key))
42
+ xml_sample.set("loNote", str(sample.key_range.low_key))
43
+ xml_sample.set("hiNote", str(sample.key_range.high_key))
44
+ xml_sample.set("loVel", str(sample.velocity_range.low_velocity))
45
+ xml_sample.set("hiVel", str(sample.velocity_range.high_velocity))
46
+ xml_sample.set("seqPosition", str(index))
47
+ xml_sample.set("path", str(os.path.basename(sample.filename)))
sf-creator-fork/sfcreator/sfz/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ from sfcreator.soundfont.soundfont import SoundFont, Sample
5
+
6
+
7
+ class SfzWriter:
8
+ def write(self, directory: str, soundfont: SoundFont):
9
+ f = open(directory + "/" + soundfont.instrument_name + ".sfz", "w")
10
+ f.write("<group>\r\n")
11
+ if soundfont.loop_mode is not None:
12
+ f.write(f"loop_mode={soundfont.loop_mode}\r\n")
13
+ if soundfont.polyphony is not None:
14
+ f.write(f"loop_mode={soundfont.polyphony}\r\n")
15
+ for root_key in soundfont.root_keys():
16
+ self.write_root_key_sample_group(f, soundfont.samples_for_root_key(root_key))
17
+
18
+ def write_root_key_sample_group(self, f, samples: List[Sample]):
19
+ if len(samples) == 1:
20
+ sample = samples[0]
21
+ f.write(f"<region> sample={os.path.basename(sample.filename)}"
22
+ f" pitch_keycenter={str(sample.key_range.root_key)} lokey={str(sample.key_range.low_key)}"
23
+ f" hikey={str(sample.key_range.high_key)}\r\n")
24
+ else:
25
+ lorand = 0.0
26
+ randstep = 1 / len(samples)
27
+ for sample in samples:
28
+ hirand = lorand + randstep
29
+ f.write(f"<region> sample={os.path.basename(sample.filename)}"
30
+ f" pitch_keycenter={str(sample.key_range.root_key)} lokey={str(sample.key_range.low_key)}"
31
+ f" hikey={str(sample.key_range.high_key)}")
32
+ if lorand > 0.0:
33
+ f.write(f" lorand={lorand}")
34
+ if hirand < 1.0:
35
+ f.write(f" hirand={hirand}")
36
+ f.write("\r\n")
37
+ lorand = hirand
sf-creator-fork/sfcreator/soundfont/__init__.py ADDED
File without changes
sf-creator-fork/sfcreator/soundfont/soundfont.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+
4
+ class KeyRange:
5
+ def __init__(self, root_key: int = 60, low_key: int = None, high_key: int = None):
6
+ self.root_key = root_key
7
+ self.low_key = low_key
8
+ self.high_key = high_key
9
+ if self.low_key is None:
10
+ self.low_key = self.root_key
11
+ if self.high_key is None:
12
+ self.high_key = self.root_key
13
+
14
+ def in_range(self, key):
15
+ return self.low_key <= key <= self.high_key
16
+
17
+ def __eq__(self, other):
18
+ return self.low_key == other.low_key and self.root_key == other.root_key and self.high_key == other.high_key
19
+
20
+
21
+ class VelocityRange:
22
+ def __init__(self, low_velocity: int = 0, high_velocity: int = 127):
23
+ self.low_velocity = low_velocity
24
+ self.high_velocity = high_velocity
25
+
26
+ def in_range(self, key):
27
+ return self.low_velocity <= key <= self.high_velocity
28
+
29
+ def __eq__(self, other):
30
+ return self.low_velocity == other.low_velocity and self.high_velocity == other.root_key and\
31
+ self.high_velocity == other.high_key
32
+
33
+
34
+ class Sample:
35
+ def __init__(self, filename: str, root_key: int = 60, low_key: int = None, high_key: int = None):
36
+ self.velocity_range = VelocityRange()
37
+ self.filename = filename
38
+ self.key_range = KeyRange(root_key, low_key, high_key)
39
+ self.index = 0
40
+
41
+ def __repr__(self) -> str:
42
+ return 'Sample(filename=' + self.filename + ', root_key=' + str(self.key_range.root_key) + ', low_key=' + str(
43
+ self.key_range.low_key) + ', high_key=' + str(self.key_range.high_key) + ')'
44
+
45
+ def __eq__(self, other):
46
+ return self.filename == other.filename and self.key_range == other.key_range and self.index == other.index
47
+
48
+
49
+ class SoundFont:
50
+ def __init__(self, samples: List[Sample], loop_mode: str = "no_loop", polyphony: str = None, release: int = None,
51
+ instrument_name=None):
52
+ self.instrument_name = instrument_name
53
+ self.samples = samples
54
+ self.loop_mode = loop_mode
55
+ self.polyphony = polyphony
56
+ self.release = release
57
+
58
+ def root_keys(self):
59
+ return sorted(set([sample.key_range.root_key for sample in self.samples]))
60
+
61
+ def range_for_key(self, key) -> KeyRange:
62
+ samples_in_range = [sample for sample in self.samples if sample.key_range.in_range(key)]
63
+ return samples_in_range[0].key_range if len(samples_in_range) > 0 else None
64
+
65
+ def samples_for_root_key(self, root_key):
66
+ return [sample for sample in self.samples if sample.key_range.root_key == root_key]
67
+
68
+ def set_range(self, root_key, low_key=None, high_key=None):
69
+ for sample in self.samples_for_root_key(root_key):
70
+ if low_key is not None:
71
+ sample.key_range.low_key = low_key
72
+ if high_key is not None:
73
+ sample.key_range.high_key = high_key
74
+
75
+
76
+ class HighLowKeyDistributor:
77
+ def distribute(self, soundfont: SoundFont, low_key: int = 21, high_key: int = 108):
78
+ soundfont.samples.sort(key=lambda sample: sample.key_range.root_key, reverse=False)
79
+
80
+ prev_root_key: int = None
81
+ for root_key in soundfont.root_keys():
82
+ range = soundfont.range_for_key(root_key)
83
+ if prev_root_key is None:
84
+ lo_key = min(low_key, range.low_key)
85
+ soundfont.set_range(root_key, low_key=min(low_key, lo_key))
86
+ else:
87
+ prev_range = soundfont.range_for_key(prev_root_key)
88
+ mid_sample_key = int((range.low_key - prev_range.high_key) / 2) + prev_range.high_key
89
+ soundfont.set_range(prev_root_key, high_key=mid_sample_key)
90
+ soundfont.set_range(root_key, low_key=mid_sample_key + 1)
91
+ prev_root_key = root_key
92
+ soundfont.set_range(soundfont.root_keys()[-1],
93
+ high_key=max(high_key, soundfont.range_for_key(soundfont.root_keys()[-1]).high_key))
94
+
95
+
96
+ class NoteNameMidiNumberMapper:
97
+ def __init__(self):
98
+ self.index_offset = 21
99
+ self.note_name_midi_number_map: List[str] = []
100
+ for octave_number in range(0, 8):
101
+ for c in range(ord('A'), ord('B') + 1):
102
+ self._add_note(c, octave_number)
103
+ for c in range(ord('C'), ord('G') + 1):
104
+ self._add_note(c, octave_number + 1)
105
+ self.note_name_midi_number_map.append("C8")
106
+
107
+ def _add_note(self, c, octave_number):
108
+ self.note_name_midi_number_map.append(f"{chr(c)}{str(octave_number)}")
109
+ if chr(c) != "A" and chr(c) != "E":
110
+ self.note_name_midi_number_map.append(f"{chr(c)}#{str(octave_number)}")
111
+
112
+ def mapped_sample(self, filename: str):
113
+ for note in self.note_name_midi_number_map:
114
+ if note in filename:
115
+ return Sample(filename, self.map(note))
116
+ return Sample(filename)
117
+
118
+ def map(self, note_name: str):
119
+ return self.note_name_midi_number_map.index(note_name) + self.index_offset
sf-creator-fork/sfcreator/soundfont/test_soundfont.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest import TestCase
2
+
3
+ from sfcreator.soundfont.soundfont import *
4
+
5
+
6
+ class TestHighLowKeyDistributor(TestCase):
7
+ def test_distribution_of_samples_one_sample_only(self):
8
+ distributor = HighLowKeyDistributor()
9
+
10
+ sf = SoundFont(samples=[
11
+ Sample("filename.wav", root_key=60)
12
+ ])
13
+ distributor.distribute(sf)
14
+
15
+ def test_distribution_of_samples_unsorted_samples(self):
16
+ distributor = HighLowKeyDistributor()
17
+ sf = SoundFont(samples=[
18
+ Sample("40.wav", root_key=40),
19
+ Sample("60.wav", root_key=60),
20
+ Sample("20.wav", root_key=20)
21
+ ])
22
+ distributor.distribute(sf)
23
+
24
+ self.assertEqual(Sample("20.wav", root_key=20, low_key=20, high_key=30), sf.samples[0])
25
+ self.assertEqual(Sample("40.wav", root_key=40, low_key=31, high_key=50), sf.samples[1])
26
+ self.assertEqual(Sample("60.wav", root_key=60, low_key=51, high_key=108), sf.samples[2])
27
+
28
+ def test_distribution_of_samples_over_108(self):
29
+ distributor = HighLowKeyDistributor()
30
+ sf = SoundFont(samples=[
31
+ Sample("40.wav", root_key=40),
32
+ Sample("110.wav", root_key=110)
33
+ ])
34
+ distributor.distribute(sf)
35
+
36
+ self.assertEqual(Sample("40.wav", root_key=40, low_key=21, high_key=75), sf.samples[0])
37
+ self.assertEqual(Sample("110.wav", root_key=110, low_key=76, high_key=110), sf.samples[1])
38
+
39
+ def test_distribution_of_samples_with_duplicated_notes(self):
40
+ distributor = HighLowKeyDistributor()
41
+ sf = SoundFont(samples=[
42
+ Sample("40.wav", root_key=40),
43
+ Sample("40-1.wav", root_key=40),
44
+ Sample("60.wav", root_key=60),
45
+ Sample("20.wav", root_key=20)
46
+ ])
47
+ distributor.distribute(sf)
48
+
49
+ self.assertEqual(Sample("20.wav", root_key=20, low_key=20, high_key=30), sf.samples[0])
50
+ self.assertEqual(Sample("40.wav", root_key=40, low_key=31, high_key=50), sf.samples[1])
51
+ self.assertEqual(Sample("40-1.wav", root_key=40, low_key=31, high_key=50), sf.samples[2])
52
+ self.assertEqual(Sample("60.wav", root_key=60, low_key=51, high_key=108), sf.samples[3])
53
+
54
+
55
+ class TestSoundFont(TestCase):
56
+ def test_creation_of_soundfont(self):
57
+ sf = SoundFont(samples=[
58
+ Sample("filename.wav", root_key=60)
59
+ ])
60
+ self.assertEqual(sf.samples[0].key_range.high_key, 60)
61
+ self.assertEqual(sf.samples[0].key_range.low_key, 60)
62
+
63
+
64
+ class TestNoteNameMidiNumberMapper(TestCase):
65
+ def test_map_note_name(self):
66
+ mapper = NoteNameMidiNumberMapper()
67
+ self.assertEqual(21, mapper.map("A0"))
68
+ self.assertEqual(108, mapper.map("C8"))
69
+ self.assertEqual(61, mapper.map("C#4"))
70
+ self.assertEqual(60, mapper.map("C4"))
train_lfm.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from torch import nn
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import einops
7
+ import wandb
8
+ import torch
9
+ # import wandb logging
10
+ from pytorch_lightning.loggers import WandbLogger
11
+ from stable_audio_tools import get_pretrained_model
12
+ from transformers import T5Tokenizer, T5EncoderModel
13
+
14
+
15
+ class SinActivation(nn.Module):
16
+ def forward(self, x):
17
+ return torch.sin(x)
18
+
19
+ class FourierFeatures(nn.Module):
20
+
21
+ def __init__(self, in_features, out_features, n_layers):
22
+ super().__init__()
23
+ self.in_features = in_features
24
+ self.out_features = out_features
25
+ self.n_layers = n_layers
26
+ layers = []
27
+ layers += [nn.Linear(in_features, out_features)]
28
+ # add sin activation
29
+ layers += [SinActivation()]
30
+ for i in range(n_layers-1):
31
+ layers += [nn.Linear(out_features, out_features)]
32
+ layers += [SinActivation()]
33
+ self.layers = nn.Sequential(*layers)
34
+
35
+ def forward(self, x):
36
+ return self.layers(x)
37
+
38
+ class FlowMatchingModule(pl.LightningModule):
39
+
40
+ def __init__(self, main_model=None, text_conditioner=None, max_tokens=128, n_channels=None, t_input=None):
41
+ super().__init__()
42
+ self.save_hyperparameters(ignore=['main_model', "text_conditioner"])
43
+
44
+ self.model = main_model.transformer
45
+ self.input_layer = main_model.transformer.project_in
46
+ self.output_layer = main_model.transformer.project_out
47
+
48
+ self.text_conditioner = text_conditioner
49
+
50
+ self.d_model = self.input_layer.weight.shape[0]
51
+ self.d_input = self.input_layer.weight.shape[1]
52
+
53
+ # use fourier features for schedule
54
+ self.schedule_embedding = FourierFeatures(1, self.d_model, 2)
55
+ # use learned positional encoding
56
+ self.pitch_embedding = nn.Parameter(torch.randn(n_channels, self.d_model))
57
+ # make embedding layer for tags
58
+ self.channels = n_channels
59
+
60
+ mean_proj = []
61
+ for layer in self.model.layers:
62
+ mean_proj += [nn.Linear(self.d_model, self.d_model)]
63
+ self.mean_proj = nn.ModuleList(mean_proj)
64
+
65
+ def get_example_inputs(self):
66
+ text = "A piano playing a C major chord"
67
+ conditioning, conditioning_mask = self.text_conditioner(text, device = self.device)
68
+
69
+ # repeat conditioning
70
+ conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels)
71
+ conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels)
72
+
73
+ t = torch.rand(1, device=self.device)
74
+ z = torch.randn(1, self.hparams.t_input ,self.hparams.n_channels, self.d_input , device=self.device)
75
+ return z, conditioning, conditioning_mask, t
76
+
77
+
78
+ def forward(self, x, conditioning, conditioning_mask, t):
79
+ batch, t_input, n_channels, d_input = x.shape
80
+
81
+ # add conditioning to x
82
+ x = self.input_layer(x)
83
+ tz = self.schedule_embedding(t[:,None,None,None])
84
+ pitch_z = self.pitch_embedding[None, None, :n_channels, :]
85
+ # print shapes
86
+ x = x + tz + pitch_z
87
+ rot = self.model.rotary_pos_emb.forward_from_seq_len(x.shape[1])
88
+
89
+ conditioning = einops.rearrange(conditioning, 'b t c d -> (b c) t d', c=self.channels)
90
+ conditioning_mask = einops.rearrange(conditioning_mask, 'b t c -> (b c) t', c=self.channels)
91
+
92
+ for layer_idx, layer in enumerate(self.model.layers):
93
+ x = einops.rearrange(x, 'b t c d -> (b c) t d')
94
+
95
+ x = layer(x, rotary_pos_emb=rot, context = conditioning, context_mask = conditioning_mask)
96
+ x = einops.rearrange(x, '(b c) t d -> b t c d', c=self.channels)
97
+ x_ch_mean = x.mean(dim=2)
98
+ x_ch_mean = self.mean_proj[layer_idx](x_ch_mean)
99
+ # non linearity
100
+ # x_ch_mean = torch.relu(x_ch_mean)
101
+ # # layer norm
102
+ # x_ch_mean = torch.layer_norm(x_ch_mean, x_ch_mean.shape[1:])
103
+ x += x_ch_mean[:, :, None, :]
104
+ x = self.output_layer(x)
105
+ return x
106
+
107
+ def step(self, batch, batch_idx):
108
+ x = batch["z"]
109
+ text = batch["text"]
110
+ conditioning, conditioning_mask = self.text_conditioner(text, device = self.device)
111
+
112
+ # repeat conditioning
113
+ conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels)
114
+ conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels)
115
+
116
+ x = einops.rearrange(x, 'b c d t -> b t c d')
117
+ z0 = torch.randn(x.shape, device=x.device)
118
+ z1 = x
119
+ t = torch.rand(x.shape[0], device=x.device)
120
+ zt = t[:,None,None,None] * z1 + (1 - t[:,None,None,None]) * z0
121
+ vt = self(zt,conditioning,conditioning_mask,t)
122
+ loss = (vt - (z1 - z0)).pow(2).mean()
123
+ return loss
124
+
125
+ @torch.inference_mode()
126
+ def sample(self, batch_size, text, steps=10, same_latent=False):
127
+ # Ensure model is on the correct device
128
+ device = next(self.parameters()).device
129
+ dtype = self.input_layer.weight.dtype
130
+
131
+ # Move conditioning to the correct device and dtype
132
+ conditioning, conditioning_mask = self.text_conditioner(text, device=device)
133
+ conditioning = einops.repeat(conditioning, "b t d-> b t c d", c=self.channels)
134
+ conditioning_mask = einops.repeat(
135
+ conditioning_mask, "b t -> b t c", c=self.channels
136
+ )
137
+ conditioning = conditioning.to(device=device, dtype=dtype)
138
+ conditioning_mask = conditioning_mask.to(device=device)
139
+
140
+ self.eval()
141
+ with torch.no_grad():
142
+ # Create initial noise on the correct device and dtype
143
+ z0 = torch.randn(
144
+ batch_size,
145
+ self.hparams.t_input,
146
+ self.hparams.n_channels,
147
+ self.d_input,
148
+ device=device,
149
+ dtype=dtype,
150
+ )
151
+
152
+ if same_latent:
153
+ z0 = z0[0].repeat(batch_size, 1, 1, 1)
154
+
155
+ zt = z0
156
+ for step in tqdm(range(steps)):
157
+ t = torch.tensor([step / steps], device=device, dtype=dtype)
158
+ zt = zt + (1 / steps) * self.forward(
159
+ zt, conditioning, conditioning_mask, t
160
+ )
161
+
162
+ return zt
163
+
164
+
165
+ def training_step(self, batch, batch_idx):
166
+ loss = self.step(batch, batch_idx)
167
+ self.log('trn_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
168
+ return loss
169
+
170
+ def validation_step(self, batch, batch_idx):
171
+ loss = self.step(batch, batch_idx)
172
+ self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
173
+ return loss
174
+
175
+ def configure_optimizers(self):
176
+ return torch.optim.Adam(self.parameters(), lr=1e-5)
177
+
178
+ class EncodedAudioDataset(torch.utils.data.Dataset):
179
+ def __init__(self, paths, pitch_range):
180
+ records = []
181
+ print("Loading data")
182
+ for path in tqdm(paths):
183
+ records+=torch.load(path)
184
+ self.records = records
185
+ self.pitch_range = pitch_range
186
+
187
+ # keep only records with z
188
+ self.records = [r for r in self.records if "z" in r]
189
+
190
+ print(f"Loaded {len(self.records)} records")
191
+
192
+
193
+ def compose_prompt(self,record):
194
+ title = record["name"] if "name" in record else record["title"]
195
+
196
+ tags = record["tags"]
197
+
198
+ # take tags
199
+ # shuffle
200
+ tags = np.random.choice(tags, len(tags), replace=False)
201
+ # take random number of tags
202
+ tags = list(tags[:np.random.randint(0, len(tags)+1)])
203
+ #
204
+ # take either the title or group or type or nothing
205
+ if "type_group" in record and "type" in record:
206
+ type_group = record["type_group"]
207
+ type = record["type"]
208
+ head = np.random.choice([title, type_group, type])
209
+ else:
210
+ head = np.random.choice([title])
211
+
212
+ # append tags
213
+ # with 75% chance add head
214
+ elements = tags
215
+ if np.random.rand() < 0.75:
216
+ elements = [head] + elements
217
+
218
+ # shuffle elements
219
+ elements = np.random.choice(elements, len(elements), replace=False)
220
+
221
+ prompt = " ".join(elements)
222
+
223
+ # make everything lowercase
224
+ prompt = prompt.lower()
225
+ return prompt
226
+
227
+ def __len__(self):
228
+ return len(self.records)
229
+
230
+ def __getitem__(self, idx):
231
+ return {
232
+ "z": self.records[idx]["z"][self.pitch_range[0]:self.pitch_range[1]],
233
+ "text": self.compose_prompt(self.records[idx])
234
+ }
235
+
236
+ def check_for_nans(self):
237
+ for r in self.records:
238
+ # check if z has nan values
239
+ if np.isnan(r["z"]).any():
240
+ raise ValueError("Nan values in z")
241
+
242
+ def get_z_shape(self):
243
+ shapes = [r["z"].shape for r in self.records]
244
+ # return unique shapes
245
+ return list(set(shapes))
246
+
247
+
248
+ if __name__ == "__main__":
249
+
250
+ # set seed
251
+ SEED = 0
252
+ torch.manual_seed(SEED)
253
+
254
+ BATCH_SIZE = 1
255
+ LATENT_T = 86
256
+
257
+ # initialize wandb logger
258
+ wandb.init()
259
+ logger = WandbLogger(project="synth_flow")
260
+
261
+ # don't log models
262
+ wandb.config.log_model = False
263
+
264
+ DATASET = "dataset_a"
265
+ if DATASET == "dataset_a":
266
+ PITCH_RANGE = [2,12]
267
+
268
+ trn_ds = EncodedAudioDataset([f"artefacts/synth_data_{i}.pt" for i in range(9)], PITCH_RANGE)
269
+ trn_ds.check_for_nans()
270
+ trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True)
271
+
272
+ val_ds = EncodedAudioDataset([f"artefacts/synth_data_9.pt"], PITCH_RANGE)
273
+ val_ds.check_for_nans()
274
+ val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
275
+
276
+
277
+ elif DATASET == "dataset_b":
278
+
279
+ PITCH_RANGE = [0,10]
280
+ trn_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_{i}.pt" for i in range(3)], PITCH_RANGE)
281
+ trn_ds.check_for_nans()
282
+ trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True)
283
+
284
+ val_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_3.pt"], PITCH_RANGE)
285
+ val_ds.check_for_nans()
286
+ val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
287
+
288
+ src_model = get_pretrained_model("stabilityai/stable-audio-open-1.0")[0].to("cpu")
289
+ src_model = src_model.to("cpu")
290
+ transformer_model = src_model.model.model
291
+ transformer_model = transformer_model.train()
292
+ text_conditioner = src_model.conditioner.conditioners.prompt
293
+
294
+ t5_version = "google-t5/t5-base"
295
+
296
+
297
+ lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
298
+
299
+ model = FlowMatchingModule(
300
+ main_model=transformer_model,
301
+ text_conditioner=text_conditioner,
302
+ n_channels=PITCH_RANGE[1] - PITCH_RANGE[0],
303
+ t_input=LATENT_T,
304
+ )
305
+
306
+ trainer = pl.Trainer(devices = [3], logger=logger, gradient_clip_val=1.0, callbacks=[lr_callback], max_epochs=1000, precision="16-mixed")
307
+
308
+ trainer.fit(model, trn_dl, val_dl, ckpt_path="synth_flow/9gzpz0i6/epoch=85-step=774000.ckpt")
309
+ # save checkpoint
310
+ trainer.save_checkpoint("artefacts/model_finetuned_2.ckpt")
311
+
312
+
313
+
314
+