File size: 1,265 Bytes
f627408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import transformers
from transformers import Trainer
from llm_finetune.arguments import (
ModelArguments,
DataArguments,
TrainingArguments,
)
from llm_finetune.dataset import make_supervised_data_module
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.eos_token
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
trainer.train(training_args.checkpoint)
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
|