resrer-demo / model.py
seonglae's picture
fix: change reader to dpr natural questions
b750381
raw
history blame
4.21 kB
from typing import List, Tuple, TypedDict
from re import sub
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging
from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader
from transformers import QuestionAnsweringPipeline
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
import torch
cuda = torch.cuda.is_available()
max_answer_len = 8
logging.set_verbosity_error()
@torch.inference_mode()
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
input_texts: List[str]):
inputs = tokenizer(input_texts, padding=True,
return_tensors='pt', truncation=True)
if cuda:
inputs = inputs.to(0)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
summary_ids = model.generate(inputs["input_ids"])
else:
summary_ids = model.generate(inputs["input_ids"])
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False, batch_size=len(input_texts))
return summaries
def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
model = PegasusXForConditionalGeneration.from_pretrained(model_id)
if cuda:
model = model.to(0)
model = torch.compile(model)
return tokenizer, model
class AnswerInfo(TypedDict):
score: float
start: int
end: int
answer: str
@torch.inference_mode()
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
if cuda:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
pipeline = QuestionAnsweringPipeline(
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
answer_infos: List[AnswerInfo] = pipeline(
question=questions, context=ctxs)
else:
pipeline = QuestionAnsweringPipeline(
model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
answer_infos = pipeline(
question=questions, context=ctxs)
for answer_info in answer_infos:
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
return answer_infos
def get_reader(model_id="facebook/dpr-reader-single-nq-base"):
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
model = DPRReader.from_pretrained(model_id)
if cuda:
model = model.to(0)
return tokenizer, model
@torch.inference_mode()
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
"""Encode a question using DPR question encoder.
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
Args:
question (str): question string to encode
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
"""
batch_dict = tokenizer(questions, return_tensors="pt",
padding=True, truncation=True)
if cuda:
batch_dict = batch_dict.to(0)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
else:
embeddings = model(**batch_dict).pooler_output
return embeddings
def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]:
"""Encode a question using DPR question encoder.
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
Args:
question (str): question string to encode
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
"""
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
model = DPRQuestionEncoder.from_pretrained(model_id)
if cuda:
model = model.to(0)
return tokenizer, model