Kevin Fink commited on
Commit
8325fbf
·
1 Parent(s): 63431bc
Files changed (1) hide show
  1. app.py +17 -23
app.py CHANGED
@@ -68,8 +68,6 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
68
 
69
  # Set training arguments
70
  training_args = TrainingArguments(
71
- remove_unused_columns=False,
72
- torch_empty_cache_steps=100,
73
  output_dir='/data/results',
74
  eval_strategy="steps", # Change this to steps
75
  save_strategy='steps',
@@ -84,7 +82,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
84
  metric_for_best_model="loss",
85
  greater_is_better=True,
86
  logging_dir='/data/logs',
87
- logging_steps=200,
88
  #push_to_hub=True,
89
  hub_model_id=hub_id.strip(),
90
  fp16=True,
@@ -231,32 +229,28 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
231
  ##data_collator=data_collator,
232
  ##processing_class=tokenizer,
233
  #)
 
 
 
 
 
 
234
  try:
235
  train_result = trainer.train(resume_from_checkpoint=True)
236
  except:
237
  checkpoint_dir = training_args.output_dir
238
- if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
239
- # Check if the trainer_state.json file exists in the specified checkpoint
240
- trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
241
- if os.path.exists(trainer_state_path):
242
- train_result = trainer.train(resume_from_checkpoint=True)
243
- else:
244
- # If the trainer_state.json is missing, look for the previous checkpoint
245
- print(f"Checkpoint {checkpoint_dir} is missing 'trainer_state.json'. Looking for previous checkpoints...")
246
- previous_checkpoints = sorted(glob.glob(os.path.join(os.path.dirname(checkpoint_dir), 'checkpoint-*')), key=os.path.getmtime)
247
-
248
- if previous_checkpoints:
249
- # Load the most recent previous checkpoint
250
- last_checkpoint = previous_checkpoints[-1]
251
- print(f"Loading previous checkpoint: {last_checkpoint}")
252
- train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
253
- else:
254
- print("No previous checkpoints found. Starting training from scratch.")
255
- train_result = trainer.train()
256
  else:
257
- print("No checkpoints found. Starting training from scratch.")
258
  train_result = trainer.train()
259
-
260
  trainer.push_to_hub(commit_message="Training complete!")
261
  except Exception as e:
262
  return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
 
68
 
69
  # Set training arguments
70
  training_args = TrainingArguments(
 
 
71
  output_dir='/data/results',
72
  eval_strategy="steps", # Change this to steps
73
  save_strategy='steps',
 
82
  metric_for_best_model="loss",
83
  greater_is_better=True,
84
  logging_dir='/data/logs',
85
+ logging_steps=100,
86
  #push_to_hub=True,
87
  hub_model_id=hub_id.strip(),
88
  fp16=True,
 
229
  ##data_collator=data_collator,
230
  ##processing_class=tokenizer,
231
  #)
232
+ print(f'ROOTDIR: {os.listdir('/data/results')}')
233
+ for entry in os.listdir('data/results'):
234
+ try:
235
+ print(f'{entry}: {os.listdir(entry)}')
236
+ except:
237
+ pass
238
  try:
239
  train_result = trainer.train(resume_from_checkpoint=True)
240
  except:
241
  checkpoint_dir = training_args.output_dir
242
+ # If the trainer_state.json is missing, look for the previous checkpoint
243
+ print(f"Checkpoint {checkpoint_dir} is missing 'trainer_state.json'. Looking for previous checkpoints...")
244
+ previous_checkpoints = sorted(glob.glob(os.path.join(os.path.dirname(checkpoint_dir), 'checkpoint-*')), key=os.path.getmtime)
245
+
246
+ if previous_checkpoints:
247
+ # Load the most recent previous checkpoint
248
+ last_checkpoint = previous_checkpoints[-1]
249
+ print(f"Loading previous checkpoint: {last_checkpoint}")
250
+ train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
 
 
 
 
 
 
 
 
 
251
  else:
252
+ print("No previous checkpoints found. Starting training from scratch.")
253
  train_result = trainer.train()
 
254
  trainer.push_to_hub(commit_message="Training complete!")
255
  except Exception as e:
256
  return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"