sayakpaul's picture
sayakpaul HF staff
add files
c4b2b37
raw
history blame
32.9 kB
# TODO consider if this can be collapsed back down into the pipeline_train.py
import argparse
import json
import logging
import os
import random
from collections import OrderedDict
from itertools import chain
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
from BERT_explainability.modules.BERT.BERT_cls_lrp import \
BertForSequenceClassification as BertForClsOrigLrp
from BERT_explainability.modules.BERT.BertForSequenceClassification import \
BertForSequenceClassification as BertForSequenceClassificationTest
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_rationale_benchmark.utils import (Annotation, Evidence,
load_datasets, load_documents,
write_jsonl)
from sklearn.metrics import accuracy_score
from transformers import BertForSequenceClassification, BertTokenizer
logging.basicConfig(
level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s"
)
logger = logging.getLogger(__name__)
# let's make this more or less deterministic (not resistent to restarts)
random.seed(12345)
np.random.seed(67890)
torch.manual_seed(10111213)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
latex_special_token = ["!@#$%^&*()"]
def generate(text_list, attention_list, latex_file, color="red", rescale_value=False):
attention_list = attention_list[: len(text_list)]
if attention_list.max() == attention_list.min():
attention_list = torch.zeros_like(attention_list)
else:
attention_list = (
100
* (attention_list - attention_list.min())
/ (attention_list.max() - attention_list.min())
)
attention_list[attention_list < 1] = 0
attention_list = attention_list.tolist()
text_list = [text_list[i].replace("$", "") for i in range(len(text_list))]
if rescale_value:
attention_list = rescale(attention_list)
word_num = len(text_list)
text_list = clean_word(text_list)
with open(latex_file, "w") as f:
f.write(
r"""\documentclass[varwidth=150mm]{standalone}
\special{papersize=210mm,297mm}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}"""
+ "\n"
)
string = (
r"""{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{"""
+ "\n"
)
for idx in range(word_num):
# string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
# print(text_list[idx])
if "\#\#" in text_list[idx]:
token = text_list[idx].replace("\#\#", "")
string += (
"\\colorbox{%s!%s}{" % (color, attention_list[idx])
+ "\\strut "
+ token
+ "}"
)
else:
string += (
" "
+ "\\colorbox{%s!%s}{" % (color, attention_list[idx])
+ "\\strut "
+ text_list[idx]
+ "}"
)
string += "\n}}}"
f.write(string + "\n")
f.write(
r"""\end{CJK*}
\end{document}"""
)
def clean_word(word_list):
new_word_list = []
for word in word_list:
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
if latex_sensitive in word:
word = word.replace(latex_sensitive, "\\" + latex_sensitive)
new_word_list.append(word)
return new_word_list
def scores_per_word_from_scores_per_token(input, tokenizer, input_ids, scores_per_id):
words = tokenizer.convert_ids_to_tokens(input_ids)
words = [word.replace("##", "") for word in words]
score_per_char = []
# TODO: DELETE
input_ids_chars = []
for word in words:
if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
continue
input_ids_chars += list(word)
# TODO: DELETE
for i in range(len(scores_per_id)):
if words[i] in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
continue
score_per_char += [scores_per_id[i]] * len(words[i])
score_per_word = []
start_idx = 0
end_idx = 0
# TODO: DELETE
words_from_chars = []
for inp in input:
if start_idx >= len(score_per_char):
break
end_idx = end_idx + len(inp)
score_per_word.append(np.max(score_per_char[start_idx:end_idx]))
# TODO: DELETE
words_from_chars.append("".join(input_ids_chars[start_idx:end_idx]))
start_idx = end_idx
if words_from_chars[:-1] != input[: len(words_from_chars) - 1]:
print(words_from_chars)
print(input[: len(words_from_chars)])
print(words)
print(tokenizer.convert_ids_to_tokens(input_ids))
assert False
return torch.tensor(score_per_word)
def get_input_words(input, tokenizer, input_ids):
words = tokenizer.convert_ids_to_tokens(input_ids)
words = [word.replace("##", "") for word in words]
input_ids_chars = []
for word in words:
if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]:
continue
input_ids_chars += list(word)
start_idx = 0
end_idx = 0
words_from_chars = []
for inp in input:
if start_idx >= len(input_ids_chars):
break
end_idx = end_idx + len(inp)
words_from_chars.append("".join(input_ids_chars[start_idx:end_idx]))
start_idx = end_idx
if words_from_chars[:-1] != input[: len(words_from_chars) - 1]:
print(words_from_chars)
print(input[: len(words_from_chars)])
print(words)
print(tokenizer.convert_ids_to_tokens(input_ids))
assert False
return words_from_chars
def bert_tokenize_doc(
doc: List[List[str]], tokenizer, special_token_map
) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]:
"""Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words"""
sents = []
sent_token_spans = []
for sent in doc:
tokens = []
spans = []
start = 0
for w in sent:
if w in special_token_map:
tokens.append(w)
else:
tokens.extend(tokenizer.tokenize(w))
end = len(tokens)
spans.append((start, end))
start = end
sents.append(tokens)
sent_token_spans.append(spans)
return sents, sent_token_spans
def initialize_models(params: dict, batch_first: bool, use_half_precision=False):
assert batch_first
max_length = params["max_length"]
tokenizer = BertTokenizer.from_pretrained(params["bert_vocab"])
pad_token_id = tokenizer.pad_token_id
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
bert_dir = params["bert_dir"]
evidence_classes = dict(
(y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"])
)
evidence_classifier = BertForSequenceClassification.from_pretrained(
bert_dir, num_labels=len(evidence_classes)
)
word_interner = tokenizer.vocab
de_interner = tokenizer.ids_to_tokens
return evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer
BATCH_FIRST = True
def extract_docid_from_dataset_element(element):
return next(iter(element.evidences))[0].docid
def extract_evidence_from_dataset_element(element):
return next(iter(element.evidences))
def main():
parser = argparse.ArgumentParser(
description="""Trains a pipeline model.
Step 1 is evidence identification, that is identify if a given sentence is evidence or not
Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task
(e.g. sentiment or significance).
These models should be separated into two separate steps, but at the moment:
* prep data (load, intern documents, load json)
* convert data for evidence identification - in the case of training data we take all the positives and sample some
negatives
* side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a
broader sampling of negative values.
* train evidence identification
* convert data for evidence classification - take all rationales + decisions and use this as input
* train evidence classification
* decode first the evidence, then run classification for each split
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--data_dir",
dest="data_dir",
required=True,
help="Which directory contains a {train,val,test}.jsonl file?",
)
parser.add_argument(
"--output_dir",
dest="output_dir",
required=True,
help="Where shall we write intermediate models + final data to?",
)
parser.add_argument(
"--model_params",
dest="model_params",
required=True,
help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.",
)
args = parser.parse_args()
assert BATCH_FIRST
os.makedirs(args.output_dir, exist_ok=True)
with open(args.model_params, "r") as fp:
logger.info(f"Loading model parameters from {args.model_params}")
model_params = json.load(fp)
logger.info(f"Params: {json.dumps(model_params, indent=2, sort_keys=True)}")
train, val, test = load_datasets(args.data_dir)
docids = set(
e.docid
for e in chain.from_iterable(
chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test)))
)
)
documents = load_documents(args.data_dir, docids)
logger.info(f"Loaded {len(documents)} documents")
(
evidence_classifier,
word_interner,
de_interner,
evidence_classes,
tokenizer,
) = initialize_models(model_params, batch_first=BATCH_FIRST)
logger.info(f"We have {len(word_interner)} wordpieces")
cache = os.path.join(args.output_dir, "preprocessed.pkl")
if os.path.exists(cache):
logger.info(f"Loading interned documents from {cache}")
(interned_documents) = torch.load(cache)
else:
logger.info(f"Interning documents")
interned_documents = {}
for d, doc in documents.items():
encoding = tokenizer.encode_plus(
doc,
add_special_tokens=True,
max_length=model_params["max_length"],
return_token_type_ids=False,
pad_to_max_length=False,
return_attention_mask=True,
return_tensors="pt",
truncation=True,
)
interned_documents[d] = encoding
torch.save((interned_documents), cache)
evidence_classifier = evidence_classifier.cuda()
optimizer = None
scheduler = None
save_dir = args.output_dir
logging.info(f"Beginning training classifier")
evidence_classifier_output_dir = os.path.join(save_dir, "classifier")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(evidence_classifier_output_dir, exist_ok=True)
model_save_file = os.path.join(evidence_classifier_output_dir, "classifier.pt")
epoch_save_file = os.path.join(
evidence_classifier_output_dir, "classifier_epoch_data.pt"
)
device = next(evidence_classifier.parameters()).device
if optimizer is None:
optimizer = torch.optim.Adam(
evidence_classifier.parameters(),
lr=model_params["evidence_classifier"]["lr"],
)
criterion = nn.CrossEntropyLoss(reduction="none")
batch_size = model_params["evidence_classifier"]["batch_size"]
epochs = model_params["evidence_classifier"]["epochs"]
patience = model_params["evidence_classifier"]["patience"]
max_grad_norm = model_params["evidence_classifier"].get("max_grad_norm", None)
class_labels = [k for k, v in sorted(evidence_classes.items())]
results = {
"train_loss": [],
"train_f1": [],
"train_acc": [],
"val_loss": [],
"val_f1": [],
"val_acc": [],
}
best_epoch = -1
best_val_acc = 0
best_val_loss = float("inf")
best_model_state_dict = None
start_epoch = 0
epoch_data = {}
if os.path.exists(epoch_save_file):
logging.info(f"Restoring model from {model_save_file}")
evidence_classifier.load_state_dict(torch.load(model_save_file))
epoch_data = torch.load(epoch_save_file)
start_epoch = epoch_data["epoch"] + 1
# handle finishing because patience was exceeded or we didn't get the best final epoch
if bool(epoch_data.get("done", 0)):
start_epoch = epochs
results = epoch_data["results"]
best_epoch = start_epoch
best_model_state_dict = OrderedDict(
{k: v.cpu() for k, v in evidence_classifier.state_dict().items()}
)
logging.info(f"Restoring training from epoch {start_epoch}")
logging.info(
f"Training evidence classifier from epoch {start_epoch} until epoch {epochs}"
)
optimizer.zero_grad()
for epoch in range(start_epoch, epochs):
epoch_train_data = random.sample(train, k=len(train))
epoch_train_loss = 0
epoch_training_acc = 0
evidence_classifier.train()
logging.info(
f"Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples"
)
for batch_start in range(0, len(epoch_train_data), batch_size):
batch_elements = epoch_train_data[
batch_start : min(batch_start + batch_size, len(epoch_train_data))
]
targets = [evidence_classes[s.classification] for s in batch_elements]
targets = torch.tensor(targets, dtype=torch.long, device=device)
samples_encoding = [
interned_documents[extract_docid_from_dataset_element(s)]
for s in batch_elements
]
input_ids = (
torch.stack(
[
samples_encoding[i]["input_ids"]
for i in range(len(samples_encoding))
]
)
.squeeze(1)
.to(device)
)
attention_masks = (
torch.stack(
[
samples_encoding[i]["attention_mask"]
for i in range(len(samples_encoding))
]
)
.squeeze(1)
.to(device)
)
preds = evidence_classifier(
input_ids=input_ids, attention_mask=attention_masks
)[0]
epoch_training_acc += accuracy_score(
preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False
)
loss = criterion(preds, targets.to(device=preds.device)).sum()
epoch_train_loss += loss.item()
loss.backward()
assert loss == loss # for nans
if max_grad_norm:
torch.nn.utils.clip_grad_norm_(
evidence_classifier.parameters(), max_grad_norm
)
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
epoch_train_loss /= len(epoch_train_data)
epoch_training_acc /= len(epoch_train_data)
assert epoch_train_loss == epoch_train_loss # for nans
results["train_loss"].append(epoch_train_loss)
logging.info(f"Epoch {epoch} training loss {epoch_train_loss}")
logging.info(f"Epoch {epoch} training accuracy {epoch_training_acc}")
with torch.no_grad():
epoch_val_loss = 0
epoch_val_acc = 0
epoch_val_data = random.sample(val, k=len(val))
evidence_classifier.eval()
val_batch_size = 32
logging.info(
f"Validating with {len(epoch_val_data) // val_batch_size} batches with {len(epoch_val_data)} examples"
)
for batch_start in range(0, len(epoch_val_data), val_batch_size):
batch_elements = epoch_val_data[
batch_start : min(batch_start + val_batch_size, len(epoch_val_data))
]
targets = [evidence_classes[s.classification] for s in batch_elements]
targets = torch.tensor(targets, dtype=torch.long, device=device)
samples_encoding = [
interned_documents[extract_docid_from_dataset_element(s)]
for s in batch_elements
]
input_ids = (
torch.stack(
[
samples_encoding[i]["input_ids"]
for i in range(len(samples_encoding))
]
)
.squeeze(1)
.to(device)
)
attention_masks = (
torch.stack(
[
samples_encoding[i]["attention_mask"]
for i in range(len(samples_encoding))
]
)
.squeeze(1)
.to(device)
)
preds = evidence_classifier(
input_ids=input_ids, attention_mask=attention_masks
)[0]
epoch_val_acc += accuracy_score(
preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False
)
loss = criterion(preds, targets.to(device=preds.device)).sum()
epoch_val_loss += loss.item()
epoch_val_loss /= len(val)
epoch_val_acc /= len(val)
results["val_acc"].append(epoch_val_acc)
results["val_loss"] = epoch_val_loss
logging.info(f"Epoch {epoch} val loss {epoch_val_loss}")
logging.info(f"Epoch {epoch} val acc {epoch_val_acc}")
if epoch_val_acc > best_val_acc or (
epoch_val_acc == best_val_acc and epoch_val_loss < best_val_loss
):
best_model_state_dict = OrderedDict(
{k: v.cpu() for k, v in evidence_classifier.state_dict().items()}
)
best_epoch = epoch
best_val_acc = epoch_val_acc
best_val_loss = epoch_val_loss
epoch_data = {
"epoch": epoch,
"results": results,
"best_val_acc": best_val_acc,
"done": 0,
}
torch.save(evidence_classifier.state_dict(), model_save_file)
torch.save(epoch_data, epoch_save_file)
logging.debug(
f"Epoch {epoch} new best model with val accuracy {epoch_val_acc}"
)
if epoch - best_epoch > patience:
logging.info(f"Exiting after epoch {epoch} due to no improvement")
epoch_data["done"] = 1
torch.save(epoch_data, epoch_save_file)
break
epoch_data["done"] = 1
epoch_data["results"] = results
torch.save(epoch_data, epoch_save_file)
evidence_classifier.load_state_dict(best_model_state_dict)
evidence_classifier = evidence_classifier.to(device=device)
evidence_classifier.eval()
# test
test_classifier = BertForSequenceClassificationTest.from_pretrained(
model_params["bert_dir"], num_labels=len(evidence_classes)
).to(device)
orig_lrp_classifier = BertForClsOrigLrp.from_pretrained(
model_params["bert_dir"], num_labels=len(evidence_classes)
).to(device)
if os.path.exists(epoch_save_file):
logging.info(f"Restoring model from {model_save_file}")
test_classifier.load_state_dict(torch.load(model_save_file))
orig_lrp_classifier.load_state_dict(torch.load(model_save_file))
test_classifier.eval()
orig_lrp_classifier.eval()
test_batch_size = 1
logging.info(
f"Testing with {len(test) // test_batch_size} batches with {len(test)} examples"
)
# explainability
explanations = Generator(test_classifier)
explanations_orig_lrp = Generator(orig_lrp_classifier)
method = "transformer_attribution"
method_folder = {
"transformer_attribution": "ours",
"partial_lrp": "partial_lrp",
"last_attn": "last_attn",
"attn_gradcam": "attn_gradcam",
"lrp": "lrp",
"rollout": "rollout",
"ground_truth": "ground_truth",
"generate_all": "generate_all",
}
method_expl = {
"transformer_attribution": explanations.generate_LRP,
"partial_lrp": explanations_orig_lrp.generate_LRP_last_layer,
"last_attn": explanations_orig_lrp.generate_attn_last_layer,
"attn_gradcam": explanations_orig_lrp.generate_attn_gradcam,
"lrp": explanations_orig_lrp.generate_full_lrp,
"rollout": explanations_orig_lrp.generate_rollout,
}
os.makedirs(os.path.join(args.output_dir, method_folder[method]), exist_ok=True)
result_files = []
for i in range(5, 85, 5):
result_files.append(
open(
os.path.join(
args.output_dir, "{0}/identifier_results_{1}.json"
).format(method_folder[method], i),
"w",
)
)
j = 0
for batch_start in range(0, len(test), test_batch_size):
batch_elements = test[
batch_start : min(batch_start + test_batch_size, len(test))
]
targets = [evidence_classes[s.classification] for s in batch_elements]
targets = torch.tensor(targets, dtype=torch.long, device=device)
samples_encoding = [
interned_documents[extract_docid_from_dataset_element(s)]
for s in batch_elements
]
input_ids = (
torch.stack(
[
samples_encoding[i]["input_ids"]
for i in range(len(samples_encoding))
]
)
.squeeze(1)
.to(device)
)
attention_masks = (
torch.stack(
[
samples_encoding[i]["attention_mask"]
for i in range(len(samples_encoding))
]
)
.squeeze(1)
.to(device)
)
preds = test_classifier(
input_ids=input_ids, attention_mask=attention_masks
)[0]
for s in batch_elements:
doc_name = extract_docid_from_dataset_element(s)
inp = documents[doc_name].split()
classification = "neg" if targets.item() == 0 else "pos"
is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0
if method == "generate_all":
file_name = "{0}_{1}_{2}.tex".format(
j, classification, is_classification_correct
)
GT_global = os.path.join(
args.output_dir, "{0}/visual_results_{1}.pdf"
).format(method_folder["ground_truth"], j)
GT_ours = os.path.join(
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
).format(
method_folder["transformer_attribution"],
j,
classification,
is_classification_correct,
)
CF_ours = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
method_folder["transformer_attribution"], j
)
GT_partial = os.path.join(
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
).format(
method_folder["partial_lrp"],
j,
classification,
is_classification_correct,
)
CF_partial = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
method_folder["partial_lrp"], j
)
GT_gradcam = os.path.join(
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
).format(
method_folder["attn_gradcam"],
j,
classification,
is_classification_correct,
)
CF_gradcam = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
method_folder["attn_gradcam"], j
)
GT_lrp = os.path.join(
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
).format(
method_folder["lrp"],
j,
classification,
is_classification_correct,
)
CF_lrp = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format(
method_folder["lrp"], j
)
GT_lastattn = os.path.join(
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
).format(
method_folder["last_attn"],
j,
classification,
is_classification_correct,
)
GT_rollout = os.path.join(
args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf"
).format(
method_folder["rollout"],
j,
classification,
is_classification_correct,
)
with open(file_name, "w") as f:
f.write(
r"""\documentclass[varwidth]{standalone}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}
{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{
\setlength{\tabcolsep}{2pt} % Default value: 6pt
\begin{tabular}{ccc}
\includegraphics[width=0.32\linewidth]{"""
+ GT_global
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ GT_ours
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ CF_ours
+ """}\\\\
(a) & (b) & (c)\\\\
\includegraphics[width=0.32\linewidth]{"""
+ GT_partial
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ CF_partial
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ GT_gradcam
+ """}\\\\
(d) & (e) & (f)\\\\
\includegraphics[width=0.32\linewidth]{"""
+ CF_gradcam
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ GT_lrp
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ CF_lrp
+ """}\\\\
(g) & (h) & (i)\\\\
\includegraphics[width=0.32\linewidth]{"""
+ GT_lastattn
+ """}&
\includegraphics[width=0.32\linewidth]{"""
+ GT_rollout
+ """}&\\\\
(j) & (k)&\\\\
\end{tabular}
}}}
\end{CJK*}
\end{document}
)"""
)
j += 1
break
if method == "ground_truth":
inp_cropped = get_input_words(inp, tokenizer, input_ids[0])
cam = torch.zeros(len(inp_cropped))
for evidence in extract_evidence_from_dataset_element(s):
start_idx = evidence.start_token
if start_idx >= len(cam):
break
end_idx = evidence.end_token
cam[start_idx:end_idx] = 1
generate(
inp_cropped,
cam,
(
os.path.join(
args.output_dir, "{0}/visual_results_{1}.tex"
).format(method_folder[method], j)
),
color="green",
)
j = j + 1
break
text = tokenizer.convert_ids_to_tokens(input_ids[0])
classification = "neg" if targets.item() == 0 else "pos"
is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0
target_idx = targets.item()
cam_target = method_expl[method](
input_ids=input_ids,
attention_mask=attention_masks,
index=target_idx,
)[0]
cam_target = cam_target.clamp(min=0)
generate(
text,
cam_target,
(
os.path.join(args.output_dir, "{0}/{1}_GT_{2}_{3}.tex").format(
method_folder[method],
j,
classification,
is_classification_correct,
)
),
)
if method in [
"transformer_attribution",
"partial_lrp",
"attn_gradcam",
"lrp",
]:
cam_false_class = method_expl[method](
input_ids=input_ids,
attention_mask=attention_masks,
index=1 - target_idx,
)[0]
cam_false_class = cam_false_class.clamp(min=0)
generate(
text,
cam_false_class,
(
os.path.join(args.output_dir, "{0}/{1}_CF.tex").format(
method_folder[method], j
)
),
)
cam = cam_target
cam = scores_per_word_from_scores_per_token(
inp, tokenizer, input_ids[0], cam
)
j = j + 1
doc_name = extract_docid_from_dataset_element(s)
hard_rationales = []
for res, i in enumerate(range(5, 85, 5)):
print("calculating top ", i)
_, indices = cam.topk(k=i)
for index in indices.tolist():
hard_rationales.append(
{"start_token": index, "end_token": index + 1}
)
result_dict = {
"annotation_id": doc_name,
"rationales": [
{
"docid": doc_name,
"hard_rationale_predictions": hard_rationales,
}
],
}
result_files[res].write(json.dumps(result_dict) + "\n")
for i in range(len(result_files)):
result_files[i].close()
if __name__ == "__main__":
main()