|
|
|
"""context |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1qLh1aASQj5HIENPZpHQltTuShZny_567 |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import wanb |
|
from pprint import pprint |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
from torch.utils.data import DataLoader |
|
from transformers import AdamW |
|
from tqdm.notebook import tqdm |
|
from transformers import BertForQuestionAnswering,BertTokenizer,BertTokenizerFast |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import pandas as pd |
|
|
|
|
|
|
|
wandb.login() |
|
|
|
|
|
PROJECT_NAME="context" |
|
ENTITY=None |
|
|
|
sweep_config={ |
|
'method':'random' |
|
} |
|
|
|
|
|
metric = { |
|
'name': 'Validation accuracy', |
|
'goal': 'maximize' |
|
} |
|
sweep_config['metric'] = metric |
|
|
|
|
|
parameters_dict = { |
|
'epochs':{ |
|
'values': [1] |
|
}, |
|
'optimizer':{ |
|
'values': ['sgd','adam'] |
|
}, |
|
'momentum':{ |
|
'distribution': 'uniform', |
|
'min': 0.5, |
|
'max': 0.99 |
|
}, |
|
'batch_size':{ |
|
'distribution': 'q_log_uniform_values', |
|
'q': 8, |
|
'min': 16, |
|
'max': 256 |
|
} |
|
} |
|
sweep_config['parameters'] = parameters_dict |
|
|
|
|
|
pprint(sweep_config) |
|
|
|
|
|
sweep_id=wandb.sweep(sweep_config,project=PROJECT_NAME,entity=ENTITY) |
|
|
|
|
|
from google.colab import drive |
|
drive.mount('/content/drive') |
|
|
|
if not os.path.exists('/content/drive/MyDrive/BERT-SQuAD'): |
|
os.mkdir('/content/drive/MyDrive/BERT-SQuAD') |
|
|
|
|
|
|
|
|
|
|
|
"""Load the training dataset and take a look at it""" |
|
with open('train-v2.0.json','rb') as f: |
|
squad=json.load(f) |
|
|
|
|
|
squad['data'][150]['paragraphs'][0]['context'] |
|
|
|
"""Load the dev dataset and take a look at it""" |
|
def read_data(path): |
|
|
|
with open(path,'rb') as f: |
|
squad=json.load(f) |
|
|
|
contexts=[] |
|
questions=[] |
|
answers=[] |
|
for group in squad['data']: |
|
for passage in group['paragraphs']: |
|
context=passage['context'] |
|
for qna in passage['qas']: |
|
question=qna['question'] |
|
for answer in qna['answers']: |
|
contexts.append(context) |
|
questions.append(question) |
|
answers.append(answer) |
|
return contexts,questions,answers |
|
|
|
|
|
|
|
""" |
|
The answers are dictionaries whith the answer text and an integer which indicates the start index of the answer in the context. |
|
""" |
|
train_contexts,train_questions,train_answers=read_data('train-v2.0.json') |
|
valid_contexts,valid_questions,valid_answers=read_data('dev-v2.0.json') |
|
|
|
|
|
|
|
def end_idx(answers,contexts): |
|
for answers,context in zip(answers,contexts): |
|
gold_text=answers['text'] |
|
start_idx=answers['answer_start'] |
|
end_idx=start_idx+len(gold_text) |
|
|
|
|
|
if context[start_idx:end_idx] == gold_text: |
|
answers['answer_end'] = end_idx |
|
elif context[start_idx-1:end_idx-1] == gold_text: |
|
answers['answer_start'] = start_idx - 1 |
|
answers['answer_end'] = end_idx - 1 |
|
elif context[start_idx-2:end_idx-2] == gold_text: |
|
answers['answer_start'] = start_idx - 2 |
|
answers['answer_end'] = end_idx - 2 |
|
|
|
|
|
""""Tokenization""" |
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True) |
|
valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True) |
|
|
|
|
|
|
|
|
|
def add_token_positions(encodings,answers): |
|
start_positions=[] |
|
end_positions=[] |
|
for i in range(len(answers)): |
|
start_positions.append(encodings.char_to_token(i,answers[i]['answer_start'])) |
|
end_positions.append(encodings.char_to_token(i,answers[i]['answer_end'])) |
|
|
|
|
|
if start_positions[-1] is None: |
|
start_positions[-1] = tokenizer.model_max_length |
|
if end_positions[-1] is None: |
|
end_positions[-1] = tokenizer.model_max_length |
|
|
|
encodings.update({'start_positions': start_positions, 'end_positions': end_positions}) |
|
|
|
|
|
"""Dataloader for the training dataset""" |
|
class DatasetRetriever(Dataset): |
|
def __init__(self,encodings): |
|
self.encodings=encodings |
|
|
|
def __getitem__(self,idx): |
|
return {key:torch.tensor(val[idx]) for key,val in self.encodings.items()} |
|
|
|
def __len__(self): |
|
return len(self.encodings.input_ids) |
|
|
|
|
|
train_dataset=DatasetRetriever(train_encodings) |
|
valid_dataset=DatasetRetriever(valid_encodings) |
|
train_loader=DataLoader(train_dataset,batch_size=16,shuffle=True) |
|
valid_loader=DataLoader(valid_dataset,batch_size=16) |
|
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased") |
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
|
|
def pipeline(): |
|
epochs=1, |
|
optimizer = torch.optim.AdamW(model.parameters(),lr=5e-5) |
|
|
|
with wandb.init(config=None): |
|
config=wandb.config |
|
model.to(device) |
|
|
|
|
|
model.train() |
|
for epoch in range(config.epochs): |
|
loop = tqdm(train_loader, leave=True) |
|
for batch in loop: |
|
optimizer.zero_grad() |
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
start_positions = batch['start_positions'].to(device) |
|
end_positions = batch['end_positions'].to(device) |
|
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions) |
|
loss = outputs[0] |
|
loss.backward() |
|
optimizer.step() |
|
|
|
loop.set_description(f'Epoch {epoch+1}') |
|
loop.set_postfix(loss=loss.item()) |
|
wandb.log({'Validation Loss':loss}) |
|
|
|
|
|
model.eval() |
|
acc=[] |
|
for batch in tqdm(valid_loader): |
|
with torch.no_grad(): |
|
input_ids=batch['input_ids'].to(device) |
|
attention_mask=batch['attention_mask'].to(device) |
|
start_true=batch['start_positions'].to(device) |
|
end_true=batch['end_positions'].to(device) |
|
|
|
outputs=model(input_ids,attention_mask=attention_mask) |
|
|
|
start_pred=torch.argmax(outputs['start_logits'],dim=1) |
|
end_pred=torch.argmax(outputs['end_logits'],dim=1) |
|
|
|
acc.append(((start_pred == start_true).sum()/len(start_pred)).item()) |
|
acc.append(((end_pred == end_true).sum()/len(end_pred)).item()) |
|
|
|
acc = sum(acc)/len(acc) |
|
|
|
print("\n\nT/P\tanswer_start\tanswer_end\n") |
|
for i in range(len(start_true)): |
|
print(f"true\t{start_true[i]}\t{end_true[i]}\n" |
|
f"pred\t{start_pred[i]}\t{end_pred[i]}\n") |
|
wandb.log({'Validation accuracy': acc}) |
|
|
|
|
|
wandb.agent(sweep_id, pipeline, count = 4) |
|
|
|
|
|
"""Save the model so we dont have to train it again""" |
|
model_path = '/content/drive/MyDrive/BERT-SQuAD' |
|
model.save_pretrained(model_path) |
|
tokenizer.save_pretrained(model_path) |
|
|
|
"""Load the model""" |
|
model_path = '/content/drive/MyDrive/BERT-SQuAD' |
|
model = BertForQuestionAnswering.from_pretrained(model_path) |
|
tokenizer = BertTokenizerFast.from_pretrained(model_path) |
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
def get_prediction(context,answer): |
|
inputs=tokenizer.encode_plus(question,context,return_tensors='pt').to(device) |
|
outputs=model(**inputs) |
|
answer_start=torch.argmax(outputs[0]) |
|
answer_end=torch.argmax(outputs[1])+1 |
|
answer = tokenizer.convert_tokens_to_string(tokenizer. |
|
convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) |
|
return answer |
|
|
|
|
|
""" |
|
Question testing |
|
|
|
Official SQuAD evaluation script--> |
|
https://colab.research.google.com/github/fastforwardlabs/ff14_blog/blob/master/_notebooks/2020-06-09-Evaluating_BERT_on_SQuAD.ipynb#scrollTo=MzPlHgWEBQ8D |
|
""" |
|
|
|
def normalize_text(s): |
|
"""Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" |
|
import string, re |
|
def remove_articles(text): |
|
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
|
return re.sub(regex, " ", text) |
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
def exact_match(prediction, truth): |
|
return bool(normalize_text(prediction) == normalize_text(truth)) |
|
|
|
def compute_f1(prediction, truth): |
|
pred_tokens = normalize_text(prediction).split() |
|
truth_tokens = normalize_text(truth).split() |
|
|
|
|
|
if len(pred_tokens) == 0 or len(truth_tokens) == 0: |
|
return int(pred_tokens == truth_tokens) |
|
|
|
common_tokens = set(pred_tokens) & set(truth_tokens) |
|
|
|
|
|
if len(common_tokens) == 0: |
|
return 0 |
|
|
|
prec = len(common_tokens) / len(pred_tokens) |
|
rec = len(common_tokens) / len(truth_tokens) |
|
|
|
return round(2 * (prec * rec) / (prec + rec), 2) |
|
|
|
def question_answer(context, question,answer): |
|
prediction = get_prediction(context,question) |
|
em_score = exact_match(prediction, answer) |
|
f1_score = compute_f1(prediction, answer) |
|
|
|
print(f'Question: {question}') |
|
print(f'Prediction: {prediction}') |
|
print(f'True Answer: {answer}') |
|
print(f'Exact match: {em_score}') |
|
print(f'F1 score: {f1_score}\n') |
|
|
|
context = """Space exploration is a very exciting field of research. It is the |
|
frontier of Physics and no doubt will change the understanding of science. |
|
However, it does come at a cost. A normal space shuttle costs about 1.5 billion dollars to make. |
|
The annual budget of NASA, which is a premier space exploring organization is about 17 billion. |
|
So the question that some people ask is that whether it is worth it.""" |
|
|
|
|
|
questions =["What wil change the understanding of science?", |
|
"What is the main idea in the paragraph?"] |
|
|
|
answers = ["Space Exploration", |
|
"The cost of space exploration is too high"] |
|
|
|
""" |
|
VISUALISATION IN PROGRESS |
|
|
|
for question, answer in zip(questions, answers): |
|
question_answer(context, question, answer) |
|
|
|
#Visualize the start scores |
|
plt.rcParams["figure.figsize"]=(20,10) |
|
ax=sns.barplot(x=token_labels,y=start_scores) |
|
ax.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center") |
|
ax.grid(True) |
|
plt.title("Start word scores") |
|
plt.show() |
|
|
|
#Visualize the end scores |
|
plt.rcParams["figure.figsize"]=(20,10) |
|
ax=sns.barplot(x=token_labels,y=end_scores) |
|
ax.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center") |
|
ax.grid(True) |
|
plt.title("End word scores") |
|
plt.show() |
|
|
|
#Visualize both the scores |
|
scores=[] |
|
for (i,token_label) in enumerate(token_labels): |
|
# Add the token's start score as one row. |
|
scores.append({'token_label':token_label, |
|
'score':start_scores[i], |
|
'marker':'start'}) |
|
|
|
# Add the token's end score as another row. |
|
scores.append({'token_label': token_label, |
|
'score': end_scores[i], |
|
'marker': 'end'}) |
|
|
|
df=pd.DataFrame(scores) |
|
group_plot=sns.catplot(x="token_label",y="score",hue="marker",data=df, |
|
kind="bar",height=6,aspect=4) |
|
|
|
group_plot.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center") |
|
group_plot.ax.grid(True) |
|
""" |
|
|