File size: 1,238 Bytes
a9082f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from typing import TypedDict, List, Dict
from re import sub
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader, logging
from transformers import QuestionAnsweringPipeline
max_answer_len = 8
logging.set_verbosity_error()
class AnswerInfo(TypedDict):
score: float
start: int
end: int
answer: str
@torch.inference_mode()
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
pipeline = QuestionAnsweringPipeline(
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
answer_infos: List[AnswerInfo] = pipeline(
question=questions, context=ctxs)
for answer_info in answer_infos:
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
return answer_infos
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
model = DPRReader.from_pretrained(model_id).to(0)
return tokenizer, model
|