frutiemax commited on
Commit
b8efa96
·
1 Parent(s): aa6b13c

Use exponentialLR

Browse files
Files changed (1) hide show
  1. train_model.py +3 -7
train_model.py CHANGED
@@ -120,11 +120,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
120
  text_encoder.requires_grad_(False)
121
 
122
  optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
123
- lr_scheduler = get_cosine_schedule_with_warmup(
124
- optimizer=optimizer,
125
- num_warmup_steps=lr_warmup_steps,
126
- num_training_steps=num_images * epochs
127
- )
128
  model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder, vae_image_processor)
129
  unet = unet.to('cuda')
130
 
@@ -173,7 +169,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
173
  if (epoch + 1) % save_model_interval == 0:
174
  # inference in float16
175
  model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
176
- vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
177
  save_and_test(model, epoch)
178
 
179
  # training in float32
@@ -185,4 +181,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
185
 
186
 
187
  if __name__ == '__main__':
188
- train_model(batch_size=1, total_images=4, save_model_interval=25, epochs=500, start_learning_rate=1e-5)
 
120
  text_encoder.requires_grad_(False)
121
 
122
  optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
123
+ lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999, verbose=True)
 
 
 
 
124
  model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder, vae_image_processor)
125
  unet = unet.to('cuda')
126
 
 
169
  if (epoch + 1) % save_model_interval == 0:
170
  # inference in float16
171
  model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
172
+ vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16), vae_image_processor)
173
  save_and_test(model, epoch)
174
 
175
  # training in float32
 
181
 
182
 
183
  if __name__ == '__main__':
184
+ train_model(batch_size=48, save_model_interval=25, epochs=1000, start_learning_rate=1e-3)