eyad-silx commited on
Commit
d606189
·
verified ·
1 Parent(s): 6b70460

Update train_dtat.py

Browse files
Files changed (1) hide show
  1. train_dtat.py +6 -5
train_dtat.py CHANGED
@@ -82,11 +82,12 @@ def get_lr(it, config):
82
  if it < config.warmup_iters:
83
  return config.learning_rate * it / config.warmup_iters
84
 
85
- # Cosine decay
86
  if config.decay_lr:
87
- decay_ratio = (it - config.warmup_iters) / (config.max_iters - config.warmup_iters)
88
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # Cosine decay
89
- return config.learning_rate * coeff
 
90
 
91
  return config.learning_rate
92
 
@@ -280,7 +281,7 @@ def main():
280
  })
281
 
282
  # Save regular checkpoint every 5000 iterations
283
- if iter_num % 5000 == 0:
284
  checkpoint = {
285
  'model_state_dict': model.state_dict(),
286
  'optimizer_state_dict': optimizer.state_dict(),
 
82
  if it < config.warmup_iters:
83
  return config.learning_rate * it / config.warmup_iters
84
 
85
+ # Cosine decay with minimum learning rate
86
  if config.decay_lr:
87
+ decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
88
+ decay_ratio = min(decay_ratio, 1.0) # Cap at 1.0
89
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
90
+ return config.min_lr + coeff * (config.learning_rate - config.min_lr)
91
 
92
  return config.learning_rate
93
 
 
281
  })
282
 
283
  # Save regular checkpoint every 5000 iterations
284
+ if iter_num % 1000 == 0:
285
  checkpoint = {
286
  'model_state_dict': model.state_dict(),
287
  'optimizer_state_dict': optimizer.state_dict(),