from transformers import Pipeline import torch from typing import Union class DocumentSentenceRelevancePipeline(Pipeline): def _sanitize_parameters(self, **kwargs): threshold = kwargs.get("threshold", 0.5) return {}, {}, {"threshold": threshold} def preprocess(self, inputs): question = inputs.get("question", "") context = inputs.get("context", [""]) response = inputs.get("response", "") q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False) r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False) question_ids = q_enc["input_ids"] response_ids = r_enc["input_ids"] document_sentences_ids = [] for s in context: s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False) document_sentences_ids.append(s_enc["input_ids"]) ids = question_ids + response_ids pair_ids = [] for s_ids in document_sentences_ids: pair_ids.extend(s_ids) total_length = len(ids) + len(pair_ids) if total_length > self.tokenizer.model_max_length: num_tokens_to_remove = total_length - self.tokenizer.model_max_length ids, pair_ids, _ = self.tokenizer.truncate_sequences( ids=ids, pair_ids=pair_ids, num_tokens_to_remove=num_tokens_to_remove, truncation_strategy="only_second", stride=0, ) combined_ids = ids + pair_ids token_types = [0]*len(ids) + [1]*len(pair_ids) attention_mask = [1]*len(combined_ids) sentence_positions = [] current_pos = len(ids) found_sentences = 0 for i, tok_id in enumerate(pair_ids): if tok_id == self.tokenizer.cls_token_id: sentence_positions.append(current_pos + i) found_sentences += 1 input_ids = torch.tensor([combined_ids], dtype=torch.long) attention_mask = torch.tensor([attention_mask], dtype=torch.long) token_type_ids = torch.tensor([token_types], dtype=torch.long) sentence_positions = torch.tensor([sentence_positions], dtype=torch.long) return { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "sentence_positions": sentence_positions } def _forward(self, model_inputs): return self.model(**model_inputs) def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs): if isinstance(inputs, dict): inputs = [inputs] model_outputs = super().__call__(inputs, **kwargs) pipeline_outputs = [] for i, output in enumerate(model_outputs): sentences = inputs[i]["context"] sentences_dict = { "sentence": sentences, "label": output["sentences"]["label"], "score": output["sentences"]["score"] } # Create the final output structure final_output = { "document": output["document"], "sentences": [ { "sentence": sent, "label": label, "score": score } for sent, label, score in zip( sentences_dict["sentence"], sentences_dict["label"], sentences_dict["score"] ) ] } pipeline_outputs.append(final_output) return pipeline_outputs def postprocess(self, model_outputs, threshold = 0.5): doc_logits = model_outputs.doc_logits sent_logits = model_outputs.sent_logits document_probabilities = torch.softmax(doc_logits, dim=-1) sentence_probabilities = torch.softmax(sent_logits, dim=-1) document_best_class = (document_probabilities[:, 1] > threshold).long() sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long() document_score = document_probabilities[:, document_best_class] sentence_best_class = sentence_best_class.squeeze() sentence_probabilities = sentence_probabilities.squeeze() if len(sentence_best_class.shape) == 0: sentence_best_class = sentence_best_class.unsqueeze(0) sentence_probabilities = sentence_probabilities.unsqueeze(0) batch_indices = torch.arange(len(sentence_best_class)) sentence_scores = sentence_probabilities[batch_indices, sentence_best_class] best_document_label = document_best_class.numpy().item() best_document_label = self.model.config.id2label[best_document_label] best_sentence_labels = sentence_best_class.numpy().tolist() best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels] document_output = {"label": best_document_label, "score": document_score.numpy().item()} sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()} return {"document": document_output, "sentences": sentence_output}