Spaces:
Running
Running
File size: 3,039 Bytes
d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 d6bdb02 1434337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import List, Optional
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import torch
from haystack.nodes.base import BaseComponent
from haystack.modeling.utils import initialize_device_settings
from haystack.schema import Document, Answer, Span
class EntailmentChecker(BaseComponent):
"""
This node checks the entailment between every document content and the query.
It enrichs the documents metadata with entailment_info
"""
outgoing_edges = 1
def __init__(
self,
model_name_or_path: str = "roberta-large-mnli",
model_version: Optional[str] = None,
tokenizer: Optional[str] = None,
use_gpu: bool = True,
batch_size: int = 16,
):
"""
Load a Natural Language Inference model from Transformers.
:param model_name_or_path: Directory of a saved model or the name of a public model.
See https://huggingface.co/models for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: Whether to use GPU (if available).
# :param batch_size: Number of Documents to be processed at a time.
"""
super().__init__()
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
tokenizer = tokenizer or model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version
)
self.batch_size = batch_size
self.model.to(str(self.devices[0]))
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
self.labels = [id2label[k].lower() for k in sorted(id2label)]
if "entailment" not in self.labels:
raise ValueError(
"The model config must contain entailment value in the id2label dict."
)
def run(self, query: str, documents: List[Document]):
for doc in documents:
entailment_dict = self.get_entailment(premise=doc.content, hypotesis=query)
doc.meta["entailment_info"] = entailment_dict
return {"documents": documents}, "output_1"
def run_batch():
pass
def get_entailment(self, premise, hypotesis):
with torch.no_grad():
inputs = self.tokenizer(
f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
).to(self.devices[0])
out = self.model(**inputs)
logits = out.logits
probs = (
torch.nn.functional.softmax(logits, dim=-1)[0, :].cpu().detach().numpy()
)
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
return entailment_dict
|