|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, TrainingArguments |
|
from datasets import load_dataset |
|
|
|
|
|
model_name = "facebook/bart-base" |
|
tokenizer_name = model_name |
|
|
|
|
|
dataset = load_dataset("cnn_dailymail", split="train") |
|
|
|
|
|
def preprocess_function(examples): |
|
inputs = [ex["article"] for ex in examples] |
|
targets = [ex["highlights"] for ex in examples] |
|
# Tokenize inputs and targets, add padding |
|
tokenized_data = tokenizer(inputs, targets, padding="max_length", truncation=True) |
|
return tokenized_data |
|
|
|
|
|
train_data = dataset.map(preprocess_function, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./outputs", # any desired output directory |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
num_train_epochs=3, # Adjust number of epochs for training |
|
save_steps=10_000, |
|
evaluation_strategy="epoch", |
|
logging_steps=500, |
|
push_to_hub=True, # Set to True for direct upload to Hub during training |
|
) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_data, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|