File size: 5,391 Bytes
347d909
 
 
 
 
 
95b1312
347d909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b1312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347d909
 
 
 
 
 
 
 
 
 
95b1312
347d909
95b1312
 
 
 
 
 
 
 
 
347d909
 
 
 
 
95b1312
347d909
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from transformers import Pipeline
import torch
from typing import Union


class DocumentSentenceRelevancePipeline(Pipeline):
    
    def _sanitize_parameters(self, **kwargs):
        threshold = kwargs.get("threshold", 0.5)
        return {}, {}, {"threshold": threshold}

    def preprocess(self, inputs):
        question = inputs.get("question", "")
        context = inputs.get("context", [""])
        response = inputs.get("response", "")

        q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False)
        r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False)

        question_ids = q_enc["input_ids"]
        response_ids = r_enc["input_ids"]

        document_sentences_ids = []
        for s in context:
            s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False)
            document_sentences_ids.append(s_enc["input_ids"])

        ids = question_ids + response_ids
        pair_ids = []
        for s_ids in document_sentences_ids:
            pair_ids.extend(s_ids)

        total_length = len(ids) + len(pair_ids)
        if total_length > self.tokenizer.model_max_length:
            num_tokens_to_remove = total_length - self.tokenizer.model_max_length
            ids, pair_ids, _ = self.tokenizer.truncate_sequences(
                ids=ids,
                pair_ids=pair_ids,
                num_tokens_to_remove=num_tokens_to_remove,
                truncation_strategy="only_second",
                stride=0,
            )
        combined_ids = ids + pair_ids
        token_types = [0]*len(ids) + [1]*len(pair_ids)
        attention_mask = [1]*len(combined_ids)

        sentence_positions = []
        current_pos = len(ids)
        found_sentences = 0

        for i, tok_id in enumerate(pair_ids):
            if tok_id == self.tokenizer.cls_token_id:
                sentence_positions.append(current_pos + i)
                found_sentences += 1

        input_ids = torch.tensor([combined_ids], dtype=torch.long)
        attention_mask = torch.tensor([attention_mask], dtype=torch.long)
        token_type_ids = torch.tensor([token_types], dtype=torch.long)
        sentence_positions = torch.tensor([sentence_positions], dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "sentence_positions": sentence_positions
        }

    def _forward(self, model_inputs):
        return self.model(**model_inputs)

    def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs):
        if isinstance(inputs, dict):
            inputs = [inputs]
        model_outputs = super().__call__(inputs, **kwargs)
        pipeline_outputs = []
        for i, output in enumerate(model_outputs):
            sentences = inputs[i]["context"]
            sentences_dict = {
                "sentence": sentences,
                "label": output["sentences"]["label"],
                "score": output["sentences"]["score"]
            }
            # Create the final output structure
            final_output = {
                "document": output["document"],
                "sentences": [
                    {
                        "sentence": sent,
                        "label": label,
                        "score": score
                    }
                    for sent, label, score in zip(
                        sentences_dict["sentence"],
                        sentences_dict["label"],
                        sentences_dict["score"]
                    )
                ]
            }
            pipeline_outputs.append(final_output)
        return pipeline_outputs

    def postprocess(self, model_outputs, threshold = 0.5):
        doc_logits = model_outputs.doc_logits
        sent_logits = model_outputs.sent_logits
        document_probabilities = torch.softmax(doc_logits, dim=-1)
        sentence_probabilities = torch.softmax(sent_logits, dim=-1)
        
        document_best_class = (document_probabilities[:, 1] > threshold).long()
        sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long()
        document_score = document_probabilities[:, document_best_class]
        
        sentence_best_class = sentence_best_class.squeeze()
        sentence_probabilities = sentence_probabilities.squeeze()
        
        if len(sentence_best_class.shape) == 0:
            sentence_best_class = sentence_best_class.unsqueeze(0)
            sentence_probabilities = sentence_probabilities.unsqueeze(0)
        
        batch_indices = torch.arange(len(sentence_best_class))
        sentence_scores = sentence_probabilities[batch_indices, sentence_best_class]
        
        best_document_label = document_best_class.numpy().item()
        best_document_label = self.model.config.id2label[best_document_label]

        best_sentence_labels = sentence_best_class.numpy().tolist()
        best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels]
        
        document_output = {"label": best_document_label, "score": document_score.numpy().item()}
        sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()}
        return {"document": document_output, "sentences": sentence_output}