from typing import Dict, List, Any | |
from transformers import BertForQuestionAnswering, BertTokenizer | |
import torch | |
# set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# def print_tokens_with_ids(tokenizer, input_ids): | |
# # BERT only needs the token IDs, but for the purpose of inspecting the | |
# # tokenizer's behavior, let's also get the token strings and display them. | |
# tokens = tokenizer.convert_ids_to_tokens(input_ids) | |
# # For each token and its id... | |
# for token, id in zip(tokens, input_ids): | |
# # If this is the [SEP] token, add some space around it to make it stand out. | |
# if id == tokenizer.sep_token_id: | |
# print('') | |
# # Print the token string and its ID in two columns. | |
# print('{:<12} {:>6,}'.format(token, id)) | |
# if id == tokenizer.sep_token_id: | |
# print('') | |
def get_segment_ids_aka_token_type_ids(tokenizer, input_ids): | |
# Search the input_ids for the first instance of the `[SEP]` token. | |
sep_index = input_ids.index(tokenizer.sep_token_id) | |
# The number of segment A tokens includes the [SEP] token istelf. | |
num_seg_a = sep_index + 1 | |
# The remainder are segment B. | |
num_seg_b = len(input_ids) - num_seg_a | |
# Construct the list of 0s and 1s. | |
segment_ids = [0]*num_seg_a + [1]*num_seg_b | |
# There should be a segment_id for every input token. | |
assert len(segment_ids) == len(input_ids), \ | |
'There should be a segment_id for every input token.' | |
return segment_ids | |
def to_model( | |
model: BertForQuestionAnswering, | |
input_ids, | |
segment_ids | |
) -> tuple: | |
# Run input through the model. | |
output = model( | |
torch.tensor([input_ids]), # The tokens representing our input text. | |
token_type_ids=torch.tensor([segment_ids]) | |
) | |
# print(output) | |
# print(output.start_logits) | |
# print(output.end_logits) | |
# print(type(output)) | |
# The segment IDs to differentiate question from answer_text | |
return output.start_logits, output.end_logits | |
#output.hidden_states | |
#output.attentions | |
#output.loss | |
def get_answer( | |
start_scores, | |
end_scores, | |
input_ids, | |
tokenizer: BertTokenizer | |
) -> str: | |
'''Side Note: | |
- It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?! | |
- The correct implementation is to pick the highest total score for which end >= start. | |
''' | |
# Find the tokens with the highest `start` and `end` scores. | |
answer_start = torch.argmax(start_scores) | |
answer_end = torch.argmax(end_scores) | |
# Combine the tokens in the answer and print it out. | |
# answer = ' '.join(tokens[answer_start:answer_end + 1]) | |
# Get the string versions of the input tokens. | |
tokens = tokenizer.convert_ids_to_tokens(input_ids) | |
# Start with the first token. | |
answer = tokens[answer_start] | |
# print('Answer: "' + answer + '"') | |
# Select the remaining answer tokens and join them with whitespace. | |
for i in range(answer_start + 1, answer_end + 1): | |
# If it's a subword token, then recombine it with the previous token. | |
if tokens[i][0:2] == '##': | |
answer += tokens[i][2:] | |
# Otherwise, add a space then the token. | |
else: | |
answer += ' ' + tokens[i] | |
return answer | |
# def resonstruct_words(tokens, answer_start, answer_end): | |
# '''reconstruct any words that got broken down into subwords. | |
# ''' | |
# # Start with the first token. | |
# answer = tokens[answer_start] | |
# # Select the remaining answer tokens and join them with whitespace. | |
# for i in range(answer_start + 1, answer_end + 1): | |
# # If it's a subword token, then recombine it with the previous token. | |
# if tokens[i][0:2] == '##': | |
# answer += tokens[i][2:] | |
# # Otherwise, add a space then the token. | |
# else: | |
# answer += ' ' + tokens[i] | |
# print('Answer: "' + answer + '"') | |
class EndpointHandler: | |
def __init__(self, path=""): | |
self.model = BertForQuestionAnswering.from_pretrained(path).to(device) | |
self.tokenizer = BertTokenizer.from_pretrained(path) | |
def __call__( | |
self, | |
data: Dict[str, str | bytes] | |
): | |
""" | |
Args: | |
data (:obj:): | |
includes the deserialized image file as PIL.Image | |
""" | |
question = data.pop("question", data) | |
context = data.pop("context", data) | |
input_ids = self.tokenizer.encode(question, context) | |
# print('The input has a total of {:} tokens.'.format(len(input_ids))) | |
segment_ids = get_segment_ids_aka_token_type_ids( | |
self.tokenizer, | |
input_ids | |
) | |
# run prediction | |
with torch.inference_mode(): | |
start_scores, end_scores = to_model( | |
self.model, | |
input_ids, | |
segment_ids | |
) | |
answer = get_answer( | |
start_scores, | |
end_scores, | |
input_ids, | |
self.tokenizer | |
) | |
return answer | |