In [5]:
import torch
from torch.utils.data import Dataset

import pandas as pd
from sklearn.model_selection import train_test_split

from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import Trainer, TrainingArguments

In [6]:
# setup 

device = "cuda:0"

model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)
max_len = 200

training_args = TrainingArguments(
    output_dir="results",
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10
    )

# dataset classes that inherit from torch.utils.data.Dataset  
class TokenizerDataset(Dataset):
    def __init__(self, strings):
        self.strings = strings
    
    def __getitem__(self, idx):
        return self.strings[idx]
    
    def __len__(self):
        return len(self.strings)

class TweetDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        self.tok = tokenizer
    
    def __getitem__(self, idx):
        # print(idx)
        # print(len(self.labels))
        encoding = self.tok(self.encodings.strings[idx], truncation=True, 
                            padding="max_length", max_length=max_len)
        # print(encoding.items())
        item = { key: torch.tensor(val) for key, val in encoding.items() }
        item['labels'] = torch.tensor(self.labels[idx])
        # print(item)
        return item
    
    def __len__(self):
        return len(self.labels)

# from https://discuss.huggingface.co/t/fine-tune-for-multiclass-or-multilabel-multiclass/4035/8
class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), 
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [7]:
# load data

train_data = pd.read_csv("train.csv")
print(train_data)
train_data, test_data = train_test_split(train_data, test_size=0.2)

train_text = train_data["comment_text"]
train_labels = train_data[["toxic", "severe_toxic", 
                           "obscene", "threat", 
                           "insult", "identity_hate"]]

test_text = test_data["comment_text"]
test_labels = test_data[["toxic", "severe_toxic", 
                           "obscene", "threat", 
                           "insult", "identity_hate"]]

train_text = train_text.values.tolist()
train_labels = train_labels.values.tolist()
test_text = test_text.values.tolist()
test_labels = test_labels.values.tolist()

                      id                                       comment_text  \
0       0000997932d777bf  Explanation\nWhy the edits made under my usern...   
1       000103f0d9cfb60f  D'aww! He matches this background colour I'm s...   
2       000113f07ec002fd  Hey man, I'm really not trying to edit war. It...   
3       0001b41b1c6bb37e  "\nMore\nI can't make any real suggestions on ...   
4       0001d958c54c6e35  You, sir, are my hero. Any chance you remember...   
...                  ...                                                ...   
159566  ffe987279560d7ff  ":::::And for the second time of asking, when ...   
159567  ffea4adeee384e90  You should be ashamed of yourself \n\nThat is ...   
159568  ffee36eab5c267c9  Spitzer \n\nUmm, theres no actual article for ...   
159569  fff125370e4aaaf3  And it looks like it was actually you who put ...   
159570  fff46fc426af1f9a  "\nAnd ... I really don't think you understand...   

        toxic  severe_toxic  obscene  threat  insul

In [8]:
# prepare datasets and trainer

train_strings = TokenizerDataset(train_text)
test_strings = TokenizerDataset(test_text)

train_dataset = TweetDataset(train_strings, train_labels)
test_dataset = TweetDataset(test_strings, test_labels)

print(len(train_dataset.labels))
print(len(train_strings))


trainer = MultilabelTrainer(
    model=model, 
    args=training_args, 
    train_dataset=train_dataset, 
    eval_dataset=test_dataset
    )

127656
127656


In [9]:
!nvidia-smi

Tue Apr 25 18:46:42 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    26W /  70W |   1025MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [10]:
trainer.train()



Step,Training Loss
10,0.7009
20,0.6893
30,0.6486
40,0.5893
50,0.5256
60,0.4485
70,0.3823
80,0.3177
90,0.2767
100,0.1935


Step,Training Loss
10,0.7009
20,0.6893
30,0.6486
40,0.5893
50,0.5256
60,0.4485
70,0.3823
80,0.3177
90,0.2767
100,0.1935


TrainOutput(global_step=15958, training_loss=0.04224376326704895, metrics={'train_runtime': 8650.202, 'train_samples_per_second': 29.515, 'train_steps_per_second': 1.845, 'total_flos': 2.62413368475264e+16, 'train_loss': 0.04224376326704895, 'epoch': 2.0})

In [11]:
trainer.save_model("./results/final")