Spaces:
Sleeping
Sleeping
File size: 4,298 Bytes
75148a1 76ce883 75148a1 76ce883 75148a1 76ce883 75148a1 1a11e20 75148a1 76ce883 75148a1 76ce883 75148a1 76ce883 75148a1 1a11e20 75148a1 b750381 75148a1 76ce883 75148a1 76ce883 75148a1 76ce883 75148a1 76ce883 75148a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from typing import List, Tuple, TypedDict
from re import sub
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging
from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader
from transformers import QuestionAnsweringPipeline
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
import torch
cuda = torch.cuda.is_available()
max_answer_len = 8
logging.set_verbosity_error()
@torch.inference_mode()
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
input_texts: List[str]):
inputs = tokenizer(input_texts, padding=True,
return_tensors='pt', truncation=True)
if cuda:
inputs = inputs.to(0)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
summary_ids = model.generate(inputs["input_ids"])
else:
summary_ids = model.generate(inputs["input_ids"])
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False, batch_size=len(input_texts))
return summaries
def get_summarizer(model_id="seonglae/resrer-pegasus-x") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
model = PegasusXForConditionalGeneration.from_pretrained(model_id)
if cuda:
model = model.to(0)
model = torch.compile(model)
return tokenizer, model
class AnswerInfo(TypedDict):
score: float
start: int
end: int
answer: str
@torch.inference_mode()
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
if cuda:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
pipeline = QuestionAnsweringPipeline(
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
answer_infos: List[AnswerInfo] = pipeline(
question=questions, context=ctxs)
else:
pipeline = QuestionAnsweringPipeline(
model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
answer_infos = pipeline(
question=questions, context=ctxs)
if not isinstance(answer_infos, list):
answer_infos = [answer_infos]
for answer_info in answer_infos:
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
return answer_infos
def get_reader(model_id="facebook/dpr-reader-single-nq-base"):
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
model = DPRReader.from_pretrained(model_id)
if cuda:
model = model.to(0)
return tokenizer, model
@torch.inference_mode()
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
"""Encode a question using DPR question encoder.
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
Args:
question (str): question string to encode
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
"""
batch_dict = tokenizer(questions, return_tensors="pt",
padding=True, truncation=True)
if cuda:
batch_dict = batch_dict.to(0)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
else:
embeddings = model(**batch_dict).pooler_output
return embeddings
def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]:
"""Encode a question using DPR question encoder.
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
Args:
question (str): question string to encode
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
"""
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
model = DPRQuestionEncoder.from_pretrained(model_id)
if cuda:
model = model.to(0)
return tokenizer, model
|