Use exponentialLR
Browse files- train_model.py +3 -7
train_model.py
CHANGED
@@ -120,11 +120,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
120 |
text_encoder.requires_grad_(False)
|
121 |
|
122 |
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
123 |
-
lr_scheduler =
|
124 |
-
optimizer=optimizer,
|
125 |
-
num_warmup_steps=lr_warmup_steps,
|
126 |
-
num_training_steps=num_images * epochs
|
127 |
-
)
|
128 |
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder, vae_image_processor)
|
129 |
unet = unet.to('cuda')
|
130 |
|
@@ -173,7 +169,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
173 |
if (epoch + 1) % save_model_interval == 0:
|
174 |
# inference in float16
|
175 |
model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
|
176 |
-
vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
|
177 |
save_and_test(model, epoch)
|
178 |
|
179 |
# training in float32
|
@@ -185,4 +181,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
185 |
|
186 |
|
187 |
if __name__ == '__main__':
|
188 |
-
train_model(batch_size=
|
|
|
120 |
text_encoder.requires_grad_(False)
|
121 |
|
122 |
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
123 |
+
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999, verbose=True)
|
|
|
|
|
|
|
|
|
124 |
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder, vae_image_processor)
|
125 |
unet = unet.to('cuda')
|
126 |
|
|
|
169 |
if (epoch + 1) % save_model_interval == 0:
|
170 |
# inference in float16
|
171 |
model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
|
172 |
+
vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16), vae_image_processor)
|
173 |
save_and_test(model, epoch)
|
174 |
|
175 |
# training in float32
|
|
|
181 |
|
182 |
|
183 |
if __name__ == '__main__':
|
184 |
+
train_model(batch_size=48, save_model_interval=25, epochs=1000, start_learning_rate=1e-3)
|