Liyan06
commited on
Commit
·
8c6cca8
1
Parent(s):
87eb2c4
claim format debug
Browse files- minicheck/minicheck.py +7 -4
minicheck/minicheck.py
CHANGED
@@ -4,7 +4,7 @@ import sys
|
|
4 |
sys.path.append("..")
|
5 |
|
6 |
from minicheck.inference import Inferencer
|
7 |
-
from typing import List
|
8 |
import numpy as np
|
9 |
|
10 |
|
@@ -18,7 +18,7 @@ class MiniCheck:
|
|
18 |
max_input_length=max_input_length,
|
19 |
)
|
20 |
|
21 |
-
def score(self,
|
22 |
'''
|
23 |
pred_labels: 0 / 1 (0: unsupported, 1: supported)
|
24 |
max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
|
@@ -26,8 +26,11 @@ class MiniCheck:
|
|
26 |
support_prob_per_chunk: the probability of "supported" for each chunk
|
27 |
'''
|
28 |
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
|
33 |
pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
|
|
|
4 |
sys.path.append("..")
|
5 |
|
6 |
from minicheck.inference import Inferencer
|
7 |
+
from typing import List, Dict
|
8 |
import numpy as np
|
9 |
|
10 |
|
|
|
18 |
max_input_length=max_input_length,
|
19 |
)
|
20 |
|
21 |
+
def score(self, inputs: Dict) -> List[float]:
|
22 |
'''
|
23 |
pred_labels: 0 / 1 (0: unsupported, 1: supported)
|
24 |
max_support_probs: the probability of "supported" for the chunk that determin the final pred_label
|
|
|
26 |
support_prob_per_chunk: the probability of "supported" for each chunk
|
27 |
'''
|
28 |
|
29 |
+
docs = inputs['docs']
|
30 |
+
claims = inputs['claims']
|
31 |
+
|
32 |
+
assert isinstance(docs, list) or isinstance(docs, np.ndarray), f"docs must be a list or np.ndarray"
|
33 |
+
assert isinstance(claims, list) or isinstance(claims, np.ndarray), f"claims must be a list or np.ndarray"
|
34 |
|
35 |
max_support_prob, used_chunk, support_prob_per_chunk = self.model.fact_check(docs, claims)
|
36 |
pred_label = [1 if prob > 0.5 else 0 for prob in max_support_prob]
|