Spaces:
Runtime error
Runtime error
update for sagemaker
Browse files- README.md +15 -12
- src/train_unconditional.py +5 -5
README.md
CHANGED
@@ -1,19 +1,22 @@
|
|
1 |
# audio-diffusion
|
2 |
```bash
|
|
|
|
|
|
|
3 |
python src/audio_to_images.py \
|
4 |
-
--resolution
|
5 |
-
--input_dir
|
6 |
-
--output_dir
|
7 |
```
|
8 |
```bash
|
9 |
accelerate launch src/train_unconditional.py \
|
10 |
-
--dataset_name
|
11 |
-
--resolution
|
12 |
-
--output_dir
|
13 |
-
--train_batch_size
|
14 |
-
--num_epochs
|
15 |
-
--gradient_accumulation_steps
|
16 |
-
--learning_rate
|
17 |
-
--lr_warmup_steps
|
18 |
-
--mixed_precision
|
19 |
```
|
|
|
1 |
# audio-diffusion
|
2 |
```bash
|
3 |
+
accelerate config
|
4 |
+
```
|
5 |
+
```bash
|
6 |
python src/audio_to_images.py \
|
7 |
+
--resolution 256 \
|
8 |
+
--input_dir path-to-audio-files \
|
9 |
+
--output_dir data-256
|
10 |
```
|
11 |
```bash
|
12 |
accelerate launch src/train_unconditional.py \
|
13 |
+
--dataset_name data-256 \
|
14 |
+
--resolution 256 \
|
15 |
+
--output_dir ddpm-ema-audio-256 \
|
16 |
+
--train_batch_size 16 \
|
17 |
+
--num_epochs 100 \
|
18 |
+
--gradient_accumulation_steps 1 \
|
19 |
+
--learning_rate 1e-4 \
|
20 |
+
--lr_warmup_steps 500 \
|
21 |
+
--mixed_precision no
|
22 |
```
|
src/train_unconditional.py
CHANGED
@@ -253,7 +253,7 @@ if __name__ == "__main__":
|
|
253 |
help="A folder containing the training data.",
|
254 |
)
|
255 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
256 |
-
parser.add_argument("--overwrite_output_dir",
|
257 |
parser.add_argument("--cache_dir", type=str, default=None)
|
258 |
parser.add_argument("--resolution", type=int, default=64)
|
259 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
@@ -269,15 +269,15 @@ if __name__ == "__main__":
|
|
269 |
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
270 |
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
271 |
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
272 |
-
parser.add_argument("--use_ema",
|
273 |
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
274 |
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
275 |
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
276 |
-
parser.add_argument("--push_to_hub",
|
277 |
-
parser.add_argument("--use_auth_token",
|
278 |
parser.add_argument("--hub_token", type=str, default=None)
|
279 |
parser.add_argument("--hub_model_id", type=str, default=None)
|
280 |
-
parser.add_argument("--hub_private_repo",
|
281 |
parser.add_argument("--logging_dir", type=str, default="logs")
|
282 |
parser.add_argument(
|
283 |
"--mixed_precision",
|
|
|
253 |
help="A folder containing the training data.",
|
254 |
)
|
255 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
256 |
+
parser.add_argument("--overwrite_output_dir", type=bool, default=False)
|
257 |
parser.add_argument("--cache_dir", type=str, default=None)
|
258 |
parser.add_argument("--resolution", type=int, default=64)
|
259 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
|
|
269 |
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
270 |
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
271 |
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
272 |
+
parser.add_argument("--use_ema", type=bool, default=True)
|
273 |
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
274 |
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
275 |
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
276 |
+
parser.add_argument("--push_to_hub", type=bool, default=False)
|
277 |
+
parser.add_argument("--use_auth_token", type=bool, default=False)
|
278 |
parser.add_argument("--hub_token", type=str, default=None)
|
279 |
parser.add_argument("--hub_model_id", type=str, default=None)
|
280 |
+
parser.add_argument("--hub_private_repo", type=bool, default=False)
|
281 |
parser.add_argument("--logging_dir", type=str, default="logs")
|
282 |
parser.add_argument(
|
283 |
"--mixed_precision",
|