|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import csv |
|
import json |
|
import math |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.optim |
|
import torch.optim as optim |
|
import torch.utils.data as data |
|
from nltk.tokenize.treebank import TreebankWordDetokenizer |
|
from torchtext import data as torchtext_data |
|
from torchtext import datasets |
|
from tqdm import tqdm, trange |
|
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel |
|
|
|
torch.manual_seed(0) |
|
np.random.seed(0) |
|
EPSILON = 1e-10 |
|
device = "cpu" |
|
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." |
|
max_length_seq = 100 |
|
|
|
|
|
class ClassificationHead(torch.nn.Module): |
|
"""Classification Head for transformer encoders""" |
|
|
|
def __init__(self, class_size, embed_size): |
|
super(ClassificationHead, self).__init__() |
|
self.class_size = class_size |
|
self.embed_size = embed_size |
|
|
|
|
|
self.mlp = torch.nn.Linear(embed_size, class_size) |
|
|
|
def forward(self, hidden_state): |
|
|
|
|
|
logits = self.mlp(hidden_state) |
|
return logits |
|
|
|
|
|
class Discriminator(torch.nn.Module): |
|
"""Transformer encoder followed by a Classification Head""" |
|
|
|
def __init__( |
|
self, |
|
class_size, |
|
pretrained_model="gpt2-medium", |
|
cached_mode=False |
|
): |
|
super(Discriminator, self).__init__() |
|
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) |
|
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model) |
|
self.embed_size = self.encoder.transformer.config.hidden_size |
|
self.classifier_head = ClassificationHead( |
|
class_size=class_size, |
|
embed_size=self.embed_size |
|
) |
|
self.cached_mode = cached_mode |
|
|
|
def get_classifier(self): |
|
return self.classifier_head |
|
|
|
def train_custom(self): |
|
for param in self.encoder.parameters(): |
|
param.requires_grad = False |
|
self.classifier_head.train() |
|
|
|
def avg_representation(self, x): |
|
mask = x.ne(0).unsqueeze(2).repeat( |
|
1, 1, self.embed_size |
|
).float().to(device).detach() |
|
hidden, _ = self.encoder.transformer(x) |
|
masked_hidden = hidden * mask |
|
avg_hidden = torch.sum(masked_hidden, dim=1) / ( |
|
torch.sum(mask, dim=1).detach() + EPSILON |
|
) |
|
return avg_hidden |
|
|
|
def forward(self, x): |
|
if self.cached_mode: |
|
avg_hidden = x.to(device) |
|
else: |
|
avg_hidden = self.avg_representation(x.to(device)) |
|
|
|
logits = self.classifier_head(avg_hidden) |
|
probs = F.log_softmax(logits, dim=-1) |
|
|
|
return probs |
|
|
|
|
|
class Dataset(data.Dataset): |
|
def __init__(self, X, y): |
|
"""Reads source and target sequences from txt files.""" |
|
self.X = X |
|
self.y = y |
|
|
|
def __len__(self): |
|
return len(self.X) |
|
|
|
def __getitem__(self, index): |
|
"""Returns one data pair (source and target).""" |
|
data = {} |
|
data["X"] = self.X[index] |
|
data["y"] = self.y[index] |
|
return data |
|
|
|
|
|
def collate_fn(data): |
|
def pad_sequences(sequences): |
|
lengths = [len(seq) for seq in sequences] |
|
|
|
padded_sequences = torch.zeros( |
|
len(sequences), |
|
max(lengths) |
|
).long() |
|
|
|
for i, seq in enumerate(sequences): |
|
end = lengths[i] |
|
padded_sequences[i, :end] = seq[:end] |
|
|
|
return padded_sequences, lengths |
|
|
|
item_info = {} |
|
for key in data[0].keys(): |
|
item_info[key] = [d[key] for d in data] |
|
|
|
x_batch, _ = pad_sequences(item_info["X"]) |
|
y_batch = torch.tensor(item_info["y"], dtype=torch.long) |
|
|
|
return x_batch, y_batch |
|
|
|
|
|
def cached_collate_fn(data): |
|
item_info = {} |
|
for key in data[0].keys(): |
|
item_info[key] = [d[key] for d in data] |
|
|
|
x_batch = torch.cat(item_info["X"], 0) |
|
y_batch = torch.tensor(item_info["y"], dtype=torch.long) |
|
|
|
return x_batch, y_batch |
|
|
|
|
|
def train_epoch(data_loader, discriminator, optimizer, |
|
epoch=0, log_interval=10): |
|
samples_so_far = 0 |
|
discriminator.train_custom() |
|
for batch_idx, (input_t, target_t) in enumerate(data_loader): |
|
input_t, target_t = input_t.to(device), target_t.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
output_t = discriminator(input_t) |
|
loss = F.nll_loss(output_t, target_t) |
|
loss.backward(retain_graph=True) |
|
optimizer.step() |
|
|
|
samples_so_far += len(input_t) |
|
|
|
if batch_idx % log_interval == 0: |
|
print( |
|
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( |
|
epoch + 1, |
|
samples_so_far, len(data_loader.dataset), |
|
100 * samples_so_far / len(data_loader.dataset), loss.item() |
|
) |
|
) |
|
|
|
|
|
def evaluate_performance(data_loader, discriminator): |
|
discriminator.eval() |
|
test_loss = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for input_t, target_t in data_loader: |
|
input_t, target_t = input_t.to(device), target_t.to(device) |
|
output_t = discriminator(input_t) |
|
|
|
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item() |
|
|
|
pred_t = output_t.argmax(dim=1, keepdim=True) |
|
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() |
|
|
|
test_loss /= len(data_loader.dataset) |
|
|
|
print( |
|
"Performance on test set: " |
|
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format( |
|
test_loss, correct, len(data_loader.dataset), |
|
100. * correct / len(data_loader.dataset) |
|
) |
|
) |
|
|
|
|
|
def predict(input_sentence, model, classes, cached=False): |
|
input_t = model.tokenizer.encode(input_sentence) |
|
input_t = torch.tensor([input_t], dtype=torch.long, device=device) |
|
if cached: |
|
input_t = model.avg_representation(input_t) |
|
|
|
log_probs = model(input_t).data.cpu().numpy().flatten().tolist() |
|
print("Input sentence:", input_sentence) |
|
print("Predictions:", ", ".join( |
|
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in |
|
zip(classes, log_probs) |
|
)) |
|
|
|
|
|
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): |
|
data_loader = torch.utils.data.DataLoader(dataset=dataset, |
|
batch_size=batch_size, |
|
collate_fn=collate_fn) |
|
|
|
xs = [] |
|
ys = [] |
|
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)): |
|
with torch.no_grad(): |
|
x = x.to(device) |
|
avg_rep = discriminator.avg_representation(x).cpu().detach() |
|
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1)) |
|
xs += avg_rep_list |
|
ys += y.cpu().numpy().tolist() |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset=Dataset(xs, ys), |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
collate_fn=cached_collate_fn) |
|
|
|
return data_loader |
|
|
|
|
|
def train_discriminator( |
|
dataset, dataset_fp=None, pretrained_model="gpt2-medium", |
|
epochs=10, batch_size=64, log_interval=10, |
|
save_model=False, cached=False, no_cuda=False): |
|
global device |
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" |
|
|
|
print("Preprocessing {} dataset...".format(dataset)) |
|
start = time.time() |
|
|
|
if dataset == "SST": |
|
idx2class = ["positive", "negative", "very positive", "very negative", |
|
"neutral"] |
|
class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
|
discriminator = Discriminator( |
|
class_size=len(idx2class), |
|
pretrained_model=pretrained_model, |
|
cached_mode=cached |
|
).to(device) |
|
|
|
text = torchtext_data.Field() |
|
label = torchtext_data.Field(sequential=False) |
|
train_data, val_data, test_data = datasets.SST.splits( |
|
text, |
|
label, |
|
fine_grained=True, |
|
train_subtrees=True, |
|
) |
|
|
|
x = [] |
|
y = [] |
|
for i in trange(len(train_data), ascii=True): |
|
seq = TreebankWordDetokenizer().detokenize( |
|
vars(train_data[i])["text"] |
|
) |
|
seq = discriminator.tokenizer.encode(seq) |
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
|
x.append(seq) |
|
y.append(class2idx[vars(train_data[i])["label"]]) |
|
train_dataset = Dataset(x, y) |
|
|
|
test_x = [] |
|
test_y = [] |
|
for i in trange(len(test_data), ascii=True): |
|
seq = TreebankWordDetokenizer().detokenize( |
|
vars(test_data[i])["text"] |
|
) |
|
seq = discriminator.tokenizer.encode(seq) |
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
|
test_x.append(seq) |
|
test_y.append(class2idx[vars(test_data[i])["label"]]) |
|
test_dataset = Dataset(test_x, test_y) |
|
|
|
discriminator_meta = { |
|
"class_size": len(idx2class), |
|
"embed_size": discriminator.embed_size, |
|
"pretrained_model": pretrained_model, |
|
"class_vocab": class2idx, |
|
"default_class": 2, |
|
} |
|
|
|
elif dataset == "clickbait": |
|
idx2class = ["non_clickbait", "clickbait"] |
|
class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
|
discriminator = Discriminator( |
|
class_size=len(idx2class), |
|
pretrained_model=pretrained_model, |
|
cached_mode=cached |
|
).to(device) |
|
|
|
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: |
|
data = [] |
|
for i, line in enumerate(f): |
|
try: |
|
data.append(eval(line)) |
|
except: |
|
print("Error evaluating line {}: {}".format( |
|
i, line |
|
)) |
|
continue |
|
x = [] |
|
y = [] |
|
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: |
|
for i, line in enumerate(tqdm(f, ascii=True)): |
|
try: |
|
d = eval(line) |
|
seq = discriminator.tokenizer.encode(d["text"]) |
|
|
|
if len(seq) < max_length_seq: |
|
seq = torch.tensor( |
|
[50256] + seq, device=device, dtype=torch.long |
|
) |
|
else: |
|
print("Line {} is longer than maximum length {}".format( |
|
i, max_length_seq |
|
)) |
|
continue |
|
x.append(seq) |
|
y.append(d["label"]) |
|
except: |
|
print("Error evaluating / tokenizing" |
|
" line {}, skipping it".format(i)) |
|
pass |
|
|
|
full_dataset = Dataset(x, y) |
|
train_size = int(0.9 * len(full_dataset)) |
|
test_size = len(full_dataset) - train_size |
|
train_dataset, test_dataset = torch.utils.data.random_split( |
|
full_dataset, [train_size, test_size] |
|
) |
|
|
|
discriminator_meta = { |
|
"class_size": len(idx2class), |
|
"embed_size": discriminator.embed_size, |
|
"pretrained_model": pretrained_model, |
|
"class_vocab": class2idx, |
|
"default_class": 1, |
|
} |
|
|
|
elif dataset == "toxic": |
|
idx2class = ["non_toxic", "toxic"] |
|
class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
|
discriminator = Discriminator( |
|
class_size=len(idx2class), |
|
pretrained_model=pretrained_model, |
|
cached_mode=cached |
|
).to(device) |
|
|
|
x = [] |
|
y = [] |
|
with open("datasets/toxic/toxic_train.txt") as f: |
|
for i, line in enumerate(tqdm(f, ascii=True)): |
|
try: |
|
d = eval(line) |
|
seq = discriminator.tokenizer.encode(d["text"]) |
|
|
|
if len(seq) < max_length_seq: |
|
seq = torch.tensor( |
|
[50256] + seq, device=device, dtype=torch.long |
|
) |
|
else: |
|
print("Line {} is longer than maximum length {}".format( |
|
i, max_length_seq |
|
)) |
|
continue |
|
x.append(seq) |
|
y.append(int(np.sum(d["label"]) > 0)) |
|
except: |
|
print("Error evaluating / tokenizing" |
|
" line {}, skipping it".format(i)) |
|
pass |
|
|
|
full_dataset = Dataset(x, y) |
|
train_size = int(0.9 * len(full_dataset)) |
|
test_size = len(full_dataset) - train_size |
|
train_dataset, test_dataset = torch.utils.data.random_split( |
|
full_dataset, [train_size, test_size] |
|
) |
|
|
|
discriminator_meta = { |
|
"class_size": len(idx2class), |
|
"embed_size": discriminator.embed_size, |
|
"pretrained_model": pretrained_model, |
|
"class_vocab": class2idx, |
|
"default_class": 0, |
|
} |
|
|
|
else: |
|
|
|
|
|
|
|
if dataset_fp is None: |
|
raise ValueError("When generic dataset is selected, " |
|
"dataset_fp needs to be specified aswell.") |
|
|
|
classes = set() |
|
with open(dataset_fp) as f: |
|
csv_reader = csv.reader(f, delimiter="\t") |
|
for row in tqdm(csv_reader, ascii=True): |
|
if row: |
|
classes.add(row[0]) |
|
|
|
idx2class = sorted(classes) |
|
class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
|
discriminator = Discriminator( |
|
class_size=len(idx2class), |
|
pretrained_model=pretrained_model, |
|
cached_mode=cached |
|
).to(device) |
|
|
|
x = [] |
|
y = [] |
|
with open(dataset_fp) as f: |
|
csv_reader = csv.reader(f, delimiter="\t") |
|
for i, row in enumerate(tqdm(csv_reader, ascii=True)): |
|
if row: |
|
label = row[0] |
|
text = row[1] |
|
|
|
try: |
|
seq = discriminator.tokenizer.encode(text) |
|
if (len(seq) < max_length_seq): |
|
seq = torch.tensor( |
|
[50256] + seq, |
|
device=device, |
|
dtype=torch.long |
|
) |
|
|
|
else: |
|
print( |
|
"Line {} is longer than maximum length {}".format( |
|
i, max_length_seq |
|
)) |
|
continue |
|
|
|
x.append(seq) |
|
y.append(class2idx[label]) |
|
|
|
except: |
|
print("Error tokenizing line {}, skipping it".format(i)) |
|
pass |
|
|
|
full_dataset = Dataset(x, y) |
|
train_size = int(0.9 * len(full_dataset)) |
|
test_size = len(full_dataset) - train_size |
|
train_dataset, test_dataset = torch.utils.data.random_split( |
|
full_dataset, |
|
[train_size, test_size] |
|
) |
|
|
|
discriminator_meta = { |
|
"class_size": len(idx2class), |
|
"embed_size": discriminator.embed_size, |
|
"pretrained_model": pretrained_model, |
|
"class_vocab": class2idx, |
|
"default_class": 0, |
|
} |
|
|
|
end = time.time() |
|
print("Preprocessed {} data points".format( |
|
len(train_dataset) + len(test_dataset)) |
|
) |
|
print("Data preprocessing took: {:.3f}s".format(end - start)) |
|
|
|
if cached: |
|
print("Building representation cache...") |
|
|
|
start = time.time() |
|
|
|
train_loader = get_cached_data_loader( |
|
train_dataset, batch_size, discriminator, shuffle=True |
|
) |
|
|
|
test_loader = get_cached_data_loader( |
|
test_dataset, batch_size, discriminator |
|
) |
|
|
|
end = time.time() |
|
print("Building representation cache took: {:.3f}s".format(end - start)) |
|
|
|
else: |
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=collate_fn) |
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, |
|
batch_size=batch_size, |
|
collate_fn=collate_fn) |
|
|
|
if save_model: |
|
with open("{}_classifier_head_meta.json".format(dataset), |
|
"w") as meta_file: |
|
json.dump(discriminator_meta, meta_file) |
|
|
|
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001) |
|
|
|
for epoch in range(epochs): |
|
start = time.time() |
|
print("\nEpoch", epoch + 1) |
|
|
|
train_epoch( |
|
discriminator=discriminator, |
|
data_loader=train_loader, |
|
optimizer=optimizer, |
|
epoch=epoch, |
|
log_interval=log_interval |
|
) |
|
evaluate_performance( |
|
data_loader=test_loader, |
|
discriminator=discriminator |
|
) |
|
|
|
end = time.time() |
|
print("Epoch took: {:.3f}s".format(end - start)) |
|
|
|
print("\nExample prediction") |
|
predict(example_sentence, discriminator, idx2class, cached) |
|
|
|
if save_model: |
|
|
|
|
|
|
|
|
|
torch.save(discriminator.get_classifier().state_dict(), |
|
"{}_classifier_head_epoch_{}.pt".format(dataset, |
|
epoch + 1)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Train a discriminator on top of GPT-2 representations") |
|
parser.add_argument("--dataset", type=str, default="SST", |
|
choices=("SST", "clickbait", "toxic", "generic"), |
|
help="dataset to train the discriminator on." |
|
"In case of generic, the dataset is expected" |
|
"to be a TSBV file with structure: class \\t text") |
|
parser.add_argument("--dataset_fp", type=str, default="", |
|
help="File path of the dataset to use. " |
|
"Needed only in case of generic datadset") |
|
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium", |
|
help="Pretrained model to use as encoder") |
|
parser.add_argument("--epochs", type=int, default=10, metavar="N", |
|
help="Number of training epochs") |
|
parser.add_argument("--batch_size", type=int, default=64, metavar="N", |
|
help="input batch size for training (default: 64)") |
|
parser.add_argument("--log_interval", type=int, default=10, metavar="N", |
|
help="how many batches to wait before logging training status") |
|
parser.add_argument("--save_model", action="store_true", |
|
help="whether to save the model") |
|
parser.add_argument("--cached", action="store_true", |
|
help="whether to cache the input representations") |
|
parser.add_argument("--no_cuda", action="store_true", |
|
help="use to turn off cuda") |
|
args = parser.parse_args() |
|
|
|
train_discriminator(**(vars(args))) |
|
|