from typing import List, Tuple, TypedDict from re import sub from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader from transformers import QuestionAnsweringPipeline from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast import torch cuda = torch.cuda.is_available() max_answer_len = 8 logging.set_verbosity_error() @torch.inference_mode() def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration, input_texts: List[str]): inputs = tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True) if cuda: inputs = inputs.to(0) with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): summary_ids = model.generate(inputs["input_ids"]) else: summary_ids = model.generate(inputs["input_ids"]) summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, batch_size=len(input_texts)) return summaries def get_summarizer(model_id="seonglae/resrer-pegasus-x") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]: tokenizer = PegasusTokenizerFast.from_pretrained(model_id) model = PegasusXForConditionalGeneration.from_pretrained(model_id) if cuda: model = model.to(0) model = torch.compile(model) return tokenizer, model class AnswerInfo(TypedDict): score: float start: int end: int answer: str @torch.inference_mode() def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering, questions: List[str], ctxs: List[str]) -> List[AnswerInfo]: if cuda: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): pipeline = QuestionAnsweringPipeline( model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len) answer_infos: List[AnswerInfo] = pipeline( question=questions, context=ctxs) else: pipeline = QuestionAnsweringPipeline( model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len) answer_infos = pipeline( question=questions, context=ctxs) if not isinstance(answer_infos, list): answer_infos = [answer_infos] for answer_info in answer_infos: answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer']) return answer_infos def get_reader(model_id="facebook/dpr-reader-single-nq-base"): tokenizer = DPRReaderTokenizer.from_pretrained(model_id) model = DPRReader.from_pretrained(model_id) if cuda: model = model.to(0) return tokenizer, model @torch.inference_mode() def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor: """Encode a question using DPR question encoder. https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder Args: question (str): question string to encode model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base """ batch_dict = tokenizer(questions, return_tensors="pt", padding=True, truncation=True) if cuda: batch_dict = batch_dict.to(0) with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): embeddings: torch.FloatTensor = model(**batch_dict).pooler_output else: embeddings = model(**batch_dict).pooler_output return embeddings def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]: """Encode a question using DPR question encoder. https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder Args: question (str): question string to encode model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base """ tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id) model = DPRQuestionEncoder.from_pretrained(model_id) if cuda: model = model.to(0) return tokenizer, model