Fix batch size
Browse files- train_model.py +4 -3
train_model.py
CHANGED
@@ -100,16 +100,17 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
|
|
100 |
|
101 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
102 |
progress_bar = tqdm(total=epochs)
|
|
|
103 |
for epoch in range(epochs):
|
104 |
# create a noisy version of each sprite
|
105 |
for batch_index in range(0, num_images, batch_size):
|
106 |
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
107 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
108 |
clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
|
109 |
-
clean_images = torch.reshape(clean_images, (
|
110 |
|
111 |
noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
|
112 |
-
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (
|
113 |
timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
114 |
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
115 |
noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
@@ -128,4 +129,4 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
|
|
128 |
|
129 |
|
130 |
if __name__ == '__main__':
|
131 |
-
train_model()
|
|
|
100 |
|
101 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
102 |
progress_bar = tqdm(total=epochs)
|
103 |
+
|
104 |
for epoch in range(epochs):
|
105 |
# create a noisy version of each sprite
|
106 |
for batch_index in range(0, num_images, batch_size):
|
107 |
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
108 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
109 |
clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
|
110 |
+
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
|
111 |
|
112 |
noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
|
113 |
+
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
|
114 |
timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
115 |
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
116 |
noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
|
|
129 |
|
130 |
|
131 |
if __name__ == '__main__':
|
132 |
+
train_model(8)
|