frutiemax commited on
Commit
f53657a
·
1 Parent(s): 3ab0859

Increase unet size

Browse files
Files changed (1) hide show
  1. train_model.py +2 -2
train_model.py CHANGED
@@ -108,7 +108,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
108
  unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
109
  down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
110
  up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
111
- block_out_channels=(128, 256, 512, 512), norm_num_groups=32)
112
  unet = unet.to(dtype=torch.float32)
113
 
114
  #https://discuss.pytorch.org/t/training-with-half-precision/11815
@@ -192,4 +192,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
192
 
193
 
194
  if __name__ == '__main__':
195
- train_model(1, save_model_interval=10, epochs=100)
 
108
  unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
109
  down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
110
  up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
111
+ block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
112
  unet = unet.to(dtype=torch.float32)
113
 
114
  #https://discuss.pytorch.org/t/training-with-half-precision/11815
 
192
 
193
 
194
  if __name__ == '__main__':
195
+ train_model(batch_size=16, save_model_interval=25, epochs=500, start_learning_rate=1e-5)