Spaces:
Running
Running
File size: 1,696 Bytes
6de3e11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import os
import os
import shutil
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
#
class SavePeftModelCallback(TrainerCallback):
def on_save(self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs, ):
if args.local_rank == 0 or args.local_rank == -1:
#
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_dir = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_dir)
peft_config_path = os.path.join(checkpoint_folder, "adapter_model/adapter_config.json")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model/adapter_model.bin")
if not os.path.exists(peft_config_path):
os.remove(peft_config_path)
if not os.path.exists(peft_model_path):
os.remove(peft_model_path)
if os.path.exists(peft_model_dir):
shutil.rmtree(peft_model_dir)
#
best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best")
#
if os.path.exists(state.best_model_checkpoint):
if os.path.exists(best_checkpoint_folder):
shutil.rmtree(best_checkpoint_folder)
shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
print(f"{state.best_model_checkpoint}{state.best_metric}")
return control
|