import os import re from typing import Dict, Iterable, List, Optional, Tuple import json import random import argparse from allennlp.data.fields.field import Field from allennlp.data.fields.sequence_field import SequenceField from allennlp.models.model import Model from allennlp.nn.util import get_text_field_mask from allennlp.predictors.predictor import Predictor import pandas as pd import spacy import torch from sklearn.preprocessing import MultiLabelBinarizer from allennlp.common.util import pad_sequence_to_length from allennlp.data import TextFieldTensors from allennlp.data.vocabulary import Vocabulary from allennlp.data import DatasetReader, TokenIndexer, Instance, Token from allennlp.data.fields import TextField, LabelField from allennlp.data.token_indexers.pretrained_transformer_indexer import ( PretrainedTransformerIndexer, ) from allennlp.data.tokenizers.pretrained_transformer_tokenizer import ( PretrainedTransformerTokenizer, ) from allennlp.models import BasicClassifier from allennlp.modules.text_field_embedders.basic_text_field_embedder import ( BasicTextFieldEmbedder, ) from allennlp.modules.token_embedders.pretrained_transformer_embedder import ( PretrainedTransformerEmbedder, ) from allennlp.modules.seq2vec_encoders.bert_pooler import BertPooler from allennlp.modules.seq2vec_encoders.cls_pooler import ClsPooler from allennlp.training.checkpointer import Checkpointer from allennlp.training.gradient_descent_trainer import GradientDescentTrainer from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader from allennlp.training.optimizers import AdamOptimizer from allennlp.predictors.text_classifier import TextClassifierPredictor from allennlp.training.callbacks.tensorboard import TensorBoardCallback from torch import nn from torch.nn.functional import binary_cross_entropy_with_logits random.seed(1986) SEQ_LABELS = ["humansMentioned", "vehiclesMentioned", "eventVerb", "activeEventVerb"] # adapted from bert-for-framenet project class SequenceMultiLabelField(Field): def __init__(self, labels: List[List[str]], sequence_field: SequenceField, binarizer: MultiLabelBinarizer, label_namespace: str ): self.labels = labels self._indexed_labels = None self._label_namespace = label_namespace self.sequence_field = sequence_field self.binarizer = binarizer @staticmethod def retokenize_tags(tags: List[List[str]], offsets: List[Tuple[int, int]], wp_primary_token: str = "last", wp_secondary_tokens: str = "empty", empty_value=lambda: [] ) -> List[List[str]]: tags_per_wordpiece = [ empty_value() # [CLS] ] for i, (off_start, off_end) in enumerate(offsets): tag = tags[i] # put a tag on the first wordpiece corresponding to the word token # e.g. "hello" --> "he" + "##ll" + "##o" --> 2 extra tokens # TAGS: [..., TAG, None, None, ...] num_extra_tokens = off_end - off_start if wp_primary_token == "first": tags_per_wordpiece.append(tag) if wp_secondary_tokens == "repeat": tags_per_wordpiece.extend(num_extra_tokens * [tag]) else: tags_per_wordpiece.extend(num_extra_tokens * [empty_value()]) if wp_primary_token == "last": tags_per_wordpiece.append(tag) tags_per_wordpiece.append(empty_value()) # [SEP] return tags_per_wordpiece def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): for label_list in self.labels: for label in label_list: counter[self._label_namespace][label] += 1 def get_padding_lengths(self) -> Dict[str, int]: return {"num_tokens": self.sequence_field.sequence_length()} def index(self, vocab: Vocabulary): indexed_labels: List[List[int]] = [] for sentence_labels in self.labels: sentence_indexed_labels = [] for label in sentence_labels: try: sentence_indexed_labels.append( vocab.get_token_index(label, self._label_namespace)) except KeyError: print(f"[WARNING] Ignore unknown label {label}") indexed_labels.append(sentence_indexed_labels) self._indexed_labels = indexed_labels def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: # binarize binarized_seq = self.binarizer.transform(self._indexed_labels).tolist() # padding desired_num_tokens = padding_lengths["num_tokens"] padded_tags = pad_sequence_to_length(binarized_seq, desired_num_tokens, default_value=lambda: list(self.binarizer.transform([[]])[0])) tensor = torch.tensor(padded_tags, dtype=torch.float) return tensor def empty_field(self) -> 'Field': field = SequenceMultiLabelField( [], self.sequence_field.empty_field(), self.binarizer, self._label_namespace) field._indexed_labels = [] return field # adapted from bert-for-framenet project class MultiSequenceLabelModel(Model): def __init__(self, embedder: PretrainedTransformerEmbedder, decoder_output_size: int, hidden_size: int, vocab: Vocabulary, embedding_size: int = 768): super().__init__(vocab) self.embedder = embedder self.out_features = decoder_output_size self.hidden_size = hidden_size self.layers = nn.Sequential( nn.Linear(in_features=embedding_size, out_features=self.hidden_size), nn.ReLU(), nn.Linear(in_features=self.hidden_size, out_features=self.out_features) ) def forward(self, tokens: TextFieldTensors, label: Optional[torch.FloatTensor] = None): embeddings = self.embedder(tokens["token_ids"]) mask = get_text_field_mask(tokens).float() tag_logits = self.layers(embeddings) mask = mask.reshape(mask.shape[0], mask.shape[1], 1).repeat(1, 1, self.out_features) output = {"tag_logits": tag_logits} if label is not None: loss = binary_cross_entropy_with_logits(tag_logits, label, mask) output["loss"] = loss def get_metrics(self, _) -> Dict[str, float]: return {} def make_human_readable(self, prediction, label_namespace, threshold=0.2, sigmoid=True ) -> Tuple[List[str], Optional[List[float]]]: if sigmoid: prediction = torch.sigmoid(prediction) predicted_labels: List[List[str]] = [[] for _ in range(len(prediction))] # get all predictions with a positive probability for coord in torch.nonzero(prediction > threshold): label = self.vocab.get_token_from_index(int(coord[1]), label_namespace) predicted_labels[coord[0]].append(f"{label}:{prediction[coord[0], coord[1]]:.3f}") str_predictions: List[str] = [] for label_list in predicted_labels: str_predictions.append("|".join(label_list) or "_") return str_predictions class TrafficBechdelReader(DatasetReader): def __init__(self, token_indexers, tokenizer, binarizer): self.token_indexers = token_indexers self.tokenizer: PretrainedTransformerTokenizer = tokenizer self.binarizer = binarizer self.orig_data = [] super().__init__() def _read(self, file_path) -> Iterable[Instance]: self.orig_data.clear() with open(file_path, encoding="utf-8") as f: for line in f: # skip any empty lines if not line.strip(): continue sentence_parts = line.lstrip("[").rstrip("]").split(",") token_txts = [] token_mlabels = [] for sp in sentence_parts: sp_txt, sp_lbl_str = sp.split(":") if sp_lbl_str == "[]": sp_lbls = [] else: sp_lbls = sp_lbl_str.lstrip("[").rstrip("]").split("|") # if the text is a WordNet thingy wn_match = re.match(r"^(.+)-n-\d+$", sp_txt) if wn_match: sp_txt = wn_match.group(1) # multi-token text sp_toks = sp_txt.split() for tok in sp_toks: token_txts.append(tok) token_mlabels.append(sp_lbls) self.orig_data.append({ "sentence": token_txts, "labels": token_mlabels, }) yield self.text_to_instance(token_txts, token_mlabels) def text_to_instance(self, sentence: List[str], labels: List[List[str]] = None) -> Instance: tokens, offsets = self.tokenizer.intra_word_tokenize(sentence) text_field = TextField(tokens, self.token_indexers) fields = {"tokens": text_field} if labels is not None: labels_ = SequenceMultiLabelField.retokenize_tags(labels, offsets) label_field = SequenceMultiLabelField(labels_, text_field, self.binarizer, "labels") fields["label"] = label_field return Instance(fields) def count_parties(sentence, lexical_dicts, nlp): num_humans = 0 num_vehicles = 0 def is_in_words(l, category): for subcategory, words in lexical_dicts[category].items(): if subcategory.startswith("WN:"): words = [re.match(r"^(.+)-n-\d+$", w).group(1) for w in words] if l in words: return True return False doc = nlp(sentence.lower()) for token in doc: lemma = token.lemma_ if is_in_words(lemma, "persons"): num_humans += 1 if is_in_words(lemma, "vehicles"): num_vehicles += 1 return num_humans, num_vehicles def predict_rule_based(annotations="data/crashes/bechdel_annotations_dev_first_25.csv"): data_crashes = pd.read_csv(annotations) with open("output/crashes/predict_bechdel/lexical_dicts.json", encoding="utf-8") as f: lexical_dicts = json.load(f) nlp = spacy.load("nl_core_news_md") for _, row in data_crashes.iterrows(): sentence = row["sentence"] num_humans, num_vehicles = count_parties(sentence, lexical_dicts, nlp) print(sentence) print(f"\thumans={num_humans}, vehicles={num_vehicles}") def evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_first_25.csv", out_file="output/crashes/predict_bechdel/predictions_crashes25.csv"): data_crashes = pd.read_csv(annotations) labels_crashes = [ { "party_mentioned": str(row["mentioned"]), "party_human": str(row["as_human"]), "active": str(True) if str(row["active"]).lower() == "true" else str(False) } for _, row in data_crashes.iterrows() ] predictions_crashes = [predictor.predict( row["sentence"]) for i, row in data_crashes.iterrows()] crashes_out = [] correct = 0 partial_2_attrs = 0 partial_1_attr = 0 correct_mentions = 0 correct_humans = 0 correct_active = 0 for sentence, label, prediction in zip(data_crashes["sentence"], labels_crashes, predictions_crashes): predicted = prediction["label"] if attrib == "all": gold = "|".join([f"{k}={v}" for k, v in label.items()]) else: gold = label["attrib"] if gold == predicted: correct += 1 if attrib == "all": partial_2_attrs += 1 partial_1_attr += 1 if attrib == "all": gold_attrs = set(gold.split("|")) pred_attrs = set(predicted.split("|")) if len(gold_attrs & pred_attrs) == 2: partial_2_attrs += 1 partial_1_attr += 1 elif len(gold_attrs & pred_attrs) == 1: partial_1_attr += 1 if gold.split("|")[0] == predicted.split("|")[0]: correct_mentions += 1 if gold.split("|")[1] == predicted.split("|")[1]: correct_humans += 1 if gold.split("|")[2] == predicted.split("|")[2]: correct_active += 1 crashes_out.append( {"sentence": sentence, "gold": gold, "prediction": predicted}) print("ACC_crashes (strict) = ", correct/len(data_crashes)) print("ACC_crashes (partial:2) = ", partial_2_attrs/len(data_crashes)) print("ACC_crashes (partial:1) = ", partial_1_attr/len(data_crashes)) print("ACC_crashes (mentions) = ", correct_mentions/len(data_crashes)) print("ACC_crashes (humans) = ", correct_humans/len(data_crashes)) print("ACC_crashes (active) = ", correct_active/len(data_crashes)) pd.DataFrame(crashes_out).to_csv(out_file) def filter_events_for_bechdel(): with open("data/crashes/thecrashes_data_all_text.json", encoding="utf-8") as f: events = json.load(f) total_articles = 0 data_out = [] for ev in events: total_articles += len(ev["articles"]) num_persons = len(ev["persons"]) num_transport_modes = len({p["transportationmode"] for p in ev["persons"]}) if num_transport_modes <= 2: for art in ev["articles"]: data_out.append({"event_id": ev["id"], "article_id": art["id"], "headline": art["title"], "num_persons": num_persons, "num_transport_modes": num_transport_modes}) print("Total articles = ", total_articles) print("Filtered articles: ", len(data_out)) out_df = pd.DataFrame(data_out) out_df.to_csv("output/crashes/predict_bechdel/filtered_headlines.csv") def train_and_eval(train=True): # use_gpu = False use_gpu = True cuda_device = None if use_gpu and torch.cuda.is_available() else -1 transformer = "GroNLP/bert-base-dutch-cased" # transformer = "xlm-roberta-large" token_indexers = {"tokens": PretrainedTransformerIndexer(transformer)} tokenizer = PretrainedTransformerTokenizer(transformer) binarizer = MultiLabelBinarizer() binarizer.fit([SEQ_LABELS]) reader = TrafficBechdelReader(token_indexers, tokenizer, binarizer) instances = list(reader.read("output/prolog/bechdel_headlines.txt")) orig_data = reader.orig_data zipped = list(zip(instances, orig_data)) random.shuffle(zipped) instances_ = [i[0] for i in zipped] orig_data_ = [i[1] for i in zipped] num_dev = round(0.05 * len(instances_)) num_test = round(0.25 * len(instances_)) num_train = len(instances_) - num_dev - num_test print("LEN(train/dev/test)=", num_train, num_dev, num_test) instances_train = instances_[:num_train] instances_dev = instances_[num_train:num_train + num_dev] # instances_test = instances_[num_train+num_dev:num_train:] # orig_train = orig_data_[:num_train] orig_dev = orig_data_[num_train:num_train + num_dev] vocab = Vocabulary.from_instances(instances_train + instances_dev) embedder = BasicTextFieldEmbedder( {"tokens": PretrainedTransformerEmbedder(transformer)}) model = MultiSequenceLabelModel(embedder, len(SEQ_LABELS), 1000, vocab) if use_gpu: model = model.cuda(cuda_device) # checkpoint_dir = f"output/crashes/predict_bechdel/model_{attrib}/" checkpoint_dir = f"/scratch/p289731/predict_bechdel/model_seqlabel/" serialization_dir = f"/scratch/p289731/predict_bechdel/serialization_seqlabel/" if train: os.makedirs(checkpoint_dir) os.makedirs(serialization_dir) tensorboard = TensorBoardCallback( serialization_dir, should_log_learning_rate=True) checkpointer = Checkpointer(serialization_dir=checkpoint_dir) optimizer = AdamOptimizer( [(n, p) for n, p in model.named_parameters() if p.requires_grad], lr=1e-5 ) train_loader = SimpleDataLoader( instances_train, batch_size=8, shuffle=True) dev_loader = SimpleDataLoader( instances_dev, batch_size=8, shuffle=False) train_loader.index_with(vocab) dev_loader.index_with(vocab) print("\t\tTraining BERT model") trainer = GradientDescentTrainer( model, optimizer, train_loader, validation_data_loader=dev_loader, # patience=32, patience=2, # num_epochs=1, checkpointer=checkpointer, cuda_device=cuda_device, serialization_dir=serialization_dir, callbacks=[tensorboard] ) trainer.train() else: state_dict = torch.load( "/scratch/p289731/predict_bechdel/serialization_all/best.th", map_location=cuda_device) model.load_state_dict(state_dict) print("\t\tProducing predictions...") predictor = Predictor(model, reader) predictions_dev = [predictor.predict_instance(i) for i in instances_dev] data_out = [] for sentence, prediction in zip(orig_dev, predictions_dev): readable = model.make_human_readable(prediction, "labels") text = sentence["sentence"] gold = sentence["labels"] predicted = readable data_out.append( {"sentence": text, "gold": gold, "predicted": predicted}) df_out = pd.DataFrame(data_out) df_out.to_csv("output/crashes/predict_bechdel/predictions_dev.csv") # print() # print("First 25 crashes:") # evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_first_25.csv", # out_file="output/crashes/predict_bechdel/predictions_first_25.csv") # print() # print("Next 75 crashes:") # evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_next_75.csv", # out_file="output/crashes/predict_bechdel/predictions_next_75.csv") if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("action", choices=["train", "predict", "rules", "filter"]) args = ap.parse_args() if args.action == "train": train_and_eval(train=True) elif args.action == "predict": train_and_eval(train=False) elif args.action == "rules": predict_rule_based() else: filter_events_for_bechdel()