Spaces:
Running
Running
Commit
·
34f14d6
1
Parent(s):
5e92245
chore: added `run_batch`
Browse files
app_utils/entailment_checker.py
CHANGED
@@ -95,7 +95,21 @@ class EntailmentChecker(BaseComponent):
|
|
95 |
return entailment_checker_result, "output_1"
|
96 |
|
97 |
def run_batch(self, queries: List[str], documents: List[Document]):
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
def get_entailment_dict(self, probs):
|
101 |
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
|
|
|
95 |
return entailment_checker_result, "output_1"
|
96 |
|
97 |
def run_batch(self, queries: List[str], documents: List[Document]):
|
98 |
+
entailment_checker_result_batch = []
|
99 |
+
entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypotesis_batch=queries)
|
100 |
+
for doc, entailment_info in zip(documents, entailment_info_batch):
|
101 |
+
doc.meta["entailment_info"] = entailment_info
|
102 |
+
aggregate_entailment_info = {
|
103 |
+
"contradiction": round(entailment_info["contradiction"] / doc.score),
|
104 |
+
"neutral": round(entailment_info["neutral"] / doc.score),
|
105 |
+
"entailment": round(entailment_info["entailment"] / doc.score),
|
106 |
+
}
|
107 |
+
entailment_checker_result_batch.append({
|
108 |
+
"documents": [doc],
|
109 |
+
"aggregate_entailment_info": aggregate_entailment_info,
|
110 |
+
})
|
111 |
+
return entailment_checker_result_batch, "output_1"
|
112 |
+
|
113 |
|
114 |
def get_entailment_dict(self, probs):
|
115 |
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
|