JacobLinCool commited on
Commit
fa9dd69
1 Parent(s): 0e6fd1f

feat: early return when trained 10 epoch

Browse files
Files changed (1) hide show
  1. infer/modules/train/train.py +5 -1
infer/modules/train/train.py CHANGED
@@ -248,8 +248,8 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
248
  scaler = GradScaler(enabled=hps.train.fp16_run)
249
 
250
  cache = []
 
251
  for epoch in range(epoch_str, hps.train.epochs + 1):
252
- print("epoch", epoch)
253
  if rank == 0:
254
  train_and_evaluate(
255
  rank,
@@ -283,6 +283,10 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
283
  scheduler_g.step()
284
  scheduler_d.step()
285
 
 
 
 
 
286
 
287
  def train_and_evaluate(
288
  rank,
 
248
  scaler = GradScaler(enabled=hps.train.fp16_run)
249
 
250
  cache = []
251
+ trained = 0
252
  for epoch in range(epoch_str, hps.train.epochs + 1):
 
253
  if rank == 0:
254
  train_and_evaluate(
255
  rank,
 
283
  scheduler_g.step()
284
  scheduler_d.step()
285
 
286
+ trained += 1
287
+ if trained >= 10:
288
+ break
289
+
290
 
291
  def train_and_evaluate(
292
  rank,