Liyan06 commited on
Commit
8c6cca8
·
1 Parent(s): 87eb2c4

claim format debug

Browse files
Files changed (1) hide show
  1. minicheck/minicheck.py +7 -4
minicheck/minicheck.py CHANGED
@@ -4,7 +4,7 @@ import sys
4
  sys.path.append("..")
5
 
6
  from minicheck.inference import Inferencer
7
- from typing import List
8
  import numpy as np
9
 
10
 
@@ -18,7 +18,7 @@ class MiniCheck:
18
  max_input_length=max_input_length,
19
  )
20
 
21
- def score(self, docs: List[str], claims: List[str]) -> List[float]:
22
  '''
23
  pred_labels: 0 / 1 (0: unsupported, 1: supported)
24
  max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
@@ -26,8 +26,11 @@ class MiniCheck:
26
  support_prob_per_chunk: the probability of "supported" for each chunk
27
  '''
28
 
29
- assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray, instead of {type(docs)};\n{docs}"
30
- assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray, instead of {type(claims)};\n{claims}"
 
 
 
31
 
32
  max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
33
  pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
 
4
  sys.path.append("..")
5
 
6
  from minicheck.inference import Inferencer
7
+ from typing import List, Dict
8
  import numpy as np
9
 
10
 
 
18
  max_input_length=max_input_length,
19
  )
20
 
21
+ def score(self, inputs: Dict) -> List[float]:
22
  '''
23
  pred_labels: 0 / 1 (0: unsupported, 1: supported)
24
  max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
 
26
  support_prob_per_chunk: the probability of "supported" for each chunk
27
  '''
28
 
29
+ docs = inputs['docs']
30
+ claims = inputs['claims']
31
+
32
+ assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
33
+ assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
34
 
35
  max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
36
  pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]