Poobanchean
Upload folder using huggingface_hub
f627408 verified
raw
history blame
1.27 kB
import transformers
from transformers import Trainer
from llm_finetune.arguments import (
ModelArguments,
DataArguments,
TrainingArguments,
)
from llm_finetune.dataset import make_supervised_data_module
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.eos_token
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
trainer.train(training_args.checkpoint)
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()