Spaces:
Runtime error
Runtime error
add sample_rate and n_fft params
Browse files- README.md +3 -1
- scripts/audio_to_images.py +3 -1
- scripts/train_unconditional.py +5 -1
- scripts/train_vae.py +15 -3
README.md
CHANGED
@@ -71,7 +71,9 @@ python scripts/audio_to_images.py \
|
|
71 |
--output_dir data/audio-diffusion-256 \
|
72 |
--push_to_hub teticio/audio-diffusion-256
|
73 |
```
|
74 |
-
|
|
|
|
|
75 |
## Train model
|
76 |
#### Run training on local machine.
|
77 |
```bash
|
|
|
71 |
--output_dir data/audio-diffusion-256 \
|
72 |
--push_to_hub teticio/audio-diffusion-256
|
73 |
```
|
74 |
+
|
75 |
+
Note that the default `sample_rate` is 22050 and audios will be resampled if they are at a different rate. If you change this value, you may find that the results in the `test_mel.ipynb` notebook are not good (for example, if `sample_rate` is 48000) and that it is necessary to adjust `n_fft` (for example, to 2000 instead of the default value of 2048; alternatively, you can resample to a `sample_rate` of 44100). Make sure you use the same parameters for training and inference. You should also bear in mind that not all resolutions work with the neural network architecture as currently configured - you should be safe if you stick to powers of 2.
|
76 |
+
|
77 |
## Train model
|
78 |
#### Run training on local machine.
|
79 |
```bash
|
scripts/audio_to_images.py
CHANGED
@@ -19,7 +19,8 @@ def main(args):
|
|
19 |
mel = Mel(x_res=args.resolution[0],
|
20 |
y_res=args.resolution[1],
|
21 |
hop_length=args.hop_length,
|
22 |
-
sample_rate=args.sample_rate
|
|
|
23 |
os.makedirs(args.output_dir, exist_ok=True)
|
24 |
audio_files = [
|
25 |
os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
|
@@ -86,6 +87,7 @@ if __name__ == "__main__":
|
|
86 |
parser.add_argument("--hop_length", type=int, default=512)
|
87 |
parser.add_argument("--push_to_hub", type=str, default=None)
|
88 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
|
|
89 |
args = parser.parse_args()
|
90 |
|
91 |
if args.input_dir is None:
|
|
|
19 |
mel = Mel(x_res=args.resolution[0],
|
20 |
y_res=args.resolution[1],
|
21 |
hop_length=args.hop_length,
|
22 |
+
sample_rate=args.sample_rate,
|
23 |
+
n_fft=args.n_fft)
|
24 |
os.makedirs(args.output_dir, exist_ok=True)
|
25 |
audio_files = [
|
26 |
os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
|
|
|
87 |
parser.add_argument("--hop_length", type=int, default=512)
|
88 |
parser.add_argument("--push_to_hub", type=str, default=None)
|
89 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
90 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
91 |
args = parser.parse_args()
|
92 |
|
93 |
if args.input_dir is None:
|
scripts/train_unconditional.py
CHANGED
@@ -173,7 +173,9 @@ def main(args):
|
|
173 |
|
174 |
mel = Mel(x_res=resolution[1],
|
175 |
y_res=resolution[0],
|
176 |
-
hop_length=args.hop_length
|
|
|
|
|
177 |
|
178 |
global_step = 0
|
179 |
for epoch in range(args.num_epochs):
|
@@ -362,6 +364,8 @@ if __name__ == "__main__":
|
|
362 |
"and an Nvidia Ampere GPU."),
|
363 |
)
|
364 |
parser.add_argument("--hop_length", type=int, default=512)
|
|
|
|
|
365 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
366 |
parser.add_argument("--start_epoch", type=int, default=0)
|
367 |
parser.add_argument("--num_train_steps", type=int, default=1000)
|
|
|
173 |
|
174 |
mel = Mel(x_res=resolution[1],
|
175 |
y_res=resolution[0],
|
176 |
+
hop_length=args.hop_length,
|
177 |
+
sample_rate=args.sample_rate,
|
178 |
+
n_fft=args.n_fft)
|
179 |
|
180 |
global_step = 0
|
181 |
for epoch in range(args.num_epochs):
|
|
|
364 |
"and an Nvidia Ampere GPU."),
|
365 |
)
|
366 |
parser.add_argument("--hop_length", type=int, default=512)
|
367 |
+
parser.add_argument("--sample_rate", type=int, default=22050)
|
368 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
369 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
370 |
parser.add_argument("--start_epoch", type=int, default=0)
|
371 |
parser.add_argument("--num_train_steps", type=int, default=1000)
|
scripts/train_vae.py
CHANGED
@@ -60,10 +60,16 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
|
|
60 |
|
61 |
class ImageLogger(Callback):
|
62 |
|
63 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
64 |
super().__init__()
|
65 |
self.every = every
|
66 |
self.hop_length = hop_length
|
|
|
|
|
67 |
|
68 |
@rank_zero_only
|
69 |
def log_images_and_audios(self, pl_module, batch):
|
@@ -76,7 +82,9 @@ class ImageLogger(Callback):
|
|
76 |
channels = image_shape[1]
|
77 |
mel = Mel(x_res=image_shape[2],
|
78 |
y_res=image_shape[3],
|
79 |
-
hop_length=self.hop_length
|
|
|
|
|
80 |
|
81 |
for k in images:
|
82 |
images[k] = images[k].detach().cpu()
|
@@ -145,6 +153,8 @@ if __name__ == "__main__":
|
|
145 |
type=int,
|
146 |
default=1)
|
147 |
parser.add_argument("--hop_length", type=int, default=512)
|
|
|
|
|
148 |
parser.add_argument("--save_images_batches", type=int, default=1000)
|
149 |
parser.add_argument("--max_epochs", type=int, default=100)
|
150 |
args = parser.parse_args()
|
@@ -166,7 +176,9 @@ if __name__ == "__main__":
|
|
166 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
167 |
callbacks=[
|
168 |
ImageLogger(every=args.save_images_batches,
|
169 |
-
hop_length=args.hop_length
|
|
|
|
|
170 |
HFModelCheckpoint(ldm_config=config,
|
171 |
hf_checkpoint=args.hf_checkpoint_dir,
|
172 |
dirpath=args.ldm_checkpoint_dir,
|
|
|
60 |
|
61 |
class ImageLogger(Callback):
|
62 |
|
63 |
+
def __init__(self,
|
64 |
+
every=1000,
|
65 |
+
hop_length=512,
|
66 |
+
sample_rate=22050,
|
67 |
+
n_fft=2048):
|
68 |
super().__init__()
|
69 |
self.every = every
|
70 |
self.hop_length = hop_length
|
71 |
+
self.sample_rate = sample_rate
|
72 |
+
self.n_fft = n_fft
|
73 |
|
74 |
@rank_zero_only
|
75 |
def log_images_and_audios(self, pl_module, batch):
|
|
|
82 |
channels = image_shape[1]
|
83 |
mel = Mel(x_res=image_shape[2],
|
84 |
y_res=image_shape[3],
|
85 |
+
hop_length=self.hop_length,
|
86 |
+
sample_rate=self.sample_rate,
|
87 |
+
n_fft=self.n_fft)
|
88 |
|
89 |
for k in images:
|
90 |
images[k] = images[k].detach().cpu()
|
|
|
153 |
type=int,
|
154 |
default=1)
|
155 |
parser.add_argument("--hop_length", type=int, default=512)
|
156 |
+
parser.add_argument("--sample_rate", type=int, default=22050)
|
157 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
158 |
parser.add_argument("--save_images_batches", type=int, default=1000)
|
159 |
parser.add_argument("--max_epochs", type=int, default=100)
|
160 |
args = parser.parse_args()
|
|
|
176 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
177 |
callbacks=[
|
178 |
ImageLogger(every=args.save_images_batches,
|
179 |
+
hop_length=args.hop_length,
|
180 |
+
sample_rate=args.sample_rate,
|
181 |
+
n_fft=args.n_fft),
|
182 |
HFModelCheckpoint(ldm_config=config,
|
183 |
hf_checkpoint=args.hf_checkpoint_dir,
|
184 |
dirpath=args.ldm_checkpoint_dir,
|