davidberenstein1957 HF staff commited on
Commit
34f14d6
·
1 Parent(s): 5e92245

chore: added `run_batch`

Browse files
Files changed (1) hide show
  1. app_utils/entailment_checker.py +15 -1
app_utils/entailment_checker.py CHANGED
@@ -95,7 +95,21 @@ class EntailmentChecker(BaseComponent):
95
  return entailment_checker_result, "output_1"
96
 
97
  def run_batch(self, queries: List[str], documents: List[Document]):
98
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def get_entailment_dict(self, probs):
101
  entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
 
95
  return entailment_checker_result, "output_1"
96
 
97
  def run_batch(self, queries: List[str], documents: List[Document]):
98
+ entailment_checker_result_batch = []
99
+ entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypotesis_batch=queries)
100
+ for doc, entailment_info in zip(documents, entailment_info_batch):
101
+ doc.meta["entailment_info"] = entailment_info
102
+ aggregate_entailment_info = {
103
+ "contradiction": round(entailment_info["contradiction"] / doc.score),
104
+ "neutral": round(entailment_info["neutral"] / doc.score),
105
+ "entailment": round(entailment_info["entailment"] / doc.score),
106
+ }
107
+ entailment_checker_result_batch.append({
108
+ "documents": [doc],
109
+ "aggregate_entailment_info": aggregate_entailment_info,
110
+ })
111
+ return entailment_checker_result_batch, "output_1"
112
+
113
 
114
  def get_entailment_dict(self, probs):
115
  entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}