teticio commited on
Commit
e97d748
·
1 Parent(s): 1dea888

add push_to_hub

Browse files
.gitignore CHANGED
@@ -1,5 +1,5 @@
1
  .vscode
2
  __pycache__
3
  .ipynb_checkpoints
4
- data
5
  ddpm-ema-audio-*
 
1
  .vscode
2
  __pycache__
3
  .ipynb_checkpoints
4
+ data*
5
  ddpm-ema-audio-*
src/audio_to_images.py CHANGED
@@ -58,6 +58,8 @@ def main(args):
58
  )
59
  dsd = DatasetDict({"train": ds})
60
  dsd.save_to_disk(os.path.join(args.output_dir))
 
 
61
 
62
 
63
  if __name__ == "__main__":
@@ -68,5 +70,6 @@ if __name__ == "__main__":
68
  parser.add_argument("--output_dir", type=str, default="data")
69
  parser.add_argument("--resolution", type=int, default=256)
70
  parser.add_argument("--hop_length", type=int, default=512)
 
71
  args = parser.parse_args()
72
  main(args)
 
58
  )
59
  dsd = DatasetDict({"train": ds})
60
  dsd.save_to_disk(os.path.join(args.output_dir))
61
+ if args.push_to_hub:
62
+ dsd.push_to_hub(args.push_to_hub)
63
 
64
 
65
  if __name__ == "__main__":
 
70
  parser.add_argument("--output_dir", type=str, default="data")
71
  parser.add_argument("--resolution", type=int, default=256)
72
  parser.add_argument("--hop_length", type=int, default=512)
73
+ parser.add_argument("--push_to_hub", type=str, default=None)
74
  args = parser.parse_args()
75
  main(args)
src/train_unconditional.py CHANGED
@@ -80,7 +80,18 @@ def main(args):
80
  )
81
 
82
  if args.dataset_name is not None:
83
- dataset = load_from_disk(args.dataset_name, args.dataset_config_name)["train"]
 
 
 
 
 
 
 
 
 
 
 
84
  else:
85
  dataset = load_dataset(
86
  "imagefolder",
@@ -203,11 +214,14 @@ def main(args):
203
  accelerator.trackers[0].writer.add_images(
204
  "test_samples", images_processed, epoch
205
  )
206
- for image in images_processed:
207
  image = Image.fromarray(np.mean(image, axis=0).astype("uint8"))
208
  audio = mel.image_to_audio(image)
209
  accelerator.trackers[0].writer.add_audio(
210
- "test_samples", audio, epoch, sample_rate=mel.get_sample_rate()
 
 
 
211
  )
212
 
213
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
 
80
  )
81
 
82
  if args.dataset_name is not None:
83
+ if os.path.exists(args.dataset_name):
84
+ dataset = load_from_disk(args.dataset_name, args.dataset_config_name)[
85
+ "train"
86
+ ]
87
+ else:
88
+ dataset = load_dataset(
89
+ args.dataset_name,
90
+ args.dataset_config_name,
91
+ cache_dir=args.cache_dir,
92
+ use_auth_token=True if args.use_auth_token else None,
93
+ split="train",
94
+ )
95
  else:
96
  dataset = load_dataset(
97
  "imagefolder",
 
214
  accelerator.trackers[0].writer.add_images(
215
  "test_samples", images_processed, epoch
216
  )
217
+ for _, image in enumerate(images_processed):
218
  image = Image.fromarray(np.mean(image, axis=0).astype("uint8"))
219
  audio = mel.image_to_audio(image)
220
  accelerator.trackers[0].writer.add_audio(
221
+ f"test_audio_{_}",
222
+ audio,
223
+ epoch,
224
+ sample_rate=mel.get_sample_rate(),
225
  )
226
 
227
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: