Alexander Slessor
added handler
c0a3632
raw
history blame
5.15 kB
from typing import Dict, List, Any
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# def print_tokens_with_ids(tokenizer, input_ids):
# # BERT only needs the token IDs, but for the purpose of inspecting the
# # tokenizer's behavior, let's also get the token strings and display them.
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# # For each token and its id...
# for token, id in zip(tokens, input_ids):
# # If this is the [SEP] token, add some space around it to make it stand out.
# if id == tokenizer.sep_token_id:
# print('')
# # Print the token string and its ID in two columns.
# print('{:<12} {:>6,}'.format(token, id))
# if id == tokenizer.sep_token_id:
# print('')
def get_segment_ids_aka_token_type_ids(tokenizer, input_ids):
# Search the input_ids for the first instance of the `[SEP]` token.
sep_index = input_ids.index(tokenizer.sep_token_id)
# The number of segment A tokens includes the [SEP] token istelf.
num_seg_a = sep_index + 1
# The remainder are segment B.
num_seg_b = len(input_ids) - num_seg_a
# Construct the list of 0s and 1s.
segment_ids = [0]*num_seg_a + [1]*num_seg_b
# There should be a segment_id for every input token.
assert len(segment_ids) == len(input_ids), \
'There should be a segment_id for every input token.'
return segment_ids
def to_model(
model: BertForQuestionAnswering,
input_ids,
segment_ids
) -> tuple:
# Run input through the model.
output = model(
torch.tensor([input_ids]), # The tokens representing our input text.
token_type_ids=torch.tensor([segment_ids])
)
# print(output)
# print(output.start_logits)
# print(output.end_logits)
# print(type(output))
# The segment IDs to differentiate question from answer_text
return output.start_logits, output.end_logits
#output.hidden_states
#output.attentions
#output.loss
def get_answer(
start_scores,
end_scores,
input_ids,
tokenizer: BertTokenizer
) -> str:
'''Side Note:
- It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?!
- The correct implementation is to pick the highest total score for which end >= start.
'''
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Combine the tokens in the answer and print it out.
# answer = ' '.join(tokens[answer_start:answer_end + 1])
# Get the string versions of the input tokens.
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Start with the first token.
answer = tokens[answer_start]
# print('Answer: "' + answer + '"')
# Select the remaining answer tokens and join them with whitespace.
for i in range(answer_start + 1, answer_end + 1):
# If it's a subword token, then recombine it with the previous token.
if tokens[i][0:2] == '##':
answer += tokens[i][2:]
# Otherwise, add a space then the token.
else:
answer += ' ' + tokens[i]
return answer
# def resonstruct_words(tokens, answer_start, answer_end):
# '''reconstruct any words that got broken down into subwords.
# '''
# # Start with the first token.
# answer = tokens[answer_start]
# # Select the remaining answer tokens and join them with whitespace.
# for i in range(answer_start + 1, answer_end + 1):
# # If it's a subword token, then recombine it with the previous token.
# if tokens[i][0:2] == '##':
# answer += tokens[i][2:]
# # Otherwise, add a space then the token.
# else:
# answer += ' ' + tokens[i]
# print('Answer: "' + answer + '"')
class EndpointHandler:
def __init__(self, path=""):
self.model = BertForQuestionAnswering.from_pretrained(path).to(device)
self.tokenizer = BertTokenizer.from_pretrained(path)
def __call__(
self,
data: Dict[str, str | bytes]
):
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
question = data.pop("question", data)
context = data.pop("context", data)
input_ids = self.tokenizer.encode(question, context)
# print('The input has a total of {:} tokens.'.format(len(input_ids)))
segment_ids = get_segment_ids_aka_token_type_ids(
self.tokenizer,
input_ids
)
# run prediction
with torch.inference_mode():
start_scores, end_scores = to_model(
self.model,
input_ids,
segment_ids
)
answer = get_answer(
start_scores,
end_scores,
input_ids,
self.tokenizer
)
return answer