Kevin Fink
commited on
Commit
·
8ba99ef
1
Parent(s):
1ec3de2
dev
Browse files
app.py
CHANGED
@@ -231,29 +231,32 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
231 |
##compute_metrics=compute_metrics,
|
232 |
##data_collator=data_collator,
|
233 |
##processing_class=tokenizer,
|
234 |
-
#)
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
if os.path.exists(
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
previous_checkpoints = sorted(glob.glob(os.path.join(os.path.dirname(checkpoint_dir), 'checkpoint-*')), key=os.path.getmtime)
|
245 |
-
|
246 |
-
if previous_checkpoints:
|
247 |
-
# Load the most recent previous checkpoint
|
248 |
-
last_checkpoint = previous_checkpoints[-1]
|
249 |
-
print(f"Loading previous checkpoint: {last_checkpoint}")
|
250 |
-
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
|
251 |
else:
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
trainer.push_to_hub(commit_message="Training complete!")
|
259 |
except Exception as e:
|
|
|
231 |
##compute_metrics=compute_metrics,
|
232 |
##data_collator=data_collator,
|
233 |
##processing_class=tokenizer,
|
234 |
+
#)
|
235 |
+
try:
|
236 |
+
train_result = trainer.train(resume_from_checkpoint=True)
|
237 |
+
except:
|
238 |
+
checkpoint_dir = training_args.output_dir
|
239 |
+
if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
|
240 |
+
# Check if the trainer_state.json file exists in the specified checkpoint
|
241 |
+
trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
|
242 |
+
if os.path.exists(trainer_state_path):
|
243 |
+
train_result = trainer.train(resume_from_checkpoint=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
else:
|
245 |
+
# If the trainer_state.json is missing, look for the previous checkpoint
|
246 |
+
print(f"Checkpoint {checkpoint_dir} is missing 'trainer_state.json'. Looking for previous checkpoints...")
|
247 |
+
previous_checkpoints = sorted(glob.glob(os.path.join(os.path.dirname(checkpoint_dir), 'checkpoint-*')), key=os.path.getmtime)
|
248 |
+
|
249 |
+
if previous_checkpoints:
|
250 |
+
# Load the most recent previous checkpoint
|
251 |
+
last_checkpoint = previous_checkpoints[-1]
|
252 |
+
print(f"Loading previous checkpoint: {last_checkpoint}")
|
253 |
+
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
|
254 |
+
else:
|
255 |
+
print("No previous checkpoints found. Starting training from scratch.")
|
256 |
+
train_result = trainer.train()
|
257 |
+
else:
|
258 |
+
print("No checkpoints found. Starting training from scratch.")
|
259 |
+
train_result = trainer.train()
|
260 |
|
261 |
trainer.push_to_hub(commit_message="Training complete!")
|
262 |
except Exception as e:
|