from typing import TypedDict, List, Dict from re import sub import torch import numpy as np from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader, logging from transformers import QuestionAnsweringPipeline max_answer_len = 8 logging.set_verbosity_error() class AnswerInfo(TypedDict): score: float start: int end: int answer: str @torch.inference_mode() def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering, questions: List[str], ctxs: List[str]) -> List[AnswerInfo]: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): pipeline = QuestionAnsweringPipeline( model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len) answer_infos: List[AnswerInfo] = pipeline( question=questions, context=ctxs) for answer_info in answer_infos: answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer']) return answer_infos def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"): tokenizer = DPRReaderTokenizer.from_pretrained(model_id) model = DPRReader.from_pretrained(model_id).to(0) return tokenizer, model