gp-uq-tester / train.py
tombm's picture
Add functionality to app
5212a08
# This is a heavily adapted version of this notebook:
# https://github.com/huggingface/notebooks/blob/main/examples/text_classification.ipynb ,
# where we show on a simple text classification problem how we can integrate
# components for uncertainty quantification into large pretrained models.
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
)
from uq import BertForUQSequenceClassification
BATCH_SIZE = 16
EVAL_BATCH_SIZE = 128
DEVICE = "cpu"
# cola dataset for determining whether sentences are gramatically correct
task = "cola"
model_checkpoint = "bert-base-uncased"
dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)
# Load our tokenizer and tokenize our data as it streams in
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
def tokenize_data(data):
# Will add input ID and attention mask columns to dataset
return tokenizer(data["sentence"], truncation=True)
encoded_dataset = dataset.map(tokenize_data, batched=True)
# Now we can load our pretrained model and introduce our uncertainty quantification component,
# which in this case is a GP final layer without any spectral normalization of the transformer weights
num_labels = 2
id2label = {0: "Invalid", 1: "Valid"}
label2id = {val: key for key, val in id2label.items()}
model = BertForUQSequenceClassification.from_pretrained(
model_checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id
)
# Specify training arguments
metric_name = "matthews_correlation"
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
f"{model_name}-finetuned-{task}",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=EVAL_BATCH_SIZE,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
push_to_hub=True,
use_mps_device=False,
no_cuda=True,
)
# Set up metric tracking
def compute_metrics(eval_predictions):
predictions, labels = eval_predictions
predictions = np.argmax(predictions, axis=1)
return metric.compute(predictions=predictions, references=labels)
# Finally, set up trainer for finetuning the model
model.to(DEVICE)
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# Add in a callback to reset the covariance matrix after each epoch, as we only need
# to do this once at the final epoch, so we don't double count any of the data. We
# could use a more elegant solution, but the covariance computation is very cheap
# so doing it ~5 times rather than once isn't a big deal.
class ResetCovarianceCallback(TrainerCallback):
def __init__(self, trainer) -> None:
super().__init__()
self._trainer = trainer
def on_epoch_end(self, args, state, control, **kwargs):
if control.should_evaluate:
self._trainer.model.classifier.reset_cov()
trainer.add_callback(ResetCovarianceCallback(trainer))
trainer.train()
trainer.push_to_hub()