clean up code
Browse files- app.py +3 -5
- eval_modules/utils.py +5 -76
app.py
CHANGED
@@ -98,9 +98,7 @@ def chat(
|
|
98 |
partial_text += f"1. Text Repetition Score: {repetition_score:.3f}\n"
|
99 |
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
|
100 |
rr = total_repetitions / len(answer) if len(answer) > 0 else 0
|
101 |
-
partial_text +=
|
102 |
-
f"1. Repetition Ratio: {rr:.3f}\n"
|
103 |
-
)
|
104 |
|
105 |
if index >= 0: # RAG
|
106 |
key = (
|
@@ -114,9 +112,9 @@ def chat(
|
|
114 |
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
|
115 |
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
|
116 |
perf = scores["bert_scores"]["f1"][0]
|
117 |
-
partial_text += f
|
118 |
nrr = 1 - rr
|
119 |
-
partial_text += f
|
120 |
|
121 |
partial_text += f"\n\nGround truth: {questions[index][key][0]}\n"
|
122 |
|
|
|
98 |
partial_text += f"1. Text Repetition Score: {repetition_score:.3f}\n"
|
99 |
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
|
100 |
rr = total_repetitions / len(answer) if len(answer) > 0 else 0
|
101 |
+
partial_text += f"1. Repetition Ratio: {rr:.3f}\n"
|
|
|
|
|
102 |
|
103 |
if index >= 0: # RAG
|
104 |
key = (
|
|
|
112 |
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
|
113 |
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
|
114 |
perf = scores["bert_scores"]["f1"][0]
|
115 |
+
partial_text += f"1. BERT-F1: {perf:.3f}\n"
|
116 |
nrr = 1 - rr
|
117 |
+
partial_text += f"1. RAP-BERT-F1: {perf * nrr * nrr * nrr:.3f}\n"
|
118 |
|
119 |
partial_text += f"\n\nGround truth: {questions[index][key][0]}\n"
|
120 |
|
eval_modules/utils.py
CHANGED
@@ -7,6 +7,11 @@ import pandas as pd
|
|
7 |
|
8 |
print(f"loading: {__file__}")
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
# pattern_non_word_char_repetition = re.compile(r"\s{5,}")
|
11 |
# pattern_text_repetitions = re.compile(r"(.{5}.*)\s*((\1)\s*)+", re.M | re.DOTALL)
|
12 |
|
@@ -81,12 +86,6 @@ def detect_repetitions(text, debug=False):
|
|
81 |
return result
|
82 |
|
83 |
|
84 |
-
|
85 |
-
bleu = evaluate.load("bleu")
|
86 |
-
rouge = evaluate.load("rouge")
|
87 |
-
bert_score = evaluate.load("bertscore")
|
88 |
-
|
89 |
-
|
90 |
def calc_perf_scores(predictions, references, debug=False):
|
91 |
if debug:
|
92 |
print("predictions:", predictions)
|
@@ -112,73 +111,3 @@ def calc_perf_scores(predictions, references, debug=False):
|
|
112 |
print("result:", result)
|
113 |
|
114 |
return result
|
115 |
-
|
116 |
-
|
117 |
-
def calc_metrics(df):
|
118 |
-
predictions = [df["answer"][i] for i in range(len(df))]
|
119 |
-
references = [df["ground_truth"][i] for i in range(len(df))]
|
120 |
-
|
121 |
-
return calc_bleu_rouge_scores(predictions, references)
|
122 |
-
|
123 |
-
|
124 |
-
pattern_abnormal_newlines = re.compile(r"\n{5,}")
|
125 |
-
pattern_text_repetitions = re.compile(r"\b(\w.+?)\b(\1+)", re.M | re.DOTALL)
|
126 |
-
exception_pattern = re.compile(r"(\w+\.)\1")
|
127 |
-
|
128 |
-
|
129 |
-
# final version for repetition detection
|
130 |
-
def detect_repetitions(
|
131 |
-
text, debug=False, pattern_text_repetitions=pattern_text_repetitions
|
132 |
-
):
|
133 |
-
subtotals = [0, 0]
|
134 |
-
|
135 |
-
if isinstance(text, str):
|
136 |
-
patterns = [pattern_abnormal_newlines, pattern_text_repetitions]
|
137 |
-
for i, pattern in enumerate(patterns):
|
138 |
-
if debug:
|
139 |
-
print(
|
140 |
-
f"----detect {'abnormal newlines' if i == 0 else 'text repetitions'}----"
|
141 |
-
)
|
142 |
-
matches = pattern.finditer(text)
|
143 |
-
for match in matches:
|
144 |
-
if debug:
|
145 |
-
print(match)
|
146 |
-
for groupNum in range(0, len(match.groups())):
|
147 |
-
groupNum = groupNum + 1
|
148 |
-
print(
|
149 |
-
"Group {groupNum} found at {start}-{end}: `{group}`".format(
|
150 |
-
groupNum=groupNum,
|
151 |
-
start=match.start(groupNum),
|
152 |
-
end=match.end(groupNum),
|
153 |
-
group=match.group(groupNum),
|
154 |
-
)
|
155 |
-
)
|
156 |
-
|
157 |
-
if exception_pattern.match(match[0]):
|
158 |
-
if debug:
|
159 |
-
print("ignored: ", match[0])
|
160 |
-
continue
|
161 |
-
|
162 |
-
start, end = match.span()
|
163 |
-
subtotals[i] += end - start
|
164 |
-
|
165 |
-
result = (subtotals[0], subtotals[1], subtotals[0] + subtotals[1])
|
166 |
-
|
167 |
-
if debug:
|
168 |
-
print(result)
|
169 |
-
return result
|
170 |
-
|
171 |
-
|
172 |
-
def detect_abnormal_newlines(text, debug=False):
|
173 |
-
return detect_repetitions(text, debug=debug)[0]
|
174 |
-
|
175 |
-
|
176 |
-
def detect_text_repetitions(text, debug=False):
|
177 |
-
return detect_repetitions(text, debug=debug)[1]
|
178 |
-
|
179 |
-
|
180 |
-
def detect_repetition_scores(text, debug=False):
|
181 |
-
newline_score, repetition_score, total_repetitions = detect_repetitions(
|
182 |
-
text, debug=debug
|
183 |
-
)
|
184 |
-
return pd.Series([newline_score, repetition_score, total_repetitions])
|
|
|
7 |
|
8 |
print(f"loading: {__file__}")
|
9 |
|
10 |
+
|
11 |
+
bleu = evaluate.load("bleu")
|
12 |
+
rouge = evaluate.load("rouge")
|
13 |
+
bert_score = evaluate.load("bertscore")
|
14 |
+
|
15 |
# pattern_non_word_char_repetition = re.compile(r"\s{5,}")
|
16 |
# pattern_text_repetitions = re.compile(r"(.{5}.*)\s*((\1)\s*)+", re.M | re.DOTALL)
|
17 |
|
|
|
86 |
return result
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def calc_perf_scores(predictions, references, debug=False):
|
90 |
if debug:
|
91 |
print("predictions:", predictions)
|
|
|
111 |
print("result:", result)
|
112 |
|
113 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|