frutiemax commited on
Commit
1618027
1 Parent(s): 04d70cd

Fix batch size

Browse files
Files changed (1) hide show
  1. train_model.py +4 -3
train_model.py CHANGED
@@ -100,16 +100,17 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
100
 
101
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
102
  progress_bar = tqdm(total=epochs)
 
103
  for epoch in range(epochs):
104
  # create a noisy version of each sprite
105
  for batch_index in range(0, num_images, batch_size):
106
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
107
  batch_end = np.minimum(num_images, batch_index + batch_size)
108
  clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
109
- clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
110
 
111
  noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
112
- timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, ))
113
  timesteps = timesteps.to(dtype=torch.int, device='cuda')
114
  noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
115
  noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
@@ -128,4 +129,4 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
128
 
129
 
130
  if __name__ == '__main__':
131
- train_model()
 
100
 
101
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
102
  progress_bar = tqdm(total=epochs)
103
+
104
  for epoch in range(epochs):
105
  # create a noisy version of each sprite
106
  for batch_index in range(0, num_images, batch_size):
107
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
108
  batch_end = np.minimum(num_images, batch_index + batch_size)
109
  clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
110
+ clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
111
 
112
  noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
113
+ timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
114
  timesteps = timesteps.to(dtype=torch.int, device='cuda')
115
  noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
116
  noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
 
129
 
130
 
131
  if __name__ == '__main__':
132
+ train_model(8)