|
|
|
import torch |
|
|
|
if torch.cuda.is_available(): |
|
print('gpu is available') |
|
else: |
|
raise Exception('gpu is NOT available') |
|
|
|
|
|
from datasets import load_dataset, DatasetDict |
|
from transformers import AutoTokenizer |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import TrainingArguments |
|
from transformers import Trainer |
|
from sklearn.metrics import accuracy_score, f1_score |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import random |
|
|
|
|
|
from pprint import pprint |
|
from datasets import load_dataset |
|
|
|
|
|
|
|
train_dataset = load_dataset("llm-book/wrime-sentiment", split="train", remove_neutral=False) |
|
valid_dataset = load_dataset("llm-book/wrime-sentiment", split="validation", remove_neutral=False) |
|
|
|
pprint(train_dataset) |
|
pprint(valid_dataset) |
|
|
|
|
|
|
|
model_name = "cl-tohoku/bert-base-japanese-whole-word-masking" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
def preprocess_text(batch): |
|
encoded_batch = tokenizer(batch['sentence'], max_length=512) |
|
encoded_batch['labels'] = batch['label'] |
|
return encoded_batch |
|
|
|
encoded_train_dataset = train_dataset.map( |
|
preprocess_text, |
|
remove_columns=train_dataset.column_names, |
|
) |
|
encoded_valid_dataset = valid_dataset.map( |
|
preprocess_text, |
|
remove_columns=valid_dataset.column_names, |
|
) |
|
|
|
|
|
from transformers import DataCollatorWithPadding |
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
|
|
|
|
from transformers import AutoModelForSequenceClassification |
|
|
|
class_label = train_dataset.features["label"] |
|
label2id = {label: id for id, label in enumerate(class_label.names)} |
|
id2label = {id: label for id, label in enumerate(class_label.names)} |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name, |
|
num_labels=class_label.num_classes, |
|
label2id=label2id, |
|
id2label=id2label, |
|
) |
|
print(type(model).__name__) |
|
|
|
|
|
|
|
from transformers import TrainingArguments |
|
|
|
save_dir = f'bert-finetuned-wrime-base' |
|
|
|
training_args = TrainingArguments( |
|
output_dir=save_dir, |
|
per_device_train_batch_size=32, |
|
per_device_eval_batch_size=32, |
|
learning_rate=2e-5, |
|
lr_scheduler_type="constant", |
|
warmup_ratio=0.1, |
|
num_train_epochs=100, |
|
save_strategy="epoch", |
|
logging_strategy="epoch", |
|
evaluation_strategy="epoch", |
|
load_best_model_at_end=True, |
|
metric_for_best_model="accuracy", |
|
fp16=True, |
|
) |
|
|
|
|
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
f1 = f1_score(labels, preds, average="weighted") |
|
acc = accuracy_score(labels, preds) |
|
return {"accuracy": acc, "f1": f1} |
|
|
|
|
|
|
|
from transformers import Trainer |
|
from transformers import EarlyStoppingCallback |
|
|
|
trainer = Trainer( |
|
model=model, |
|
train_dataset=encoded_train_dataset, |
|
eval_dataset=encoded_valid_dataset, |
|
data_collator=data_collator, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], |
|
) |
|
trainer.train() |
|
|
|
|
|
|
|
trainer.save_model(save_dir) |
|
tokenizer.save_pretrained(save_dir) |
|
|
|
history_df = pd.DataFrame(trainer.state.log_history) |
|
history_df.to_csv('base_line/wrime_baseline_history.csv') |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
def show_graph(df, suptitle, output='output.png'): |
|
suptitle_size = 23 |
|
graph_title_size = 20 |
|
legend_size = 18 |
|
ticks_size = 13 |
|
|
|
fig = plt.figure(figsize=(20, 5)) |
|
plt.suptitle(suptitle, fontsize=suptitle_size) |
|
|
|
plt.subplot(131) |
|
plt.title('Train Loss', fontsize=graph_title_size) |
|
plt.plot(df['loss'].dropna(), label='train') |
|
plt.legend(fontsize=legend_size) |
|
plt.yticks(fontsize=ticks_size) |
|
|
|
plt.subplot(132) |
|
|
|
plt.title(f'Val Loss', fontsize=graph_title_size) |
|
y = df['eval_loss'].dropna().values |
|
x = np.arange(len(y)).reshape(-1, 1) |
|
|
|
plt.plot(y, color='tab:orange', label='val') |
|
|
|
plt.legend(fontsize=legend_size) |
|
|
|
plt.yticks(fontsize=ticks_size) |
|
|
|
plt.subplot(133) |
|
plt.title('eval Accuracy/F1', fontsize=graph_title_size) |
|
plt.plot(df['eval_accuracy'].dropna(), label='accuracy') |
|
plt.plot(df['eval_f1'].dropna(), label='F1') |
|
plt.legend(fontsize=legend_size) |
|
plt.yticks(fontsize=ticks_size) |
|
plt.tight_layout() |
|
|
|
plt.savefig(output) |
|
|
|
|
|
|
|
suptitle = 'batch:32, lr:2e-5, type:constant' |
|
show_graph(history_df, suptitle, 'base_line/wrime_baseline_output.png') |
|
|