|
from transformers import Pipeline |
|
import torch |
|
from typing import Union |
|
|
|
|
|
class DocumentSentenceRelevancePipeline(Pipeline): |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
threshold = kwargs.get("threshold", 0.5) |
|
return {}, {}, {"threshold": threshold} |
|
|
|
def preprocess(self, inputs): |
|
question = inputs.get("question", "") |
|
context = inputs.get("context", [""]) |
|
response = inputs.get("response", "") |
|
|
|
q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False) |
|
r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False) |
|
|
|
question_ids = q_enc["input_ids"] |
|
response_ids = r_enc["input_ids"] |
|
|
|
document_sentences_ids = [] |
|
for s in context: |
|
s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False) |
|
document_sentences_ids.append(s_enc["input_ids"]) |
|
|
|
ids = question_ids + response_ids |
|
pair_ids = [] |
|
for s_ids in document_sentences_ids: |
|
pair_ids.extend(s_ids) |
|
|
|
total_length = len(ids) + len(pair_ids) |
|
if total_length > self.tokenizer.model_max_length: |
|
num_tokens_to_remove = total_length - self.tokenizer.model_max_length |
|
ids, pair_ids, _ = self.tokenizer.truncate_sequences( |
|
ids=ids, |
|
pair_ids=pair_ids, |
|
num_tokens_to_remove=num_tokens_to_remove, |
|
truncation_strategy="only_second", |
|
stride=0, |
|
) |
|
combined_ids = ids + pair_ids |
|
token_types = [0]*len(ids) + [1]*len(pair_ids) |
|
attention_mask = [1]*len(combined_ids) |
|
|
|
sentence_positions = [] |
|
current_pos = len(ids) |
|
found_sentences = 0 |
|
|
|
for i, tok_id in enumerate(pair_ids): |
|
if tok_id == self.tokenizer.cls_token_id: |
|
sentence_positions.append(current_pos + i) |
|
found_sentences += 1 |
|
|
|
input_ids = torch.tensor([combined_ids], dtype=torch.long) |
|
attention_mask = torch.tensor([attention_mask], dtype=torch.long) |
|
token_type_ids = torch.tensor([token_types], dtype=torch.long) |
|
sentence_positions = torch.tensor([sentence_positions], dtype=torch.long) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
"sentence_positions": sentence_positions |
|
} |
|
|
|
def _forward(self, model_inputs): |
|
return self.model(**model_inputs) |
|
|
|
def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs): |
|
if isinstance(inputs, dict): |
|
inputs = [inputs] |
|
model_outputs = super().__call__(inputs, **kwargs) |
|
pipeline_outputs = [] |
|
for i, output in enumerate(model_outputs): |
|
sentences = inputs[i]["context"] |
|
sentences_dict = { |
|
"sentence": sentences, |
|
"label": output["sentences"]["label"], |
|
"score": output["sentences"]["score"] |
|
} |
|
|
|
final_output = { |
|
"document": output["document"], |
|
"sentences": [ |
|
{ |
|
"sentence": sent, |
|
"label": label, |
|
"score": score |
|
} |
|
for sent, label, score in zip( |
|
sentences_dict["sentence"], |
|
sentences_dict["label"], |
|
sentences_dict["score"] |
|
) |
|
] |
|
} |
|
pipeline_outputs.append(final_output) |
|
return pipeline_outputs |
|
|
|
def postprocess(self, model_outputs, threshold = 0.5): |
|
doc_logits = model_outputs.doc_logits |
|
sent_logits = model_outputs.sent_logits |
|
document_probabilities = torch.softmax(doc_logits, dim=-1) |
|
sentence_probabilities = torch.softmax(sent_logits, dim=-1) |
|
|
|
document_best_class = (document_probabilities[:, 1] > threshold).long() |
|
sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long() |
|
document_score = document_probabilities[:, document_best_class] |
|
|
|
sentence_best_class = sentence_best_class.squeeze() |
|
sentence_probabilities = sentence_probabilities.squeeze() |
|
|
|
if len(sentence_best_class.shape) == 0: |
|
sentence_best_class = sentence_best_class.unsqueeze(0) |
|
sentence_probabilities = sentence_probabilities.unsqueeze(0) |
|
|
|
batch_indices = torch.arange(len(sentence_best_class)) |
|
sentence_scores = sentence_probabilities[batch_indices, sentence_best_class] |
|
|
|
best_document_label = document_best_class.numpy().item() |
|
best_document_label = self.model.config.id2label[best_document_label] |
|
|
|
best_sentence_labels = sentence_best_class.numpy().tolist() |
|
best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels] |
|
|
|
document_output = {"label": best_document_label, "score": document_score.numpy().item()} |
|
sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()} |
|
return {"document": document_output, "sentences": sentence_output} |