File size: 5,391 Bytes
347d909 95b1312 347d909 95b1312 347d909 95b1312 347d909 95b1312 347d909 95b1312 347d909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from transformers import Pipeline
import torch
from typing import Union
class DocumentSentenceRelevancePipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
threshold = kwargs.get("threshold", 0.5)
return {}, {}, {"threshold": threshold}
def preprocess(self, inputs):
question = inputs.get("question", "")
context = inputs.get("context", [""])
response = inputs.get("response", "")
q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False)
r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False)
question_ids = q_enc["input_ids"]
response_ids = r_enc["input_ids"]
document_sentences_ids = []
for s in context:
s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False)
document_sentences_ids.append(s_enc["input_ids"])
ids = question_ids + response_ids
pair_ids = []
for s_ids in document_sentences_ids:
pair_ids.extend(s_ids)
total_length = len(ids) + len(pair_ids)
if total_length > self.tokenizer.model_max_length:
num_tokens_to_remove = total_length - self.tokenizer.model_max_length
ids, pair_ids, _ = self.tokenizer.truncate_sequences(
ids=ids,
pair_ids=pair_ids,
num_tokens_to_remove=num_tokens_to_remove,
truncation_strategy="only_second",
stride=0,
)
combined_ids = ids + pair_ids
token_types = [0]*len(ids) + [1]*len(pair_ids)
attention_mask = [1]*len(combined_ids)
sentence_positions = []
current_pos = len(ids)
found_sentences = 0
for i, tok_id in enumerate(pair_ids):
if tok_id == self.tokenizer.cls_token_id:
sentence_positions.append(current_pos + i)
found_sentences += 1
input_ids = torch.tensor([combined_ids], dtype=torch.long)
attention_mask = torch.tensor([attention_mask], dtype=torch.long)
token_type_ids = torch.tensor([token_types], dtype=torch.long)
sentence_positions = torch.tensor([sentence_positions], dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"sentence_positions": sentence_positions
}
def _forward(self, model_inputs):
return self.model(**model_inputs)
def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs):
if isinstance(inputs, dict):
inputs = [inputs]
model_outputs = super().__call__(inputs, **kwargs)
pipeline_outputs = []
for i, output in enumerate(model_outputs):
sentences = inputs[i]["context"]
sentences_dict = {
"sentence": sentences,
"label": output["sentences"]["label"],
"score": output["sentences"]["score"]
}
# Create the final output structure
final_output = {
"document": output["document"],
"sentences": [
{
"sentence": sent,
"label": label,
"score": score
}
for sent, label, score in zip(
sentences_dict["sentence"],
sentences_dict["label"],
sentences_dict["score"]
)
]
}
pipeline_outputs.append(final_output)
return pipeline_outputs
def postprocess(self, model_outputs, threshold = 0.5):
doc_logits = model_outputs.doc_logits
sent_logits = model_outputs.sent_logits
document_probabilities = torch.softmax(doc_logits, dim=-1)
sentence_probabilities = torch.softmax(sent_logits, dim=-1)
document_best_class = (document_probabilities[:, 1] > threshold).long()
sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long()
document_score = document_probabilities[:, document_best_class]
sentence_best_class = sentence_best_class.squeeze()
sentence_probabilities = sentence_probabilities.squeeze()
if len(sentence_best_class.shape) == 0:
sentence_best_class = sentence_best_class.unsqueeze(0)
sentence_probabilities = sentence_probabilities.unsqueeze(0)
batch_indices = torch.arange(len(sentence_best_class))
sentence_scores = sentence_probabilities[batch_indices, sentence_best_class]
best_document_label = document_best_class.numpy().item()
best_document_label = self.model.config.id2label[best_document_label]
best_sentence_labels = sentence_best_class.numpy().tolist()
best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels]
document_output = {"label": best_document_label, "score": document_score.numpy().item()}
sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()}
return {"document": document_output, "sentences": sentence_output} |