|
import transformers |
|
from datasets import ClassLabel |
|
import random |
|
import pandas as pd |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples['text'], add_special_tokens=True) |
|
|
|
|
|
def group_texts(examples): |
|
|
|
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} |
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
total_length = (total_length // block_size) * block_size |
|
|
|
result = { |
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
result["labels"] = result["input_ids"].copy() |
|
return result |
|
|
|
|
|
|
|
block_size = 128 |
|
|
|
from datasets import load_dataset |
|
datasets = load_dataset('jed351/cantonese-wikipedia') |
|
|
|
from transformers import AutoTokenizer |
|
model_checkpoint = "Ayaka/bart-base-cantonese" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) |
|
tokenized_datasets = datasets.map(tokenize_function, |
|
batched=True, num_proc=4, remove_columns=["text"]) |
|
|
|
|
|
|
|
lm_datasets = tokenized_datasets.map( |
|
group_texts, |
|
batched=True, |
|
batch_size=1000, |
|
num_proc=4, |
|
) |
|
|
|
|
|
|
|
from transformers import Trainer, TrainingArguments |
|
|
|
|
|
from transformers import DataCollatorForLanguageModeling |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) |
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForMaskedLM |
|
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) |
|
|
|
|
|
training_args = TrainingArguments( |
|
f"bart-finetuned-wikitext2", |
|
evaluation_strategy = "epoch", |
|
learning_rate=2e-5, |
|
weight_decay=0.01, |
|
push_to_hub=False, |
|
per_device_train_batch_size=72, |
|
fp16=True, |
|
save_steps=5000 |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=lm_datasets["train"], |
|
eval_dataset=lm_datasets["test"], |
|
data_collator=data_collator, |
|
) |
|
|
|
|
|
trainer.train() |
|
|