Increase unet size
Browse files- train_model.py +2 -2
train_model.py
CHANGED
@@ -108,7 +108,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
108 |
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
|
109 |
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
110 |
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
111 |
-
block_out_channels=(
|
112 |
unet = unet.to(dtype=torch.float32)
|
113 |
|
114 |
#https://discuss.pytorch.org/t/training-with-half-precision/11815
|
@@ -192,4 +192,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
192 |
|
193 |
|
194 |
if __name__ == '__main__':
|
195 |
-
train_model(
|
|
|
108 |
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
|
109 |
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
110 |
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
111 |
+
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
|
112 |
unet = unet.to(dtype=torch.float32)
|
113 |
|
114 |
#https://discuss.pytorch.org/t/training-with-half-precision/11815
|
|
|
192 |
|
193 |
|
194 |
if __name__ == '__main__':
|
195 |
+
train_model(batch_size=16, save_model_interval=25, epochs=500, start_learning_rate=1e-5)
|