frutiemax commited on
Commit
21640c8
1 Parent(s): a4c8091

Revert to not using accelerated

Browse files
Files changed (1) hide show
  1. train_model.py +15 -28
train_model.py CHANGED
@@ -103,18 +103,9 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
103
 
104
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
105
  progress_bar = tqdm(total=epochs)
106
- accelerator = Accelerator(
107
- mixed_precision='fp16',
108
- gradient_accumulation_steps=1,
109
- log_with="tensorboard",
110
- project_dir='logs',
111
- )
112
 
113
  scheduler = DDPMScheduler(scheduler_num_timesteps)
114
- unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(unet, scheduler, \
115
- optimizer, lr_scheduler)
116
-
117
- unet = unet.to(dtype=torch.float16)
118
  scheduler.set_timesteps(scheduler_num_timesteps)
119
 
120
  for epoch in range(epochs):
@@ -123,27 +114,23 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
123
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
124
  batch_end = np.minimum(num_images, batch_index + batch_size)
125
  clean_images = targets[batch_index:batch_end]
126
- clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
127
 
128
- noise = torch.randn(clean_images.shape, dtype=torch.float16)
129
- timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
130
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
131
- noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
132
-
133
- with accelerator.accumulate(unet):
134
- noise_pred = unet(noisy_images, timesteps.to(device='cuda'), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
135
-
136
- #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
137
- loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
138
- accelerator.backward(loss)
139
- accelerator.clip_grad_norm_(unet.parameters(), 1.0)
140
-
141
- optimizer.step()
142
- lr_scheduler.step()
143
- optimizer.zero_grad()
144
 
145
  if (epoch + 1) % save_model_interval == 0:
146
- model.unet = accelerator.unwrap_model(unet)
147
  model.scheduler = scheduler
148
  save_and_test(model, epoch)
149
  del model.unet
@@ -152,4 +139,4 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
152
 
153
 
154
  if __name__ == '__main__':
155
- train_model(4)
 
103
 
104
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
105
  progress_bar = tqdm(total=epochs)
 
 
 
 
 
 
106
 
107
  scheduler = DDPMScheduler(scheduler_num_timesteps)
108
+ unet = unet.to(device='cuda', dtype=torch.float16)
 
 
 
109
  scheduler.set_timesteps(scheduler_num_timesteps)
110
 
111
  for epoch in range(epochs):
 
114
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
115
  batch_end = np.minimum(num_images, batch_index + batch_size)
116
  clean_images = targets[batch_index:batch_end]
117
+ clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256)).to(device='cuda', dtype=torch.float16)
118
 
119
+ noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
120
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
121
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
122
+ noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
123
+ noise_pred = unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
124
+
125
+ #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
126
+ loss = F.mse_loss(noise_pred, noise)
127
+ loss.backward()
128
+ optimizer.step()
129
+ lr_scheduler.step()
130
+ optimizer.zero_grad()
 
 
 
 
131
 
132
  if (epoch + 1) % save_model_interval == 0:
133
+ model.unet = unet
134
  model.scheduler = scheduler
135
  save_and_test(model, epoch)
136
  del model.unet
 
139
 
140
 
141
  if __name__ == '__main__':
142
+ train_model(8, save_model_interval=1)