Update train_dtat.py
Browse files- train_dtat.py +6 -5
train_dtat.py
CHANGED
@@ -82,11 +82,12 @@ def get_lr(it, config):
|
|
82 |
if it < config.warmup_iters:
|
83 |
return config.learning_rate * it / config.warmup_iters
|
84 |
|
85 |
-
# Cosine decay
|
86 |
if config.decay_lr:
|
87 |
-
decay_ratio = (it - config.warmup_iters) / (config.
|
88 |
-
|
89 |
-
|
|
|
90 |
|
91 |
return config.learning_rate
|
92 |
|
@@ -280,7 +281,7 @@ def main():
|
|
280 |
})
|
281 |
|
282 |
# Save regular checkpoint every 5000 iterations
|
283 |
-
if iter_num %
|
284 |
checkpoint = {
|
285 |
'model_state_dict': model.state_dict(),
|
286 |
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
82 |
if it < config.warmup_iters:
|
83 |
return config.learning_rate * it / config.warmup_iters
|
84 |
|
85 |
+
# Cosine decay with minimum learning rate
|
86 |
if config.decay_lr:
|
87 |
+
decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
|
88 |
+
decay_ratio = min(decay_ratio, 1.0) # Cap at 1.0
|
89 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
90 |
+
return config.min_lr + coeff * (config.learning_rate - config.min_lr)
|
91 |
|
92 |
return config.learning_rate
|
93 |
|
|
|
281 |
})
|
282 |
|
283 |
# Save regular checkpoint every 5000 iterations
|
284 |
+
if iter_num % 1000 == 0:
|
285 |
checkpoint = {
|
286 |
'model_state_dict': model.state_dict(),
|
287 |
'optimizer_state_dict': optimizer.state_dict(),
|