param-bharat commited on
Commit
cc1fe5d
1 Parent(s): e803a9e

Upload DocumentSentenceRelevancePipeline

Browse files
Files changed (2) hide show
  1. config.json +9 -0
  2. pipeline.py +35 -14
config.json CHANGED
@@ -7,6 +7,15 @@
7
  "AutoModel": "modeling.MultiHeadModel"
8
  },
9
  "classifier_dropout": 0.1,
 
 
 
 
 
 
 
 
 
10
  "encoder_name": "tasksource/deberta-small-long-nli",
11
  "id2label": {
12
  "0": "irrelevant",
 
7
  "AutoModel": "modeling.MultiHeadModel"
8
  },
9
  "classifier_dropout": 0.1,
10
+ "custom_pipelines": {
11
+ "context-relevance": {
12
+ "impl": "pipeline.DocumentSentenceRelevancePipeline",
13
+ "pt": [
14
+ "AutoModel"
15
+ ],
16
+ "tf": []
17
+ }
18
+ },
19
  "encoder_name": "tasksource/deberta-small-long-nli",
20
  "id2label": {
21
  "0": "irrelevant",
pipeline.py CHANGED
@@ -3,15 +3,8 @@ import torch
3
  from typing import Union
4
 
5
 
6
-
7
- def convert_to_list(data):
8
- first_list = next(iter(data.values()))
9
- return [
10
- {key: values[i] for key, values in data.items()}
11
- for i in range(len(first_list))
12
- ]
13
-
14
  class DocumentSentenceRelevancePipeline(Pipeline):
 
15
  def _sanitize_parameters(self, **kwargs):
16
  threshold = kwargs.get("threshold", 0.5)
17
  return {}, {}, {"threshold": threshold}
@@ -82,10 +75,29 @@ class DocumentSentenceRelevancePipeline(Pipeline):
82
  pipeline_outputs = []
83
  for i, output in enumerate(model_outputs):
84
  sentences = inputs[i]["context"]
85
- output["sentences"]["sentence"] = sentences
86
- output['sentences'] = convert_to_list(output['sentences'])
87
- pipeline_outputs.append(output)
88
- return pipeline_outputs if len(pipeline_outputs) > 1 else pipeline_outputs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def postprocess(self, model_outputs, threshold = 0.5):
91
  doc_logits = model_outputs.doc_logits
@@ -96,14 +108,23 @@ class DocumentSentenceRelevancePipeline(Pipeline):
96
  document_best_class = (document_probabilities[:, 1] > threshold).long()
97
  sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long()
98
  document_score = document_probabilities[:, document_best_class]
 
99
  sentence_best_class = sentence_best_class.squeeze()
100
- batch_indices = torch.arange(sentence_probabilities.size(1))
101
- sentence_scores = sentence_probabilities.squeeze()[batch_indices, sentence_best_class]
 
 
 
 
 
 
 
102
  best_document_label = document_best_class.numpy().item()
103
  best_document_label = self.model.config.id2label[best_document_label]
104
 
105
  best_sentence_labels = sentence_best_class.numpy().tolist()
106
  best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels]
 
107
  document_output = {"label": best_document_label, "score": document_score.numpy().item()}
108
  sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()}
109
  return {"document": document_output, "sentences": sentence_output}
 
3
  from typing import Union
4
 
5
 
 
 
 
 
 
 
 
 
6
  class DocumentSentenceRelevancePipeline(Pipeline):
7
+
8
  def _sanitize_parameters(self, **kwargs):
9
  threshold = kwargs.get("threshold", 0.5)
10
  return {}, {}, {"threshold": threshold}
 
75
  pipeline_outputs = []
76
  for i, output in enumerate(model_outputs):
77
  sentences = inputs[i]["context"]
78
+ sentences_dict = {
79
+ "sentence": sentences,
80
+ "label": output["sentences"]["label"],
81
+ "score": output["sentences"]["score"]
82
+ }
83
+ # Create the final output structure
84
+ final_output = {
85
+ "document": output["document"],
86
+ "sentences": [
87
+ {
88
+ "sentence": sent,
89
+ "label": label,
90
+ "score": score
91
+ }
92
+ for sent, label, score in zip(
93
+ sentences_dict["sentence"],
94
+ sentences_dict["label"],
95
+ sentences_dict["score"]
96
+ )
97
+ ]
98
+ }
99
+ pipeline_outputs.append(final_output)
100
+ return pipeline_outputs
101
 
102
  def postprocess(self, model_outputs, threshold = 0.5):
103
  doc_logits = model_outputs.doc_logits
 
108
  document_best_class = (document_probabilities[:, 1] > threshold).long()
109
  sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long()
110
  document_score = document_probabilities[:, document_best_class]
111
+
112
  sentence_best_class = sentence_best_class.squeeze()
113
+ sentence_probabilities = sentence_probabilities.squeeze()
114
+
115
+ if len(sentence_best_class.shape) == 0:
116
+ sentence_best_class = sentence_best_class.unsqueeze(0)
117
+ sentence_probabilities = sentence_probabilities.unsqueeze(0)
118
+
119
+ batch_indices = torch.arange(len(sentence_best_class))
120
+ sentence_scores = sentence_probabilities[batch_indices, sentence_best_class]
121
+
122
  best_document_label = document_best_class.numpy().item()
123
  best_document_label = self.model.config.id2label[best_document_label]
124
 
125
  best_sentence_labels = sentence_best_class.numpy().tolist()
126
  best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels]
127
+
128
  document_output = {"label": best_document_label, "score": document_score.numpy().item()}
129
  sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()}
130
  return {"document": document_output, "sentences": sentence_output}