teticio commited on
Commit
b929114
·
1 Parent(s): 1d6411d

no flip, start training from epoch

Browse files
Files changed (1) hide show
  1. src/train_unconditional.py +43 -29
src/train_unconditional.py CHANGED
@@ -20,7 +20,6 @@ from torchvision.transforms import (
20
  Compose,
21
  InterpolationMode,
22
  Normalize,
23
- RandomHorizontalFlip,
24
  Resize,
25
  ToTensor,
26
  )
@@ -40,29 +39,32 @@ def main(args):
40
  logging_dir=logging_dir,
41
  )
42
 
43
- model = UNet2DModel(
44
- sample_size=args.resolution,
45
- in_channels=1,
46
- out_channels=1,
47
- layers_per_block=2,
48
- block_out_channels=(128, 128, 256, 256, 512, 512),
49
- down_block_types=(
50
- "DownBlock2D",
51
- "DownBlock2D",
52
- "DownBlock2D",
53
- "DownBlock2D",
54
- "AttnDownBlock2D",
55
- "DownBlock2D",
56
- ),
57
- up_block_types=(
58
- "UpBlock2D",
59
- "AttnUpBlock2D",
60
- "UpBlock2D",
61
- "UpBlock2D",
62
- "UpBlock2D",
63
- "UpBlock2D",
64
- ),
65
- )
 
 
 
66
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
67
  optimizer = torch.optim.AdamW(
68
  model.parameters(),
@@ -76,7 +78,6 @@ def main(args):
76
  [
77
  Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
78
  CenterCrop(args.resolution),
79
- RandomHorizontalFlip(),
80
  ToTensor(),
81
  Normalize([0.5], [0.5]),
82
  ]
@@ -142,11 +143,22 @@ def main(args):
142
 
143
  global_step = 0
144
  for epoch in range(args.num_epochs):
145
- model.train()
146
  progress_bar = tqdm(
147
  total=len(train_dataloader), disable=not accelerator.is_local_main_process
148
  )
149
  progress_bar.set_description(f"Epoch {epoch}")
 
 
 
 
 
 
 
 
 
 
 
 
150
  for step, batch in enumerate(train_dataloader):
151
  clean_images = batch["input"]
152
  # Sample noise that we'll add to the images
@@ -271,12 +283,12 @@ if __name__ == "__main__":
271
  parser.add_argument("--adam_beta2", type=float, default=0.999)
272
  parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
273
  parser.add_argument("--adam_epsilon", type=float, default=1e-08)
274
- parser.add_argument("--use_ema", type=bool, default=True)
275
  parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
276
  parser.add_argument("--ema_power", type=float, default=3 / 4)
277
  parser.add_argument("--ema_max_decay", type=float, default=0.9999)
278
- parser.add_argument("--push_to_hub", type=bool, default=False)
279
- parser.add_argument("--use_auth_token", type=bool, default=False)
280
  parser.add_argument("--hub_token", type=str, default=None)
281
  parser.add_argument("--hub_model_id", type=str, default=None)
282
  parser.add_argument("--hub_private_repo", type=bool, default=False)
@@ -293,6 +305,8 @@ if __name__ == "__main__":
293
  ),
294
  )
295
  parser.add_argument("--hop_length", type=int, default=512)
 
 
296
 
297
  args = parser.parse_args()
298
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
20
  Compose,
21
  InterpolationMode,
22
  Normalize,
 
23
  Resize,
24
  ToTensor,
25
  )
 
39
  logging_dir=logging_dir,
40
  )
41
 
42
+ if args.from_pretrained is not None:
43
+ model = UNet2DModel.from_pretrained(args.from_pretrained)
44
+ else:
45
+ model = UNet2DModel(
46
+ sample_size=args.resolution,
47
+ in_channels=1,
48
+ out_channels=1,
49
+ layers_per_block=2,
50
+ block_out_channels=(128, 128, 256, 256, 512, 512),
51
+ down_block_types=(
52
+ "DownBlock2D",
53
+ "DownBlock2D",
54
+ "DownBlock2D",
55
+ "DownBlock2D",
56
+ "AttnDownBlock2D",
57
+ "DownBlock2D",
58
+ ),
59
+ up_block_types=(
60
+ "UpBlock2D",
61
+ "AttnUpBlock2D",
62
+ "UpBlock2D",
63
+ "UpBlock2D",
64
+ "UpBlock2D",
65
+ "UpBlock2D",
66
+ ),
67
+ )
68
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
69
  optimizer = torch.optim.AdamW(
70
  model.parameters(),
 
78
  [
79
  Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
80
  CenterCrop(args.resolution),
 
81
  ToTensor(),
82
  Normalize([0.5], [0.5]),
83
  ]
 
143
 
144
  global_step = 0
145
  for epoch in range(args.num_epochs):
 
146
  progress_bar = tqdm(
147
  total=len(train_dataloader), disable=not accelerator.is_local_main_process
148
  )
149
  progress_bar.set_description(f"Epoch {epoch}")
150
+
151
+ if epoch < args.start_epoch:
152
+ for step in range(len(train_dataloader)):
153
+ optimizer.step()
154
+ lr_scheduler.step()
155
+ progress_bar.update(1)
156
+ global_step += 1
157
+ if epoch == args.start_epoch - 1 and args.use_ema:
158
+ ema_model.optimization_step = global_step
159
+ continue
160
+
161
+ model.train()
162
  for step, batch in enumerate(train_dataloader):
163
  clean_images = batch["input"]
164
  # Sample noise that we'll add to the images
 
283
  parser.add_argument("--adam_beta2", type=float, default=0.999)
284
  parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
285
  parser.add_argument("--adam_epsilon", type=float, default=1e-08)
286
+ parser.add_argument("--use_ema", type=bool, default=True)
287
  parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
288
  parser.add_argument("--ema_power", type=float, default=3 / 4)
289
  parser.add_argument("--ema_max_decay", type=float, default=0.9999)
290
+ parser.add_argument("--push_to_hub", type=bool, default=False)
291
+ parser.add_argument("--use_auth_token", type=bool, default=False)
292
  parser.add_argument("--hub_token", type=str, default=None)
293
  parser.add_argument("--hub_model_id", type=str, default=None)
294
  parser.add_argument("--hub_private_repo", type=bool, default=False)
 
305
  ),
306
  )
307
  parser.add_argument("--hop_length", type=int, default=512)
308
+ parser.add_argument("--from_pretrained", type=str, default=None)
309
+ parser.add_argument("--start_epoch", type=int, default=0)
310
 
311
  args = parser.parse_args()
312
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))