param-bharat's picture
Upload DocumentSentenceRelevancePipeline
cc1fe5d verified
from transformers import Pipeline
import torch
from typing import Union
class DocumentSentenceRelevancePipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
threshold = kwargs.get("threshold", 0.5)
return {}, {}, {"threshold": threshold}
def preprocess(self, inputs):
question = inputs.get("question", "")
context = inputs.get("context", [""])
response = inputs.get("response", "")
q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False)
r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False)
question_ids = q_enc["input_ids"]
response_ids = r_enc["input_ids"]
document_sentences_ids = []
for s in context:
s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False)
document_sentences_ids.append(s_enc["input_ids"])
ids = question_ids + response_ids
pair_ids = []
for s_ids in document_sentences_ids:
pair_ids.extend(s_ids)
total_length = len(ids) + len(pair_ids)
if total_length > self.tokenizer.model_max_length:
num_tokens_to_remove = total_length - self.tokenizer.model_max_length
ids, pair_ids, _ = self.tokenizer.truncate_sequences(
ids=ids,
pair_ids=pair_ids,
num_tokens_to_remove=num_tokens_to_remove,
truncation_strategy="only_second",
stride=0,
)
combined_ids = ids + pair_ids
token_types = [0]*len(ids) + [1]*len(pair_ids)
attention_mask = [1]*len(combined_ids)
sentence_positions = []
current_pos = len(ids)
found_sentences = 0
for i, tok_id in enumerate(pair_ids):
if tok_id == self.tokenizer.cls_token_id:
sentence_positions.append(current_pos + i)
found_sentences += 1
input_ids = torch.tensor([combined_ids], dtype=torch.long)
attention_mask = torch.tensor([attention_mask], dtype=torch.long)
token_type_ids = torch.tensor([token_types], dtype=torch.long)
sentence_positions = torch.tensor([sentence_positions], dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"sentence_positions": sentence_positions
}
def _forward(self, model_inputs):
return self.model(**model_inputs)
def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs):
if isinstance(inputs, dict):
inputs = [inputs]
model_outputs = super().__call__(inputs, **kwargs)
pipeline_outputs = []
for i, output in enumerate(model_outputs):
sentences = inputs[i]["context"]
sentences_dict = {
"sentence": sentences,
"label": output["sentences"]["label"],
"score": output["sentences"]["score"]
}
# Create the final output structure
final_output = {
"document": output["document"],
"sentences": [
{
"sentence": sent,
"label": label,
"score": score
}
for sent, label, score in zip(
sentences_dict["sentence"],
sentences_dict["label"],
sentences_dict["score"]
)
]
}
pipeline_outputs.append(final_output)
return pipeline_outputs
def postprocess(self, model_outputs, threshold = 0.5):
doc_logits = model_outputs.doc_logits
sent_logits = model_outputs.sent_logits
document_probabilities = torch.softmax(doc_logits, dim=-1)
sentence_probabilities = torch.softmax(sent_logits, dim=-1)
document_best_class = (document_probabilities[:, 1] > threshold).long()
sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long()
document_score = document_probabilities[:, document_best_class]
sentence_best_class = sentence_best_class.squeeze()
sentence_probabilities = sentence_probabilities.squeeze()
if len(sentence_best_class.shape) == 0:
sentence_best_class = sentence_best_class.unsqueeze(0)
sentence_probabilities = sentence_probabilities.unsqueeze(0)
batch_indices = torch.arange(len(sentence_best_class))
sentence_scores = sentence_probabilities[batch_indices, sentence_best_class]
best_document_label = document_best_class.numpy().item()
best_document_label = self.model.config.id2label[best_document_label]
best_sentence_labels = sentence_best_class.numpy().tolist()
best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels]
document_output = {"label": best_document_label, "score": document_score.numpy().item()}
sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()}
return {"document": document_output, "sentences": sentence_output}