Spaces:
Build error
Build error
import os | |
import re | |
from typing import Dict, Iterable, List, Optional, Tuple | |
import json | |
import random | |
import argparse | |
from allennlp.data.fields.field import Field | |
from allennlp.data.fields.sequence_field import SequenceField | |
from allennlp.models.model import Model | |
from allennlp.nn.util import get_text_field_mask | |
from allennlp.predictors.predictor import Predictor | |
import pandas as pd | |
import spacy | |
import torch | |
from sklearn.preprocessing import MultiLabelBinarizer | |
from allennlp.common.util import pad_sequence_to_length | |
from allennlp.data import TextFieldTensors | |
from allennlp.data.vocabulary import Vocabulary | |
from allennlp.data import DatasetReader, TokenIndexer, Instance, Token | |
from allennlp.data.fields import TextField, LabelField | |
from allennlp.data.token_indexers.pretrained_transformer_indexer import ( | |
PretrainedTransformerIndexer, | |
) | |
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import ( | |
PretrainedTransformerTokenizer, | |
) | |
from allennlp.models import BasicClassifier | |
from allennlp.modules.text_field_embedders.basic_text_field_embedder import ( | |
BasicTextFieldEmbedder, | |
) | |
from allennlp.modules.token_embedders.pretrained_transformer_embedder import ( | |
PretrainedTransformerEmbedder, | |
) | |
from allennlp.modules.seq2vec_encoders.bert_pooler import BertPooler | |
from allennlp.modules.seq2vec_encoders.cls_pooler import ClsPooler | |
from allennlp.training.checkpointer import Checkpointer | |
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer | |
from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader | |
from allennlp.training.optimizers import AdamOptimizer | |
from allennlp.predictors.text_classifier import TextClassifierPredictor | |
from allennlp.training.callbacks.tensorboard import TensorBoardCallback | |
from torch import nn | |
from torch.nn.functional import binary_cross_entropy_with_logits | |
random.seed(1986) | |
SEQ_LABELS = ["humansMentioned", "vehiclesMentioned", "eventVerb", "activeEventVerb"] | |
# adapted from bert-for-framenet project | |
class SequenceMultiLabelField(Field): | |
def __init__(self, | |
labels: List[List[str]], | |
sequence_field: SequenceField, | |
binarizer: MultiLabelBinarizer, | |
label_namespace: str | |
): | |
self.labels = labels | |
self._indexed_labels = None | |
self._label_namespace = label_namespace | |
self.sequence_field = sequence_field | |
self.binarizer = binarizer | |
def retokenize_tags(tags: List[List[str]], | |
offsets: List[Tuple[int, int]], | |
wp_primary_token: str = "last", | |
wp_secondary_tokens: str = "empty", | |
empty_value=lambda: [] | |
) -> List[List[str]]: | |
tags_per_wordpiece = [ | |
empty_value() # [CLS] | |
] | |
for i, (off_start, off_end) in enumerate(offsets): | |
tag = tags[i] | |
# put a tag on the first wordpiece corresponding to the word token | |
# e.g. "hello" --> "he" + "##ll" + "##o" --> 2 extra tokens | |
# TAGS: [..., TAG, None, None, ...] | |
num_extra_tokens = off_end - off_start | |
if wp_primary_token == "first": | |
tags_per_wordpiece.append(tag) | |
if wp_secondary_tokens == "repeat": | |
tags_per_wordpiece.extend(num_extra_tokens * [tag]) | |
else: | |
tags_per_wordpiece.extend(num_extra_tokens * [empty_value()]) | |
if wp_primary_token == "last": | |
tags_per_wordpiece.append(tag) | |
tags_per_wordpiece.append(empty_value()) # [SEP] | |
return tags_per_wordpiece | |
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): | |
for label_list in self.labels: | |
for label in label_list: | |
counter[self._label_namespace][label] += 1 | |
def get_padding_lengths(self) -> Dict[str, int]: | |
return {"num_tokens": self.sequence_field.sequence_length()} | |
def index(self, vocab: Vocabulary): | |
indexed_labels: List[List[int]] = [] | |
for sentence_labels in self.labels: | |
sentence_indexed_labels = [] | |
for label in sentence_labels: | |
try: | |
sentence_indexed_labels.append( | |
vocab.get_token_index(label, self._label_namespace)) | |
except KeyError: | |
print(f"[WARNING] Ignore unknown label {label}") | |
indexed_labels.append(sentence_indexed_labels) | |
self._indexed_labels = indexed_labels | |
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: | |
# binarize | |
binarized_seq = self.binarizer.transform(self._indexed_labels).tolist() | |
# padding | |
desired_num_tokens = padding_lengths["num_tokens"] | |
padded_tags = pad_sequence_to_length(binarized_seq, desired_num_tokens, | |
default_value=lambda: list(self.binarizer.transform([[]])[0])) | |
tensor = torch.tensor(padded_tags, dtype=torch.float) | |
return tensor | |
def empty_field(self) -> 'Field': | |
field = SequenceMultiLabelField( | |
[], self.sequence_field.empty_field(), self.binarizer, self._label_namespace) | |
field._indexed_labels = [] | |
return field | |
# adapted from bert-for-framenet project | |
class MultiSequenceLabelModel(Model): | |
def __init__(self, embedder: PretrainedTransformerEmbedder, decoder_output_size: int, hidden_size: int, vocab: Vocabulary, embedding_size: int = 768): | |
super().__init__(vocab) | |
self.embedder = embedder | |
self.out_features = decoder_output_size | |
self.hidden_size = hidden_size | |
self.layers = nn.Sequential( | |
nn.Linear(in_features=embedding_size, | |
out_features=self.hidden_size), | |
nn.ReLU(), | |
nn.Linear(in_features=self.hidden_size, | |
out_features=self.out_features) | |
) | |
def forward(self, tokens: TextFieldTensors, label: Optional[torch.FloatTensor] = None): | |
embeddings = self.embedder(tokens["token_ids"]) | |
mask = get_text_field_mask(tokens).float() | |
tag_logits = self.layers(embeddings) | |
mask = mask.reshape(mask.shape[0], mask.shape[1], 1).repeat(1, 1, self.out_features) | |
output = {"tag_logits": tag_logits} | |
if label is not None: | |
loss = binary_cross_entropy_with_logits(tag_logits, label, mask) | |
output["loss"] = loss | |
def get_metrics(self, _) -> Dict[str, float]: | |
return {} | |
def make_human_readable(self, | |
prediction, | |
label_namespace, | |
threshold=0.2, | |
sigmoid=True | |
) -> Tuple[List[str], Optional[List[float]]]: | |
if sigmoid: | |
prediction = torch.sigmoid(prediction) | |
predicted_labels: List[List[str]] = [[] for _ in range(len(prediction))] | |
# get all predictions with a positive probability | |
for coord in torch.nonzero(prediction > threshold): | |
label = self.vocab.get_token_from_index(int(coord[1]), label_namespace) | |
predicted_labels[coord[0]].append(f"{label}:{prediction[coord[0], coord[1]]:.3f}") | |
str_predictions: List[str] = [] | |
for label_list in predicted_labels: | |
str_predictions.append("|".join(label_list) or "_") | |
return str_predictions | |
class TrafficBechdelReader(DatasetReader): | |
def __init__(self, token_indexers, tokenizer, binarizer): | |
self.token_indexers = token_indexers | |
self.tokenizer: PretrainedTransformerTokenizer = tokenizer | |
self.binarizer = binarizer | |
self.orig_data = [] | |
super().__init__() | |
def _read(self, file_path) -> Iterable[Instance]: | |
self.orig_data.clear() | |
with open(file_path, encoding="utf-8") as f: | |
for line in f: | |
# skip any empty lines | |
if not line.strip(): | |
continue | |
sentence_parts = line.lstrip("[").rstrip("]").split(",") | |
token_txts = [] | |
token_mlabels = [] | |
for sp in sentence_parts: | |
sp_txt, sp_lbl_str = sp.split(":") | |
if sp_lbl_str == "[]": | |
sp_lbls = [] | |
else: | |
sp_lbls = sp_lbl_str.lstrip("[").rstrip("]").split("|") | |
# if the text is a WordNet thingy | |
wn_match = re.match(r"^(.+)-n-\d+$", sp_txt) | |
if wn_match: | |
sp_txt = wn_match.group(1) | |
# multi-token text | |
sp_toks = sp_txt.split() | |
for tok in sp_toks: | |
token_txts.append(tok) | |
token_mlabels.append(sp_lbls) | |
self.orig_data.append({ | |
"sentence": token_txts, | |
"labels": token_mlabels, | |
}) | |
yield self.text_to_instance(token_txts, token_mlabels) | |
def text_to_instance(self, sentence: List[str], labels: List[List[str]] = None) -> Instance: | |
tokens, offsets = self.tokenizer.intra_word_tokenize(sentence) | |
text_field = TextField(tokens, self.token_indexers) | |
fields = {"tokens": text_field} | |
if labels is not None: | |
labels_ = SequenceMultiLabelField.retokenize_tags(labels, offsets) | |
label_field = SequenceMultiLabelField(labels_, text_field, self.binarizer, "labels") | |
fields["label"] = label_field | |
return Instance(fields) | |
def count_parties(sentence, lexical_dicts, nlp): | |
num_humans = 0 | |
num_vehicles = 0 | |
def is_in_words(l, category): | |
for subcategory, words in lexical_dicts[category].items(): | |
if subcategory.startswith("WN:"): | |
words = [re.match(r"^(.+)-n-\d+$", w).group(1) for w in words] | |
if l in words: | |
return True | |
return False | |
doc = nlp(sentence.lower()) | |
for token in doc: | |
lemma = token.lemma_ | |
if is_in_words(lemma, "persons"): | |
num_humans += 1 | |
if is_in_words(lemma, "vehicles"): | |
num_vehicles += 1 | |
return num_humans, num_vehicles | |
def predict_rule_based(annotations="data/crashes/bechdel_annotations_dev_first_25.csv"): | |
data_crashes = pd.read_csv(annotations) | |
with open("output/crashes/predict_bechdel/lexical_dicts.json", encoding="utf-8") as f: | |
lexical_dicts = json.load(f) | |
nlp = spacy.load("nl_core_news_md") | |
for _, row in data_crashes.iterrows(): | |
sentence = row["sentence"] | |
num_humans, num_vehicles = count_parties(sentence, lexical_dicts, nlp) | |
print(sentence) | |
print(f"\thumans={num_humans}, vehicles={num_vehicles}") | |
def evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_first_25.csv", out_file="output/crashes/predict_bechdel/predictions_crashes25.csv"): | |
data_crashes = pd.read_csv(annotations) | |
labels_crashes = [ | |
{ | |
"party_mentioned": str(row["mentioned"]), | |
"party_human": str(row["as_human"]), | |
"active": str(True) if str(row["active"]).lower() == "true" else str(False) | |
} | |
for _, row in data_crashes.iterrows() | |
] | |
predictions_crashes = [predictor.predict( | |
row["sentence"]) for i, row in data_crashes.iterrows()] | |
crashes_out = [] | |
correct = 0 | |
partial_2_attrs = 0 | |
partial_1_attr = 0 | |
correct_mentions = 0 | |
correct_humans = 0 | |
correct_active = 0 | |
for sentence, label, prediction in zip(data_crashes["sentence"], labels_crashes, predictions_crashes): | |
predicted = prediction["label"] | |
if attrib == "all": | |
gold = "|".join([f"{k}={v}" for k, v in label.items()]) | |
else: | |
gold = label["attrib"] | |
if gold == predicted: | |
correct += 1 | |
if attrib == "all": | |
partial_2_attrs += 1 | |
partial_1_attr += 1 | |
if attrib == "all": | |
gold_attrs = set(gold.split("|")) | |
pred_attrs = set(predicted.split("|")) | |
if len(gold_attrs & pred_attrs) == 2: | |
partial_2_attrs += 1 | |
partial_1_attr += 1 | |
elif len(gold_attrs & pred_attrs) == 1: | |
partial_1_attr += 1 | |
if gold.split("|")[0] == predicted.split("|")[0]: | |
correct_mentions += 1 | |
if gold.split("|")[1] == predicted.split("|")[1]: | |
correct_humans += 1 | |
if gold.split("|")[2] == predicted.split("|")[2]: | |
correct_active += 1 | |
crashes_out.append( | |
{"sentence": sentence, "gold": gold, "prediction": predicted}) | |
print("ACC_crashes (strict) = ", correct/len(data_crashes)) | |
print("ACC_crashes (partial:2) = ", partial_2_attrs/len(data_crashes)) | |
print("ACC_crashes (partial:1) = ", partial_1_attr/len(data_crashes)) | |
print("ACC_crashes (mentions) = ", correct_mentions/len(data_crashes)) | |
print("ACC_crashes (humans) = ", correct_humans/len(data_crashes)) | |
print("ACC_crashes (active) = ", correct_active/len(data_crashes)) | |
pd.DataFrame(crashes_out).to_csv(out_file) | |
def filter_events_for_bechdel(): | |
with open("data/crashes/thecrashes_data_all_text.json", encoding="utf-8") as f: | |
events = json.load(f) | |
total_articles = 0 | |
data_out = [] | |
for ev in events: | |
total_articles += len(ev["articles"]) | |
num_persons = len(ev["persons"]) | |
num_transport_modes = len({p["transportationmode"] | |
for p in ev["persons"]}) | |
if num_transport_modes <= 2: | |
for art in ev["articles"]: | |
data_out.append({"event_id": ev["id"], "article_id": art["id"], "headline": art["title"], | |
"num_persons": num_persons, "num_transport_modes": num_transport_modes}) | |
print("Total articles = ", total_articles) | |
print("Filtered articles: ", len(data_out)) | |
out_df = pd.DataFrame(data_out) | |
out_df.to_csv("output/crashes/predict_bechdel/filtered_headlines.csv") | |
def train_and_eval(train=True): | |
# use_gpu = False | |
use_gpu = True | |
cuda_device = None if use_gpu and torch.cuda.is_available() else -1 | |
transformer = "GroNLP/bert-base-dutch-cased" | |
# transformer = "xlm-roberta-large" | |
token_indexers = {"tokens": PretrainedTransformerIndexer(transformer)} | |
tokenizer = PretrainedTransformerTokenizer(transformer) | |
binarizer = MultiLabelBinarizer() | |
binarizer.fit([SEQ_LABELS]) | |
reader = TrafficBechdelReader(token_indexers, tokenizer, binarizer) | |
instances = list(reader.read("output/prolog/bechdel_headlines.txt")) | |
orig_data = reader.orig_data | |
zipped = list(zip(instances, orig_data)) | |
random.shuffle(zipped) | |
instances_ = [i[0] for i in zipped] | |
orig_data_ = [i[1] for i in zipped] | |
num_dev = round(0.05 * len(instances_)) | |
num_test = round(0.25 * len(instances_)) | |
num_train = len(instances_) - num_dev - num_test | |
print("LEN(train/dev/test)=", num_train, num_dev, num_test) | |
instances_train = instances_[:num_train] | |
instances_dev = instances_[num_train:num_train + num_dev] | |
# instances_test = instances_[num_train+num_dev:num_train:] | |
# orig_train = orig_data_[:num_train] | |
orig_dev = orig_data_[num_train:num_train + num_dev] | |
vocab = Vocabulary.from_instances(instances_train + instances_dev) | |
embedder = BasicTextFieldEmbedder( | |
{"tokens": PretrainedTransformerEmbedder(transformer)}) | |
model = MultiSequenceLabelModel(embedder, len(SEQ_LABELS), 1000, vocab) | |
if use_gpu: | |
model = model.cuda(cuda_device) | |
# checkpoint_dir = f"output/crashes/predict_bechdel/model_{attrib}/" | |
checkpoint_dir = f"/scratch/p289731/predict_bechdel/model_seqlabel/" | |
serialization_dir = f"/scratch/p289731/predict_bechdel/serialization_seqlabel/" | |
if train: | |
os.makedirs(checkpoint_dir) | |
os.makedirs(serialization_dir) | |
tensorboard = TensorBoardCallback( | |
serialization_dir, should_log_learning_rate=True) | |
checkpointer = Checkpointer(serialization_dir=checkpoint_dir) | |
optimizer = AdamOptimizer( | |
[(n, p) for n, p in model.named_parameters() if p.requires_grad], | |
lr=1e-5 | |
) | |
train_loader = SimpleDataLoader( | |
instances_train, batch_size=8, shuffle=True) | |
dev_loader = SimpleDataLoader( | |
instances_dev, batch_size=8, shuffle=False) | |
train_loader.index_with(vocab) | |
dev_loader.index_with(vocab) | |
print("\t\tTraining BERT model") | |
trainer = GradientDescentTrainer( | |
model, | |
optimizer, | |
train_loader, | |
validation_data_loader=dev_loader, | |
# patience=32, | |
patience=2, | |
# num_epochs=1, | |
checkpointer=checkpointer, | |
cuda_device=cuda_device, | |
serialization_dir=serialization_dir, | |
callbacks=[tensorboard] | |
) | |
trainer.train() | |
else: | |
state_dict = torch.load( | |
"/scratch/p289731/predict_bechdel/serialization_all/best.th", map_location=cuda_device) | |
model.load_state_dict(state_dict) | |
print("\t\tProducing predictions...") | |
predictor = Predictor(model, reader) | |
predictions_dev = [predictor.predict_instance(i) for i in instances_dev] | |
data_out = [] | |
for sentence, prediction in zip(orig_dev, predictions_dev): | |
readable = model.make_human_readable(prediction, "labels") | |
text = sentence["sentence"] | |
gold = sentence["labels"] | |
predicted = readable | |
data_out.append( | |
{"sentence": text, "gold": gold, "predicted": predicted}) | |
df_out = pd.DataFrame(data_out) | |
df_out.to_csv("output/crashes/predict_bechdel/predictions_dev.csv") | |
# print() | |
# print("First 25 crashes:") | |
# evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_first_25.csv", | |
# out_file="output/crashes/predict_bechdel/predictions_first_25.csv") | |
# print() | |
# print("Next 75 crashes:") | |
# evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_next_75.csv", | |
# out_file="output/crashes/predict_bechdel/predictions_next_75.csv") | |
if __name__ == "__main__": | |
ap = argparse.ArgumentParser() | |
ap.add_argument("action", choices=["train", "predict", "rules", "filter"]) | |
args = ap.parse_args() | |
if args.action == "train": | |
train_and_eval(train=True) | |
elif args.action == "predict": | |
train_and_eval(train=False) | |
elif args.action == "rules": | |
predict_rule_based() | |
else: | |
filter_events_for_bechdel() | |