File size: 4,155 Bytes
f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca f2a478c 369c9ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)
max_len = 200
training_args = TrainingArguments(
output_dir="results",
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
learning_rate=5e-5,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10
)
# dataset class that inherits from torch.utils.data.Dataset
class TokenizerDataset(Dataset):
def __init__(self, strings):
self.strings = strings
def __getitem__(self, idx):
return self.strings[idx]
def __len__(self):
return len(self.strings)
train_data = pd.read_csv("data/train.csv")
print(train_data)
train_text = train_data["comment_text"]
train_labels = train_data[["toxic", "severe_toxic",
"obscene", "threat",
"insult", "identity_hate"]]
test_text = pd.read_csv("data/test.csv")["comment_text"]
test_labels = pd.read_csv("data/test_labels.csv")[[
"toxic", "severe_toxic",
"obscene", "threat",
"insult", "identity_hate"]]
# data preprocessing
train_text = train_text.values.tolist()
train_labels = train_labels.values.tolist()
test_text = test_text.values.tolist()
test_labels = test_labels.values.tolist()
# prepare tokenizer and dataset
class TweetDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
self.tok = tokenizer
def __getitem__(self, idx):
print(idx)
# print(len(self.labels))
encoding = self.tok(self.encodings.strings[idx], truncation=True,
padding="max_length", max_length=max_len)
# print(encoding.items())
item = { key: torch.tensor(val) for key, val in encoding.items() }
item['labels'] = torch.tensor(self.labels[idx])
# print(item)
return item
def __len__(self):
return len(self.labels)
train_strings = TokenizerDataset(train_text)
test_strings = TokenizerDataset(test_text)
train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)
# train_encodings = tokenizer.batch_encode_plus(train_text, \
# max_length=200, pad_to_max_length=True, \
# truncation=True, return_token_type_ids=False \
# )
# test_encodings = tokenizer.batch_encode_plus(test_text, \
# max_length=200, pad_to_max_length=True, \
# truncation=True, return_token_type_ids=False \
# )
# train_encodings = tokenizer(train_text, truncation=True, padding=True)
# test_encodings = tokenizer(test_text, truncation=True, padding=True)
train_dataset = TweetDataset(train_strings, train_labels)
test_dataset = TweetDataset(test_strings, test_labels)
print(len(train_dataset.labels))
print(len(train_strings))
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss
# training
trainer = MultilabelTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset
)
trainer.train() |