Spaces:
Build error
Build error
completed llama-3.1-70b generic prompt
Browse files- eval_modules/calc_repetitions_v2d.py +1281 -0
- llm_toolkit/translation_utils_v2.py +766 -0
- notebooks/00f_Data Analysis_Fine_Tuned_RPP_Generic_Prompt.ipynb +0 -0
- notebooks/03a_RAPGeT_v2_Data Analysis_Chat_Template.ipynb +0 -0
- notebooks/03b_RAPGeT_v2_Data Analysis_Generic_Prompt.ipynb +0 -0
- results/mac-results_rpp_with_mnt_2048_generic_prompt_metrics.csv +25 -21
- results/mac-results_rpp_with_mnt_2048_metrics.csv +31 -43
eval_modules/calc_repetitions_v2d.py
ADDED
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib.ticker as mtick
|
8 |
+
import seaborn as sns
|
9 |
+
import nltk
|
10 |
+
import evaluate
|
11 |
+
import traceback
|
12 |
+
|
13 |
+
bert_score = evaluate.load("bertscore")
|
14 |
+
meteor = evaluate.load("meteor")
|
15 |
+
|
16 |
+
print(f"loading: {__file__}")
|
17 |
+
|
18 |
+
# pattern_non_word_char_repetition = re.compile(r"\s{5,}")
|
19 |
+
# pattern_text_repetitions = re.compile(r"(.{5}.*)\s*((\1)\s*)+", re.M | re.DOTALL)
|
20 |
+
|
21 |
+
# final version
|
22 |
+
pattern_non_word_char_repetition = re.compile(r"[\s\W]{5,}")
|
23 |
+
pattern_text_repetitions = re.compile(
|
24 |
+
r"(?P<repeat>.{5}.*?)(?:[\s\W]*(?P=repeat))+", re.M | re.DOTALL | re.IGNORECASE
|
25 |
+
)
|
26 |
+
# Explanation of the Regex Pattern:
|
27 |
+
# (?P<repeat>.{5}.*?): Captures any sequence of characters with minimal length of 5 and names this group repeat.
|
28 |
+
# .*?: Matches zero or more characters, non-greedily (as few as possible).
|
29 |
+
# (?:[\s\W]+(?P=repeat))+: A non-capturing group that matches one or more repetitions of:
|
30 |
+
# [\s\W]+: One or more whitespace or non-word characters (spaces, punctuation, etc.).
|
31 |
+
# (?P=repeat): A backreference to the named group repeat.
|
32 |
+
|
33 |
+
|
34 |
+
def del_non_word_char_repetition(text, debug=False):
|
35 |
+
count = 0
|
36 |
+
|
37 |
+
if isinstance(text, str):
|
38 |
+
if debug:
|
39 |
+
print("----detect non-word characters repetition----")
|
40 |
+
count = len(text)
|
41 |
+
text = pattern_non_word_char_repetition.sub("\t", text)
|
42 |
+
count -= len(text)
|
43 |
+
if debug and count:
|
44 |
+
print(f"removed non-word characters repetition: {count}")
|
45 |
+
return text, count
|
46 |
+
|
47 |
+
|
48 |
+
# final version for repetition detection
|
49 |
+
def detect_text_repetitions(text, debug=False):
|
50 |
+
count = 0
|
51 |
+
|
52 |
+
if isinstance(text, str):
|
53 |
+
if debug:
|
54 |
+
print("----detect text repetitions----")
|
55 |
+
matches = pattern_text_repetitions.finditer(text)
|
56 |
+
for match in matches:
|
57 |
+
if debug:
|
58 |
+
print(match)
|
59 |
+
for groupNum in range(0, len(match.groups())):
|
60 |
+
groupNum = groupNum + 1
|
61 |
+
print(
|
62 |
+
"Group {groupNum} found at {start}-{end}: `{group}`".format(
|
63 |
+
groupNum=groupNum,
|
64 |
+
start=match.start(groupNum),
|
65 |
+
end=match.end(groupNum),
|
66 |
+
group=match.group(groupNum),
|
67 |
+
)
|
68 |
+
)
|
69 |
+
|
70 |
+
start, end = match.span()
|
71 |
+
count += end - start - len(match.group(1))
|
72 |
+
|
73 |
+
return count
|
74 |
+
|
75 |
+
|
76 |
+
def detect_repetitions(text, debug=False):
|
77 |
+
text, count_non_word_char_repetition = del_non_word_char_repetition(
|
78 |
+
text, debug=debug
|
79 |
+
)
|
80 |
+
count_text_repetitions = detect_text_repetitions(text, debug=debug)
|
81 |
+
total_repetitions = count_non_word_char_repetition + count_text_repetitions
|
82 |
+
|
83 |
+
result = (count_non_word_char_repetition, count_text_repetitions, total_repetitions)
|
84 |
+
|
85 |
+
if debug:
|
86 |
+
print(result)
|
87 |
+
return result
|
88 |
+
|
89 |
+
|
90 |
+
def detect_scores(text, debug=False):
|
91 |
+
newline_score, repetition_score, total_repetitions = detect_repetitions(
|
92 |
+
text, debug=debug
|
93 |
+
)
|
94 |
+
return pd.Series([newline_score, repetition_score, total_repetitions])
|
95 |
+
|
96 |
+
|
97 |
+
def load_with_newline_and_repetition_scores(result_file, force_recalculate=False):
|
98 |
+
print(f"loading result file: {result_file}")
|
99 |
+
df = pd.read_csv(result_file, comment="#", on_bad_lines="warn")
|
100 |
+
|
101 |
+
if (
|
102 |
+
force_recalculate
|
103 |
+
or "newline_score" not in df.columns
|
104 |
+
or "repetition_score" not in df.columns
|
105 |
+
or "total_repetitions" not in df.columns
|
106 |
+
or "nrr" not in df.columns
|
107 |
+
or "rr" not in df.columns
|
108 |
+
):
|
109 |
+
if (
|
110 |
+
force_recalculate
|
111 |
+
or "newline_score" not in df.columns
|
112 |
+
or "repetition_score" not in df.columns
|
113 |
+
or "total_repetitions" not in df.columns
|
114 |
+
):
|
115 |
+
df[["newline_score", "repetition_score", "total_repetitions"]] = df[
|
116 |
+
"answer"
|
117 |
+
].apply(detect_scores)
|
118 |
+
|
119 |
+
df["answer_len"] = df["answer"].apply(
|
120 |
+
lambda x: len(x) if isinstance(x, str) else 0
|
121 |
+
)
|
122 |
+
|
123 |
+
df["nrr"] = df.apply(
|
124 |
+
lambda x: (
|
125 |
+
1
|
126 |
+
if x["answer_len"] == 0
|
127 |
+
else 1 - (x["newline_score"] + x["repetition_score"]) / x["answer_len"]
|
128 |
+
),
|
129 |
+
axis=1,
|
130 |
+
)
|
131 |
+
|
132 |
+
df["rr"] = df["nrr"].apply(lambda x: 1 - x)
|
133 |
+
|
134 |
+
df.to_csv(result_file, index=False)
|
135 |
+
|
136 |
+
return df
|
137 |
+
|
138 |
+
|
139 |
+
def replace_last(source_string, old_string, new_string):
|
140 |
+
head, _sep, tail = source_string.rpartition(old_string)
|
141 |
+
return head + new_string + tail
|
142 |
+
|
143 |
+
|
144 |
+
def load_for_repetition_penalty(
|
145 |
+
csv_result_file, repetition_penalty, force_recalculate=False
|
146 |
+
):
|
147 |
+
result_file = replace_last(
|
148 |
+
csv_result_file, ".csv", f"_RP_{repetition_penalty:.3f}.csv"
|
149 |
+
)
|
150 |
+
return load_with_newline_and_repetition_scores(
|
151 |
+
result_file, force_recalculate=force_recalculate
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
def calc_adjusted_performance(f, r):
|
156 |
+
return f / math.log10(10 + r)
|
157 |
+
|
158 |
+
|
159 |
+
def calculate_adjusted_performance(row):
|
160 |
+
r = row["total_repetitions"]
|
161 |
+
adjusted_precision = calc_adjusted_performance(row["precision"], r)
|
162 |
+
adjusted_recall = calc_adjusted_performance(row["recall"], r)
|
163 |
+
return pd.Series([adjusted_precision, adjusted_recall])
|
164 |
+
|
165 |
+
|
166 |
+
def load_performance_df(csv_result_file, repetition_penalty):
|
167 |
+
result_file = replace_last(
|
168 |
+
csv_result_file, ".csv", f"_RP_{repetition_penalty:.3f}-t2_evaluated.json"
|
169 |
+
)
|
170 |
+
result_file = result_file.replace("/results/", "/eval/")
|
171 |
+
print(f"loading json file: {result_file}")
|
172 |
+
df = pd.read_json(result_file)
|
173 |
+
|
174 |
+
return df
|
175 |
+
|
176 |
+
|
177 |
+
def calculate_performance_score(
|
178 |
+
csv_result_file, repetition_penalty, force_recalculate=False
|
179 |
+
):
|
180 |
+
result_file = replace_last(
|
181 |
+
csv_result_file, ".csv", f"_rpp_{repetition_penalty:.2f}.csv"
|
182 |
+
)
|
183 |
+
|
184 |
+
if os.path.exists(result_file):
|
185 |
+
print(f"loading result file: {result_file}")
|
186 |
+
df = load_with_newline_and_repetition_scores(
|
187 |
+
result_file, force_recalculate=force_recalculate
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
print(f"re-creating result file: {result_file}")
|
191 |
+
df = pd.DataFrame()
|
192 |
+
force_recalculate = True
|
193 |
+
|
194 |
+
if force_recalculate or "f2" in df.columns or "f1" not in df.columns:
|
195 |
+
try:
|
196 |
+
perf_df = load_performance_df(csv_result_file, repetition_penalty)
|
197 |
+
df.drop(
|
198 |
+
columns=[
|
199 |
+
"precision",
|
200 |
+
"recall",
|
201 |
+
"f1",
|
202 |
+
"f2",
|
203 |
+
"entities_in_answer",
|
204 |
+
"entities_in_question",
|
205 |
+
"word_count",
|
206 |
+
],
|
207 |
+
errors="ignore",
|
208 |
+
inplace=True,
|
209 |
+
)
|
210 |
+
|
211 |
+
df["id"] = perf_df["id"]
|
212 |
+
df["question"] = perf_df["question"]
|
213 |
+
df["answer"] = perf_df["pred_answer"]
|
214 |
+
df["word_count"] = df["answer"].apply(
|
215 |
+
lambda x: len(nltk.word_tokenize(x)) if isinstance(x, str) else 0
|
216 |
+
)
|
217 |
+
df["ground_truth"] = perf_df["ground_truth"]
|
218 |
+
|
219 |
+
df["eval_gemini_1.0_pro"] = perf_df["eval_gemini_1.0_pro"]
|
220 |
+
df["precision"] = perf_df["score"].apply(lambda x: x[0])
|
221 |
+
df["recall"] = perf_df["score"].apply(lambda x: x[1])
|
222 |
+
df["f1"] = perf_df["score"].apply(lambda x: x[2])
|
223 |
+
except Exception as e:
|
224 |
+
print(f"\tignored error: {e}")
|
225 |
+
# traceback.print_exc()
|
226 |
+
|
227 |
+
df[["newline_score", "repetition_score", "total_repetitions"]] = df[
|
228 |
+
"answer"
|
229 |
+
].apply(detect_scores)
|
230 |
+
|
231 |
+
df[["adjusted_precision", "adjusted_recall"]] = df.apply(
|
232 |
+
calculate_adjusted_performance, axis=1
|
233 |
+
)
|
234 |
+
|
235 |
+
df.to_csv(result_file, index=False)
|
236 |
+
print(f"performance scores saved to result file: {result_file}")
|
237 |
+
|
238 |
+
# print(f"df len: {len(df)}")
|
239 |
+
|
240 |
+
return df
|
241 |
+
|
242 |
+
|
243 |
+
def adjust_perf_scores_with_repetition_penalty(result, precision, recall):
|
244 |
+
newline_score = [
|
245 |
+
df["newline_score"].mean() for df in result["df_list_repetition_penalty"]
|
246 |
+
]
|
247 |
+
|
248 |
+
repetition_score = [
|
249 |
+
df["repetition_score"].mean() for df in result["df_list_repetition_penalty"]
|
250 |
+
]
|
251 |
+
|
252 |
+
precision = [
|
253 |
+
f / math.log10(10 + n + r)
|
254 |
+
for f, n, r in zip(precision, newline_score, repetition_score)
|
255 |
+
]
|
256 |
+
recall = [
|
257 |
+
f / math.log10(10 + n + r)
|
258 |
+
for f, n, r in zip(recall, newline_score, repetition_score)
|
259 |
+
]
|
260 |
+
|
261 |
+
return precision, recall
|
262 |
+
|
263 |
+
|
264 |
+
def plot_performance_scores(
|
265 |
+
result,
|
266 |
+
models=None,
|
267 |
+
title="Performance",
|
268 |
+
):
|
269 |
+
if models is None:
|
270 |
+
models = result.keys()
|
271 |
+
for model in models:
|
272 |
+
print(f"model: {model}")
|
273 |
+
df = result[model]["df_overall"]
|
274 |
+
|
275 |
+
# Calculate the statistics
|
276 |
+
precision = [
|
277 |
+
df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
|
278 |
+
]
|
279 |
+
recall = [
|
280 |
+
df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
|
281 |
+
]
|
282 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
283 |
+
best_f1 = max(f1)
|
284 |
+
best_f1_index = f1.index(best_f1)
|
285 |
+
|
286 |
+
precision, recall = adjust_perf_scores_with_repetition_penalty(
|
287 |
+
result[model], precision, recall
|
288 |
+
)
|
289 |
+
afrp = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
290 |
+
|
291 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
292 |
+
best_afrp = max(afrp)
|
293 |
+
best_afrp_index = afrp.index(best_afrp)
|
294 |
+
|
295 |
+
adjusted_precision = [
|
296 |
+
df["adjusted_precision"].mean()
|
297 |
+
for df in result[model]["df_list_repetition_penalty"]
|
298 |
+
]
|
299 |
+
adjusted_recall = [
|
300 |
+
df["adjusted_recall"].mean()
|
301 |
+
for df in result[model]["df_list_repetition_penalty"]
|
302 |
+
]
|
303 |
+
afrp2 = [
|
304 |
+
2 * (p * r) / (p + r) for p, r in zip(adjusted_precision, adjusted_recall)
|
305 |
+
]
|
306 |
+
best_afrp2 = max(afrp2)
|
307 |
+
best_afrp2_index = afrp2.index(best_afrp2)
|
308 |
+
|
309 |
+
repetition_penalties = list(df["repetition_penalty"])
|
310 |
+
|
311 |
+
# line plot for precision, recall, f1
|
312 |
+
plt.figure(figsize=(10, 6))
|
313 |
+
|
314 |
+
plt.axvspan(
|
315 |
+
repetition_penalties[best_f1_index] - 0.01,
|
316 |
+
repetition_penalties[best_f1_index] + 0.01,
|
317 |
+
alpha=0.5,
|
318 |
+
edgecolor="none",
|
319 |
+
facecolor="blue",
|
320 |
+
)
|
321 |
+
|
322 |
+
# plt.axvspan(
|
323 |
+
# repetition_penalties[best_afrp2_index] - 0.01,
|
324 |
+
# repetition_penalties[best_afrp2_index] + 0.01,
|
325 |
+
# alpha=0.5,
|
326 |
+
# edgecolor="none",
|
327 |
+
# facecolor="green",
|
328 |
+
# )
|
329 |
+
|
330 |
+
plt.axvspan(
|
331 |
+
repetition_penalties[best_afrp_index] - 0.01,
|
332 |
+
repetition_penalties[best_afrp_index] + 0.01,
|
333 |
+
alpha=0.5,
|
334 |
+
edgecolor="none",
|
335 |
+
facecolor="orange",
|
336 |
+
)
|
337 |
+
|
338 |
+
plt.plot(repetition_penalties, f1, label="F1", marker="D", color="blue")
|
339 |
+
# plt.plot(
|
340 |
+
# repetition_penalties,
|
341 |
+
# afrp2,
|
342 |
+
# label="Per-question RAP - F1",
|
343 |
+
# marker="s",
|
344 |
+
# color="green",
|
345 |
+
# )
|
346 |
+
plt.plot(
|
347 |
+
repetition_penalties,
|
348 |
+
afrp,
|
349 |
+
label="RAP - F1",
|
350 |
+
marker="o",
|
351 |
+
color="orange",
|
352 |
+
)
|
353 |
+
plt.xlabel("Repetition Penalties")
|
354 |
+
plt.ylabel("Score")
|
355 |
+
# plt.xlim(0.99, 1.31)
|
356 |
+
# y in percentage
|
357 |
+
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
358 |
+
plt.title(f"{model} {title}")
|
359 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
360 |
+
|
361 |
+
plt.show()
|
362 |
+
|
363 |
+
|
364 |
+
def plot_best_afrp(
|
365 |
+
result,
|
366 |
+
models=None,
|
367 |
+
title="Models with Best RAP - F1",
|
368 |
+
ref_result=None,
|
369 |
+
):
|
370 |
+
# Initialize lists to store the statistics
|
371 |
+
model_names = []
|
372 |
+
best_f1 = []
|
373 |
+
best_afrp = []
|
374 |
+
best_repetition_penalty = []
|
375 |
+
best_mtr = []
|
376 |
+
|
377 |
+
if models is None:
|
378 |
+
models = result.keys()
|
379 |
+
for model in models:
|
380 |
+
print(f"model: {model}")
|
381 |
+
df = result[model]["df_overall"]
|
382 |
+
|
383 |
+
# Calculate the statistics
|
384 |
+
precision = [
|
385 |
+
df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
|
386 |
+
]
|
387 |
+
recall = [
|
388 |
+
df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
|
389 |
+
]
|
390 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
391 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
392 |
+
|
393 |
+
newline_score = [
|
394 |
+
df["newline_score"].mean()
|
395 |
+
for df in result[model]["df_list_repetition_penalty"]
|
396 |
+
]
|
397 |
+
# print(f"newline_score: {newline_score}")
|
398 |
+
|
399 |
+
repetition_score = [
|
400 |
+
df["repetition_score"].mean()
|
401 |
+
for df in result[model]["df_list_repetition_penalty"]
|
402 |
+
]
|
403 |
+
# print(f"repetition_score: {repetition_score}")
|
404 |
+
|
405 |
+
afrp = [
|
406 |
+
f / math.log10(10 + n + r)
|
407 |
+
for f, n, r in zip(f1, newline_score, repetition_score)
|
408 |
+
]
|
409 |
+
|
410 |
+
best_afrp.append(max(afrp))
|
411 |
+
best_afrp_index = afrp.index(best_afrp[-1])
|
412 |
+
best_repetition_penalty.append(df["repetition_penalty"][best_afrp_index])
|
413 |
+
|
414 |
+
best_f1.append(f1[best_afrp_index])
|
415 |
+
best_mtr.append(
|
416 |
+
newline_score[best_afrp_index] + repetition_score[best_afrp_index]
|
417 |
+
)
|
418 |
+
|
419 |
+
# print(
|
420 |
+
# f"best repetition penalty: {best_repetition_penalty[-1]}, best afrp: {best_afrp[-1]}, f1: {best_f1[-1]}"
|
421 |
+
# )
|
422 |
+
|
423 |
+
df = result[model]["df_list_repetition_penalty"][best_afrp_index]
|
424 |
+
|
425 |
+
model_names.append(
|
426 |
+
f"{model} (RP={best_repetition_penalty[-1]})"
|
427 |
+
) # Add the model name to the list
|
428 |
+
|
429 |
+
if ref_result is not None:
|
430 |
+
print("ref_result:", ref_result)
|
431 |
+
for model in ref_result.keys():
|
432 |
+
model_names.append(model)
|
433 |
+
df = pd.read_csv(ref_result[model])
|
434 |
+
# df = df[df["id"].isin(wikidata_df["id"])]
|
435 |
+
|
436 |
+
p = df["precision"].mean()
|
437 |
+
r = df["recall"].mean()
|
438 |
+
|
439 |
+
f1 = 2 * p * r / (p + r) if p + r > 0 else 0
|
440 |
+
best_f1.append(f1)
|
441 |
+
best_afrp.append(f1)
|
442 |
+
best_mtr.append(0)
|
443 |
+
|
444 |
+
print("model_names:", model_names)
|
445 |
+
# print("best_f1:", best_f1)
|
446 |
+
# print("best_afrp:", best_afrp)
|
447 |
+
|
448 |
+
# Create a DataFrame with the statistics
|
449 |
+
data = pd.DataFrame(
|
450 |
+
{
|
451 |
+
"Model": model_names,
|
452 |
+
"RAP - F1": best_afrp,
|
453 |
+
"F1": best_f1,
|
454 |
+
}
|
455 |
+
)
|
456 |
+
|
457 |
+
# Melt the DataFrame to a long format
|
458 |
+
data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
|
459 |
+
|
460 |
+
# Pivot the DataFrame to a wide format
|
461 |
+
data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
|
462 |
+
|
463 |
+
# make sure the columns are following the order of the models
|
464 |
+
data_pivoted = data_pivoted[model_names]
|
465 |
+
|
466 |
+
# make sure three groups in the order of precision, recall, f1
|
467 |
+
data_pivoted = data_pivoted.reindex(["RAP - F1", "F1"])
|
468 |
+
|
469 |
+
# Plot the statistics
|
470 |
+
plt.figure(figsize=(15, 6))
|
471 |
+
ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
|
472 |
+
plt.title(title)
|
473 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
474 |
+
|
475 |
+
# Set the rotation of the x-axis labels to 0 degrees
|
476 |
+
plt.xticks(rotation=0)
|
477 |
+
|
478 |
+
# Format the y-axis to display as percentage
|
479 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
480 |
+
|
481 |
+
# get the max value of the y-axis
|
482 |
+
a1 = max(best_afrp)
|
483 |
+
a2 = max(best_f1)
|
484 |
+
|
485 |
+
max_value = max([a1, a2]) * 1.12
|
486 |
+
print("max_value:", max_value)
|
487 |
+
|
488 |
+
# Set the y-axis limit up to 70%
|
489 |
+
ax.set_ylim(0, max_value)
|
490 |
+
|
491 |
+
# Add the values above each bar
|
492 |
+
for p in ax.patches:
|
493 |
+
ax.annotate(
|
494 |
+
f"{p.get_height() * 100:.1f}",
|
495 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
496 |
+
ha="center",
|
497 |
+
va="bottom",
|
498 |
+
xytext=(0, 10),
|
499 |
+
textcoords="offset points",
|
500 |
+
rotation=90,
|
501 |
+
)
|
502 |
+
|
503 |
+
plt.show()
|
504 |
+
return data_pivoted, best_mtr
|
505 |
+
|
506 |
+
|
507 |
+
def plot_best_performance(
|
508 |
+
result,
|
509 |
+
models=None,
|
510 |
+
title="Models with Best F1 Score",
|
511 |
+
adjusted_f1=False,
|
512 |
+
ref_result=None,
|
513 |
+
):
|
514 |
+
# Initialize lists to store the statistics
|
515 |
+
model_names = []
|
516 |
+
best_precision = []
|
517 |
+
best_recall = []
|
518 |
+
best_f1 = []
|
519 |
+
best_repetition_penalty = []
|
520 |
+
best_mtr = []
|
521 |
+
|
522 |
+
if models is None:
|
523 |
+
models = result.keys()
|
524 |
+
for model in models:
|
525 |
+
print(f"model: {model}")
|
526 |
+
df = result[model]["df_overall"]
|
527 |
+
|
528 |
+
# Calculate the statistics
|
529 |
+
precision = [
|
530 |
+
df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
|
531 |
+
]
|
532 |
+
recall = [
|
533 |
+
df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
|
534 |
+
]
|
535 |
+
newline_score = [
|
536 |
+
df["newline_score"].mean()
|
537 |
+
for df in result[model]["df_list_repetition_penalty"]
|
538 |
+
]
|
539 |
+
|
540 |
+
repetition_score = [
|
541 |
+
df["repetition_score"].mean()
|
542 |
+
for df in result[model]["df_list_repetition_penalty"]
|
543 |
+
]
|
544 |
+
|
545 |
+
if adjusted_f1:
|
546 |
+
precision, recall = adjust_perf_scores_with_repetition_penalty(
|
547 |
+
result[model], precision, recall
|
548 |
+
)
|
549 |
+
|
550 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
551 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
552 |
+
|
553 |
+
best_f1.append(max(f1))
|
554 |
+
best_f1_index = f1.index(best_f1[-1])
|
555 |
+
best_repetition_penalty.append(df["repetition_penalty"][best_f1_index])
|
556 |
+
|
557 |
+
best_precision.append(precision[best_f1_index])
|
558 |
+
best_recall.append(recall[best_f1_index])
|
559 |
+
best_mtr.append(newline_score[best_f1_index] + repetition_score[best_f1_index])
|
560 |
+
|
561 |
+
print(
|
562 |
+
f"best repetition penalty: {best_repetition_penalty[-1]}, best f1: {best_f1[-1]}, precision: {best_precision[-1]}, recall: {best_recall[-1]}"
|
563 |
+
)
|
564 |
+
|
565 |
+
df = result[model]["df_list_repetition_penalty"][best_f1_index]
|
566 |
+
|
567 |
+
model_names.append(
|
568 |
+
f"{model} (RP={best_repetition_penalty[-1]})"
|
569 |
+
) # Add the model name to the list
|
570 |
+
|
571 |
+
# print sum for columns: newline_score, repetition_score
|
572 |
+
print(
|
573 |
+
f"newline_score: {df['newline_score'].sum()}, repetition_score: {df['repetition_score'].sum()}"
|
574 |
+
)
|
575 |
+
|
576 |
+
if ref_result is not None:
|
577 |
+
print("ref_result:", ref_result)
|
578 |
+
for model in ref_result.keys():
|
579 |
+
model_names.append(model)
|
580 |
+
df = pd.read_csv(ref_result[model])
|
581 |
+
# df = df[df["id"].isin(wikidata_df["id"])]
|
582 |
+
|
583 |
+
best_precision.append(df["precision"].mean())
|
584 |
+
best_recall.append(df["recall"].mean())
|
585 |
+
f1 = (
|
586 |
+
2
|
587 |
+
* (best_precision[-1] * best_recall[-1])
|
588 |
+
/ (best_precision[-1] + best_recall[-1])
|
589 |
+
)
|
590 |
+
# best_f1.append(df["f1"].mean())
|
591 |
+
best_f1.append(f1)
|
592 |
+
best_mtr.append(0)
|
593 |
+
|
594 |
+
# Create a DataFrame with the statistics
|
595 |
+
data = (
|
596 |
+
pd.DataFrame(
|
597 |
+
{
|
598 |
+
"Model": model_names,
|
599 |
+
"Adjusted Precision with RP": best_precision,
|
600 |
+
"Adjusted Recall with RP": best_recall,
|
601 |
+
"Adjusted F1 with RP": best_f1,
|
602 |
+
}
|
603 |
+
)
|
604 |
+
if adjusted_f1
|
605 |
+
else pd.DataFrame(
|
606 |
+
{
|
607 |
+
"Model": model_names,
|
608 |
+
"Precision": best_precision,
|
609 |
+
"Recall": best_recall,
|
610 |
+
"F1": best_f1,
|
611 |
+
}
|
612 |
+
)
|
613 |
+
)
|
614 |
+
columns = list(data.columns)
|
615 |
+
|
616 |
+
# Melt the DataFrame to a long format
|
617 |
+
data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
|
618 |
+
|
619 |
+
# Pivot the DataFrame to a wide format
|
620 |
+
data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
|
621 |
+
|
622 |
+
# make sure the columns are following the order of the models
|
623 |
+
data_pivoted = data_pivoted[model_names]
|
624 |
+
|
625 |
+
# make sure three groups in the order of precision, recall, f1
|
626 |
+
data_pivoted = data_pivoted.reindex(columns[1:])
|
627 |
+
|
628 |
+
# Plot the statistics
|
629 |
+
plt.figure(figsize=(10, 6))
|
630 |
+
ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
|
631 |
+
plt.title(title)
|
632 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
633 |
+
|
634 |
+
# Set the rotation of the x-axis labels to 0 degrees
|
635 |
+
plt.xticks(rotation=0)
|
636 |
+
|
637 |
+
# Format the y-axis to display as percentage
|
638 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
639 |
+
|
640 |
+
# get the max value of the y-axis
|
641 |
+
a1 = max(best_precision)
|
642 |
+
a2 = max(best_recall)
|
643 |
+
a3 = max(best_f1)
|
644 |
+
|
645 |
+
max_value = max([a1, a2, a3]) * 1.12
|
646 |
+
print("max_value:", max_value)
|
647 |
+
|
648 |
+
# Set the y-axis limit up to 70%
|
649 |
+
ax.set_ylim(0, max_value)
|
650 |
+
|
651 |
+
# Add the values above each bar
|
652 |
+
for p in ax.patches:
|
653 |
+
ax.annotate(
|
654 |
+
f"{p.get_height() * 100:.1f}",
|
655 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
656 |
+
ha="center",
|
657 |
+
va="bottom",
|
658 |
+
xytext=(0, 10),
|
659 |
+
textcoords="offset points",
|
660 |
+
rotation=90,
|
661 |
+
)
|
662 |
+
|
663 |
+
plt.show()
|
664 |
+
return data_pivoted, best_mtr
|
665 |
+
|
666 |
+
|
667 |
+
def plot_best_performance_ms_macro(
|
668 |
+
result,
|
669 |
+
models=None,
|
670 |
+
title="Models with Best RAP - Performance",
|
671 |
+
ref_result=None,
|
672 |
+
skip_generic_prompt=False,
|
673 |
+
include_adjusted_performance=True,
|
674 |
+
):
|
675 |
+
# Initialize lists to store the statistics
|
676 |
+
model_names = []
|
677 |
+
best_f1 = []
|
678 |
+
best_afrp = []
|
679 |
+
best_repetition_penalty = []
|
680 |
+
best_bleu1 = []
|
681 |
+
best_rougeL = []
|
682 |
+
best_mtr = []
|
683 |
+
|
684 |
+
if models is None:
|
685 |
+
models = result.keys()
|
686 |
+
for model in models:
|
687 |
+
if skip_generic_prompt and "generic prompt" in model:
|
688 |
+
continue
|
689 |
+
print(f"model: {model}")
|
690 |
+
df = result[model]["df_overall"]
|
691 |
+
|
692 |
+
# Calculate the statistics
|
693 |
+
bleu1 = [x for x in df["bleu1"]]
|
694 |
+
rougeL = [x for x in df["rougeL"]]
|
695 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
|
696 |
+
|
697 |
+
newline_score = [
|
698 |
+
df["newline_score"].mean()
|
699 |
+
for df in result[model]["df_list_repetition_penalty"]
|
700 |
+
]
|
701 |
+
# print(f"newline_score: {newline_score}")
|
702 |
+
|
703 |
+
repetition_score = [
|
704 |
+
df["repetition_score"].mean()
|
705 |
+
for df in result[model]["df_list_repetition_penalty"]
|
706 |
+
]
|
707 |
+
# print(f"repetition_score: {repetition_score}")
|
708 |
+
|
709 |
+
afrp = [
|
710 |
+
f / math.log10(10 + n + r)
|
711 |
+
for f, n, r in zip(f1, newline_score, repetition_score)
|
712 |
+
]
|
713 |
+
|
714 |
+
best_afrp.append(max(afrp if include_adjusted_performance else f1))
|
715 |
+
best_afrp_index = (
|
716 |
+
afrp.index(best_afrp[-1])
|
717 |
+
if include_adjusted_performance
|
718 |
+
else f1.index(best_afrp[-1])
|
719 |
+
)
|
720 |
+
best_repetition_penalty.append(df["repetition_penalty"][best_afrp_index])
|
721 |
+
|
722 |
+
best_f1.append(f1[best_afrp_index])
|
723 |
+
best_bleu1.append(bleu1[best_afrp_index])
|
724 |
+
best_rougeL.append(rougeL[best_afrp_index])
|
725 |
+
best_mtr.append(
|
726 |
+
newline_score[best_afrp_index] + repetition_score[best_afrp_index]
|
727 |
+
)
|
728 |
+
|
729 |
+
# print(
|
730 |
+
# f"best repetition penalty: {best_repetition_penalty[-1]}, best afrp: {best_afrp[-1]}, f1: {best_f1[-1]}"
|
731 |
+
# )
|
732 |
+
|
733 |
+
df = result[model]["df_list_repetition_penalty"][best_afrp_index]
|
734 |
+
|
735 |
+
model_names.append(
|
736 |
+
f"{model} (RP={best_repetition_penalty[-1]})"
|
737 |
+
) # Add the model name to the list
|
738 |
+
|
739 |
+
if ref_result is not None:
|
740 |
+
print("ref_result:", ref_result)
|
741 |
+
for model in ref_result.keys():
|
742 |
+
model_names.append(model)
|
743 |
+
df = pd.read_csv(ref_result[model], comment="#", on_bad_lines="warn")
|
744 |
+
# df = df[df["id"].isin(wikidata_df["id"])]
|
745 |
+
|
746 |
+
p = df["bleu1"][0]
|
747 |
+
best_bleu1.append(p)
|
748 |
+
|
749 |
+
r = df["rougeL"][0]
|
750 |
+
best_rougeL.append(r)
|
751 |
+
|
752 |
+
f1 = 2 * p * r / (p + r) if p + r > 0 else 0
|
753 |
+
best_f1.append(f1)
|
754 |
+
best_afrp.append(f1)
|
755 |
+
best_mtr.append(0)
|
756 |
+
|
757 |
+
# print("model_names:", model_names)
|
758 |
+
# print("best_f1:", best_f1)
|
759 |
+
# print("best_afrp:", best_afrp)
|
760 |
+
|
761 |
+
# Create a DataFrame with the statistics
|
762 |
+
data = (
|
763 |
+
pd.DataFrame(
|
764 |
+
{
|
765 |
+
"Model": model_names,
|
766 |
+
"RAP - Perf Score": best_afrp,
|
767 |
+
"Overall Perf Score": best_f1,
|
768 |
+
}
|
769 |
+
)
|
770 |
+
if include_adjusted_performance
|
771 |
+
else pd.DataFrame(
|
772 |
+
{
|
773 |
+
"Model": model_names,
|
774 |
+
"Bleu-1": best_bleu1,
|
775 |
+
"Rouge-L": best_rougeL,
|
776 |
+
"Overall Perf Score": best_f1,
|
777 |
+
}
|
778 |
+
)
|
779 |
+
)
|
780 |
+
|
781 |
+
# Melt the DataFrame to a long format
|
782 |
+
data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
|
783 |
+
|
784 |
+
# Pivot the DataFrame to a wide format
|
785 |
+
data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
|
786 |
+
|
787 |
+
# make sure the columns are following the order of the models
|
788 |
+
data_pivoted = data_pivoted[model_names]
|
789 |
+
|
790 |
+
columns = list(data.columns)
|
791 |
+
data_pivoted = data_pivoted.reindex(columns[1:])
|
792 |
+
|
793 |
+
# Plot the statistics
|
794 |
+
plt.figure(figsize=(10, 6))
|
795 |
+
ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
|
796 |
+
plt.title(title)
|
797 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
798 |
+
|
799 |
+
# Set the rotation of the x-axis labels to 0 degrees
|
800 |
+
plt.xticks(rotation=0)
|
801 |
+
|
802 |
+
# Format the y-axis to display as percentage
|
803 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
804 |
+
|
805 |
+
# get the max value of the y-axis
|
806 |
+
a1 = max(best_afrp)
|
807 |
+
a2 = max(best_f1)
|
808 |
+
a3 = max(best_bleu1)
|
809 |
+
a4 = max(best_rougeL)
|
810 |
+
|
811 |
+
max_value = (
|
812 |
+
max([a1, a2] if include_adjusted_performance else [a1, a2, a3, a4]) * 1.12
|
813 |
+
)
|
814 |
+
print("max_value:", max_value)
|
815 |
+
|
816 |
+
# Set the y-axis limit up to 70%
|
817 |
+
ax.set_ylim(0, max_value)
|
818 |
+
|
819 |
+
# Add the values above each bar
|
820 |
+
for p in ax.patches:
|
821 |
+
ax.annotate(
|
822 |
+
f"{p.get_height() * 100:.1f}",
|
823 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
824 |
+
ha="center",
|
825 |
+
va="bottom",
|
826 |
+
xytext=(0, 10),
|
827 |
+
textcoords="offset points",
|
828 |
+
rotation=90,
|
829 |
+
)
|
830 |
+
|
831 |
+
plt.show()
|
832 |
+
return data_pivoted, best_mtr
|
833 |
+
|
834 |
+
|
835 |
+
all_open_source_models = [
|
836 |
+
"gemma-1.1-2b-it",
|
837 |
+
"Phi-3-mini-128k-instruct",
|
838 |
+
"gemma-1.1-7b-it",
|
839 |
+
"Llama-2-7b-chat-hf",
|
840 |
+
"Mistral-7B-Instruct-v0.2",
|
841 |
+
"Meta-Llama-3-8B-Instruct",
|
842 |
+
"Llama-2-13b-chat-hf",
|
843 |
+
"Llama-2-70b-chat-hf",
|
844 |
+
"Meta-Llama-3-70B-Instruct",
|
845 |
+
]
|
846 |
+
|
847 |
+
|
848 |
+
def load_for_repetition_penalty_ms_macro(
|
849 |
+
csv_result_file, repetition_penalty, force_recalculate=False
|
850 |
+
):
|
851 |
+
result_file = replace_last(
|
852 |
+
csv_result_file, ".csv", f"_rpp_{repetition_penalty:.2f}.csv"
|
853 |
+
)
|
854 |
+
df = load_with_newline_and_repetition_scores(
|
855 |
+
result_file, force_recalculate=force_recalculate
|
856 |
+
)
|
857 |
+
|
858 |
+
return df
|
859 |
+
|
860 |
+
|
861 |
+
# MS MACRO
|
862 |
+
def plot_performance_scores_ms_macro(
|
863 |
+
result,
|
864 |
+
models=None,
|
865 |
+
title="Performance",
|
866 |
+
):
|
867 |
+
if models is None:
|
868 |
+
models = result.keys()
|
869 |
+
for model in models:
|
870 |
+
print(f"model: {model}")
|
871 |
+
df = result[model]["df_overall"]
|
872 |
+
# print(result[model]["df_list_repetition_penalty"][0].describe())
|
873 |
+
|
874 |
+
# Calculate the statistics
|
875 |
+
bleu1 = list(df["bleu1"])
|
876 |
+
rougeL = list(df["rougeL"])
|
877 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
|
878 |
+
best_f1 = max(f1)
|
879 |
+
best_f1_index = f1.index(best_f1)
|
880 |
+
|
881 |
+
bleu1, rougeL = adjust_perf_scores_with_repetition_penalty(
|
882 |
+
result[model], bleu1, rougeL
|
883 |
+
)
|
884 |
+
afrp = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
|
885 |
+
|
886 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
887 |
+
best_afrp = max(afrp)
|
888 |
+
best_afrp_index = afrp.index(best_afrp)
|
889 |
+
|
890 |
+
repetition_penalties = list(df["repetition_penalty"])
|
891 |
+
|
892 |
+
# line plot for precision, recall, f1
|
893 |
+
plt.figure(figsize=(10, 6))
|
894 |
+
|
895 |
+
plt.axvspan(
|
896 |
+
repetition_penalties[best_f1_index] - 0.01,
|
897 |
+
repetition_penalties[best_f1_index] + 0.01,
|
898 |
+
alpha=0.5,
|
899 |
+
edgecolor="none",
|
900 |
+
facecolor="blue",
|
901 |
+
)
|
902 |
+
|
903 |
+
plt.axvspan(
|
904 |
+
repetition_penalties[best_afrp_index] - 0.01,
|
905 |
+
repetition_penalties[best_afrp_index] + 0.01,
|
906 |
+
alpha=0.5,
|
907 |
+
edgecolor="none",
|
908 |
+
facecolor="orange",
|
909 |
+
)
|
910 |
+
|
911 |
+
plt.plot(
|
912 |
+
repetition_penalties,
|
913 |
+
f1,
|
914 |
+
label="Overall Perf Score",
|
915 |
+
marker="D",
|
916 |
+
color="blue",
|
917 |
+
)
|
918 |
+
plt.plot(
|
919 |
+
repetition_penalties,
|
920 |
+
afrp,
|
921 |
+
label="RAP - Perf Score",
|
922 |
+
marker="o",
|
923 |
+
color="orange",
|
924 |
+
)
|
925 |
+
|
926 |
+
plt.xlabel("Repetition Penalties")
|
927 |
+
plt.ylabel("Score")
|
928 |
+
# plt.xlim(0.99, 1.31)
|
929 |
+
# y in percentage
|
930 |
+
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
931 |
+
plt.title(f"{model} {title}")
|
932 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
933 |
+
|
934 |
+
plt.show()
|
935 |
+
|
936 |
+
|
937 |
+
def plot_repetition_factors(result, groups):
|
938 |
+
for group in groups:
|
939 |
+
# Plot the statistics
|
940 |
+
plt.figure(figsize=(10, 6))
|
941 |
+
|
942 |
+
max_value = 0
|
943 |
+
for model in result.keys():
|
944 |
+
if not group in model.lower():
|
945 |
+
continue
|
946 |
+
print(f"model: {model}")
|
947 |
+
df = result[model]["df_overall"]
|
948 |
+
repetition_panelties = [
|
949 |
+
repetition_penalty for repetition_penalty in df["repetition_penalty"]
|
950 |
+
]
|
951 |
+
|
952 |
+
mean_score = [
|
953 |
+
# math.log10(10 + df["total_repetitions"].mean())
|
954 |
+
df["total_repetitions"].mean()
|
955 |
+
for df in result[model]["df_list_repetition_penalty"]
|
956 |
+
]
|
957 |
+
|
958 |
+
sns.lineplot(x=repetition_panelties, y=mean_score, label=model)
|
959 |
+
|
960 |
+
new_max = max(mean_score)
|
961 |
+
if new_max > max_value:
|
962 |
+
max_value = new_max
|
963 |
+
|
964 |
+
max_value = max_value * 1.05
|
965 |
+
# if max_value < 1.5:
|
966 |
+
# max_value = 1.5
|
967 |
+
# set ylimit
|
968 |
+
plt.ylim(0, max_value)
|
969 |
+
|
970 |
+
# show grid
|
971 |
+
plt.grid(True)
|
972 |
+
plt.xlabel("Repetition Penalties")
|
973 |
+
plt.ylabel("Mean Total Repetitions")
|
974 |
+
plt.title("Mean Total Repetitions vs Repetition Penalties")
|
975 |
+
plt.legend()
|
976 |
+
|
977 |
+
plt.show()
|
978 |
+
|
979 |
+
|
980 |
+
def plot_repetition_factors_by_group(result, group_filter=None):
|
981 |
+
markers = ["D", "o", "s", "x"]
|
982 |
+
colors = ["blue", "orange", "green", "red"]
|
983 |
+
|
984 |
+
# Plot the statistics
|
985 |
+
plt.figure(figsize=(10, 6))
|
986 |
+
index = 0
|
987 |
+
max_value = 0
|
988 |
+
|
989 |
+
for model in result.keys():
|
990 |
+
if group_filter is not None and group_filter not in model:
|
991 |
+
continue
|
992 |
+
|
993 |
+
print(f"model: {model}")
|
994 |
+
|
995 |
+
df = result[model]["df_overall"]
|
996 |
+
repetition_panelties = [
|
997 |
+
repetition_penalty for repetition_penalty in df["repetition_penalty"]
|
998 |
+
]
|
999 |
+
|
1000 |
+
# Calculate the statistics
|
1001 |
+
mean_score = [
|
1002 |
+
# math.log10(10 + df["total_repetitions"].mean())
|
1003 |
+
df["total_repetitions"].mean()
|
1004 |
+
for df in result[model]["df_list_repetition_penalty"]
|
1005 |
+
]
|
1006 |
+
if len(mean_score) != len(repetition_panelties):
|
1007 |
+
print(
|
1008 |
+
f"model: {model} has different length of repetition penalties and mean score"
|
1009 |
+
)
|
1010 |
+
print("repetition_panelties:", len(repetition_panelties))
|
1011 |
+
print("mean_score:", len(mean_score))
|
1012 |
+
continue
|
1013 |
+
|
1014 |
+
new_max = max(mean_score)
|
1015 |
+
if new_max > max_value:
|
1016 |
+
max_value = new_max
|
1017 |
+
|
1018 |
+
sns.lineplot(
|
1019 |
+
x=repetition_panelties,
|
1020 |
+
y=mean_score,
|
1021 |
+
label=model,
|
1022 |
+
marker=markers[index],
|
1023 |
+
color=colors[index],
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
index += 1
|
1027 |
+
|
1028 |
+
max_value = max_value * 1.05
|
1029 |
+
# if max_value < 1.5:
|
1030 |
+
# max_value = 1.5
|
1031 |
+
# set ylimit
|
1032 |
+
plt.ylim(0, max_value)
|
1033 |
+
max_value = 0
|
1034 |
+
|
1035 |
+
plt.xlabel("Repetition Penalties")
|
1036 |
+
plt.ylabel("Mean Total Repetitions")
|
1037 |
+
plt.title("Mean Total Repetitions vs Repetition Penalties")
|
1038 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
1039 |
+
|
1040 |
+
plt.show()
|
1041 |
+
|
1042 |
+
|
1043 |
+
ms_marco_csv_result_files = [
|
1044 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Generic Prompt)_mm.csv",
|
1045 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Chat Template)_mm.csv",
|
1046 |
+
"data/results_v2/gemma-1.1-2b-it(Non-RAG)_mm.csv",
|
1047 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Generic Prompt)_mm.csv",
|
1048 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Chat Template)_mm.csv",
|
1049 |
+
"data/results_v2/Phi-3-mini-128k-instruct(Non-RAG)_mm.csv",
|
1050 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Generic Prompt)_mm.csv",
|
1051 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Chat Template)_mm.csv",
|
1052 |
+
"data/results_v2/gemma-1.1-7b-it(Non-RAG)_mm.csv",
|
1053 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Generic Prompt)_mm.csv",
|
1054 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Chat Template)_mm.csv",
|
1055 |
+
"data/results_v2/Llama-2-7b-chat-hf(Non-RAG)_mm.csv",
|
1056 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Generic Prompt)_mm.csv",
|
1057 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Chat Template)_mm.csv",
|
1058 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(Non-RAG)_mm.csv",
|
1059 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Generic Prompt)_mm.csv",
|
1060 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Chat Template)_mm.csv",
|
1061 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(Non-RAG)_mm.csv",
|
1062 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Generic Prompt)_mm.csv",
|
1063 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Chat Template)_mm.csv",
|
1064 |
+
"data/results_v2/Llama-2-13b-chat-hf(Non-RAG)_mm.csv",
|
1065 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Generic Prompt)_mm.csv",
|
1066 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Chat Template)_mm.csv",
|
1067 |
+
"data/results_v2/Llama-2-70b-chat-hf(Non-RAG)_mm.csv",
|
1068 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Generic Prompt)_mm.csv",
|
1069 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Chat Template)_mm.csv",
|
1070 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(Non-RAG)_mm.csv",
|
1071 |
+
]
|
1072 |
+
|
1073 |
+
webqsp_csv_result_files = [
|
1074 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Generic Prompt)_wd.csv",
|
1075 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Chat Template)_wd.csv",
|
1076 |
+
"data/results_v2/gemma-1.1-2b-it(Non-RAG)_wd.csv",
|
1077 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Generic Prompt)_wd.csv",
|
1078 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Chat Template)_wd.csv",
|
1079 |
+
"data/results_v2/Phi-3-mini-128k-instruct(Non-RAG)_wd.csv",
|
1080 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Generic Prompt)_wd.csv",
|
1081 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Chat Template)_wd.csv",
|
1082 |
+
"data/results_v2/gemma-1.1-7b-it(Non-RAG)_wd.csv",
|
1083 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Generic Prompt)_wd.csv",
|
1084 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Chat Template)_wd.csv",
|
1085 |
+
"data/results_v2/Llama-2-7b-chat-hf(Non-RAG)_wd.csv",
|
1086 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Generic Prompt)_wd.csv",
|
1087 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Chat Template)_wd.csv",
|
1088 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(Non-RAG)_wd.csv",
|
1089 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Generic Prompt)_wd.csv",
|
1090 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Chat Template)_wd.csv",
|
1091 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(Non-RAG)_wd.csv",
|
1092 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Generic Prompt)_wd.csv",
|
1093 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Chat Template)_wd.csv",
|
1094 |
+
"data/results_v2/Llama-2-13b-chat-hf(Non-RAG)_wd.csv",
|
1095 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Generic Prompt)_wd.csv",
|
1096 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Chat Template)_wd.csv",
|
1097 |
+
"data/results_v2/Llama-2-70b-chat-hf(Non-RAG)_wd.csv",
|
1098 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Generic Prompt)_wd.csv",
|
1099 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Chat Template)_wd.csv",
|
1100 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(Non-RAG)_wd.csv",
|
1101 |
+
]
|
1102 |
+
|
1103 |
+
|
1104 |
+
def calc_rap_scores(result, precision="precision", recall="recall"):
|
1105 |
+
newline_score = [
|
1106 |
+
df["newline_score"].mean() for df in result["df_list_repetition_penalty"]
|
1107 |
+
]
|
1108 |
+
|
1109 |
+
repetition_score = [
|
1110 |
+
df["repetition_score"].mean() for df in result["df_list_repetition_penalty"]
|
1111 |
+
]
|
1112 |
+
|
1113 |
+
if precision in result["df_list_repetition_penalty"][0].columns:
|
1114 |
+
precision = [
|
1115 |
+
df[precision].mean() for df in result["df_list_repetition_penalty"]
|
1116 |
+
]
|
1117 |
+
recall = [df[recall].mean() for df in result["df_list_repetition_penalty"]]
|
1118 |
+
else:
|
1119 |
+
precision = result["df_overall"][precision]
|
1120 |
+
recall = result["df_overall"][recall]
|
1121 |
+
|
1122 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
1123 |
+
|
1124 |
+
# rap = [
|
1125 |
+
# f / math.log10(10 + n + r)
|
1126 |
+
# for f, n, r in zip(f1, newline_score, repetition_score)
|
1127 |
+
# ]
|
1128 |
+
|
1129 |
+
nrr = [
|
1130 |
+
1 - (n + r) / s
|
1131 |
+
for f, n, r, s in zip(
|
1132 |
+
f1, newline_score, repetition_score, result["df_overall"]["answer_len"]
|
1133 |
+
)
|
1134 |
+
]
|
1135 |
+
|
1136 |
+
rap = [f * n * n * n for f, n in zip(f1, nrr)]
|
1137 |
+
|
1138 |
+
return newline_score, repetition_score, f1, rap, nrr
|
1139 |
+
|
1140 |
+
|
1141 |
+
def get_model_name(csv_result_file):
|
1142 |
+
parts = re.split(r"[_/]", csv_result_file)
|
1143 |
+
print(f"parts: {parts}")
|
1144 |
+
model_name = parts[3]
|
1145 |
+
return model_name
|
1146 |
+
|
1147 |
+
|
1148 |
+
def load_webqsp_result(csv_result_files, force_recalculate=False, save=False):
|
1149 |
+
result = {}
|
1150 |
+
for i, csv_result_file in enumerate(csv_result_files):
|
1151 |
+
try:
|
1152 |
+
df = pd.read_csv(csv_result_file)
|
1153 |
+
model_name = get_model_name(csv_result_file)
|
1154 |
+
print(f"\tmodel_name: {model_name}")
|
1155 |
+
|
1156 |
+
dfs = [
|
1157 |
+
calculate_performance_score(
|
1158 |
+
csv_result_file,
|
1159 |
+
repetition_penalty,
|
1160 |
+
force_recalculate=force_recalculate,
|
1161 |
+
)
|
1162 |
+
for repetition_penalty in df["repetition_penalty"]
|
1163 |
+
]
|
1164 |
+
|
1165 |
+
answer_lens = []
|
1166 |
+
for df_rpp in dfs:
|
1167 |
+
df_rpp["answer_len"] = df_rpp["answer"].apply(
|
1168 |
+
lambda x: len(x) if isinstance(x, str) else 0
|
1169 |
+
)
|
1170 |
+
answer_lens.append(df_rpp["answer_len"].mean())
|
1171 |
+
df["answer_len"] = answer_lens
|
1172 |
+
|
1173 |
+
result[model_name] = {
|
1174 |
+
"df_overall": df,
|
1175 |
+
"df_list_repetition_penalty": dfs,
|
1176 |
+
"file": csv_result_file,
|
1177 |
+
}
|
1178 |
+
newline_score, repetition_score, perf, rap, nrr = calc_rap_scores(
|
1179 |
+
result[model_name]
|
1180 |
+
)
|
1181 |
+
df["newline_score"] = newline_score
|
1182 |
+
df["repetition_score"] = repetition_score
|
1183 |
+
df["total_repetitions"] = df["newline_score"] + df["repetition_score"]
|
1184 |
+
df["perf"] = perf
|
1185 |
+
df["nrr"] = nrr
|
1186 |
+
df["rap"] = rap
|
1187 |
+
df["rr"] = df["nrr"].apply(lambda x: 1 - x)
|
1188 |
+
if save:
|
1189 |
+
df.to_csv(csv_result_file, index=False)
|
1190 |
+
except Exception as e:
|
1191 |
+
print(f"Error: {e}")
|
1192 |
+
traceback.print_exc()
|
1193 |
+
|
1194 |
+
return result
|
1195 |
+
|
1196 |
+
|
1197 |
+
def load_ms_marco_result(
|
1198 |
+
csv_result_files, force_recalculate=False, calc_bertscore=False, save=False
|
1199 |
+
):
|
1200 |
+
result = {}
|
1201 |
+
for csv_result_file in csv_result_files:
|
1202 |
+
try:
|
1203 |
+
df = pd.read_csv(csv_result_file)
|
1204 |
+
model_name = get_model_name(csv_result_file)
|
1205 |
+
print(f"\tmodel_name: {model_name}")
|
1206 |
+
|
1207 |
+
dfs = [
|
1208 |
+
load_for_repetition_penalty_ms_macro(
|
1209 |
+
csv_result_file,
|
1210 |
+
repetition_penalty,
|
1211 |
+
force_recalculate=force_recalculate,
|
1212 |
+
)
|
1213 |
+
for repetition_penalty in df["repetition_penalty"]
|
1214 |
+
]
|
1215 |
+
|
1216 |
+
answer_lens = []
|
1217 |
+
for df_rpp in dfs:
|
1218 |
+
answer_lens.append(df_rpp["answer_len"].mean())
|
1219 |
+
df["answer_len"] = answer_lens
|
1220 |
+
|
1221 |
+
col = "bert_score" if calc_bertscore else "meteor"
|
1222 |
+
score_unavailable = col not in df.columns
|
1223 |
+
|
1224 |
+
if score_unavailable:
|
1225 |
+
save = True
|
1226 |
+
bert_meteor_scores = []
|
1227 |
+
bert_score_references = None
|
1228 |
+
for df_rpp in dfs:
|
1229 |
+
if calc_bertscore:
|
1230 |
+
bert_meteor_score = 0
|
1231 |
+
|
1232 |
+
for i, row in df_rpp.iterrows():
|
1233 |
+
answer = row["answer"]
|
1234 |
+
if not isinstance(answer, str):
|
1235 |
+
answer = ""
|
1236 |
+
bert_meteor_score += bert_score.compute(
|
1237 |
+
predictions=[answer],
|
1238 |
+
references=[row["ground_truth"][0]],
|
1239 |
+
lang="en",
|
1240 |
+
model_type="microsoft/deberta-large-mnli",
|
1241 |
+
)["f1"][0]
|
1242 |
+
# get average of bertscore
|
1243 |
+
bert_meteor_score = bert_meteor_score / len(df_rpp)
|
1244 |
+
|
1245 |
+
print(f"bert_score: {bert_meteor_score}")
|
1246 |
+
else:
|
1247 |
+
bert_meteor_score = meteor.compute(
|
1248 |
+
predictions=df_rpp["answer"],
|
1249 |
+
references=df_rpp["ground_truth"],
|
1250 |
+
)["meteor"]
|
1251 |
+
|
1252 |
+
bert_meteor_scores.append(bert_meteor_score)
|
1253 |
+
|
1254 |
+
df[col] = bert_meteor_scores
|
1255 |
+
|
1256 |
+
result[model_name] = {
|
1257 |
+
"df_overall": df,
|
1258 |
+
"df_list_repetition_penalty": dfs,
|
1259 |
+
"file": csv_result_file,
|
1260 |
+
}
|
1261 |
+
newline_score, repetition_score, perf, rap, nrr = calc_rap_scores(
|
1262 |
+
result[model_name],
|
1263 |
+
precision=col,
|
1264 |
+
recall=col,
|
1265 |
+
)
|
1266 |
+
df["newline_score"] = newline_score
|
1267 |
+
df["repetition_score"] = repetition_score
|
1268 |
+
df["total_repetitions"] = df["newline_score"] + df["repetition_score"]
|
1269 |
+
df["perf"] = perf
|
1270 |
+
df["nrr"] = nrr
|
1271 |
+
df["rap"] = rap
|
1272 |
+
df["rr"] = df["nrr"].apply(lambda x: 1 - x)
|
1273 |
+
|
1274 |
+
if save:
|
1275 |
+
df.to_csv(csv_result_file, index=False)
|
1276 |
+
except Exception as e:
|
1277 |
+
print("An error occurred:", e)
|
1278 |
+
traceback.print_exc()
|
1279 |
+
print(f"csv_result_file: {csv_result_file}")
|
1280 |
+
|
1281 |
+
return result
|
llm_toolkit/translation_utils_v2.py
ADDED
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import glob
|
4 |
+
import pandas as pd
|
5 |
+
import evaluate
|
6 |
+
import seaborn as sns
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from datasets import load_dataset
|
9 |
+
from langchain_openai import ChatOpenAI
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from tqdm import tqdm
|
12 |
+
from eval_modules.calc_repetitions_v2d import *
|
13 |
+
from llm_toolkit.llm_utils import load_tokenizer, print_row_details
|
14 |
+
|
15 |
+
print(f"loading {__file__}")
|
16 |
+
|
17 |
+
bleu = evaluate.load("bleu")
|
18 |
+
rouge = evaluate.load("rouge")
|
19 |
+
meteor = evaluate.load("meteor")
|
20 |
+
accuracy = evaluate.load("accuracy")
|
21 |
+
sacrebleu = evaluate.load("sacrebleu")
|
22 |
+
comet = evaluate.load("comet")
|
23 |
+
|
24 |
+
|
25 |
+
def extract_answer(text, debug=False):
|
26 |
+
if text and isinstance(text, str):
|
27 |
+
# Remove the begin and end tokens
|
28 |
+
text = re.sub(
|
29 |
+
r".*?(assistant|\[/INST\]).+?\b", "", text, flags=re.DOTALL | re.MULTILINE
|
30 |
+
)
|
31 |
+
if debug:
|
32 |
+
print("--------\nstep 1:", text)
|
33 |
+
|
34 |
+
text = re.sub(r"<.+?>.*", "", text, flags=re.DOTALL | re.MULTILINE)
|
35 |
+
if debug:
|
36 |
+
print("--------\nstep 2:", text)
|
37 |
+
|
38 |
+
text = re.sub(
|
39 |
+
r".*?end_header_id\|>\n\n", "", text, flags=re.DOTALL | re.MULTILINE
|
40 |
+
)
|
41 |
+
if debug:
|
42 |
+
print("--------\nstep 3:", text)
|
43 |
+
|
44 |
+
return text
|
45 |
+
|
46 |
+
|
47 |
+
def calc_metrics(references, predictions, sources=None, debug=False):
|
48 |
+
assert len(references) == len(
|
49 |
+
predictions
|
50 |
+
), f"lengths are difference: {len(references)} != {len(predictions)}"
|
51 |
+
|
52 |
+
predictions = [extract_answer(text) for text in predictions]
|
53 |
+
results = {}
|
54 |
+
|
55 |
+
results["comet"] = comet.compute(
|
56 |
+
predictions=predictions, references=references, sources=sources
|
57 |
+
)["mean_score"]
|
58 |
+
|
59 |
+
results["meteor"] = meteor.compute(predictions=predictions, references=references)[
|
60 |
+
"meteor"
|
61 |
+
]
|
62 |
+
|
63 |
+
results["sacrebleu"] = sacrebleu.compute(
|
64 |
+
predictions=predictions, references=references
|
65 |
+
)
|
66 |
+
|
67 |
+
results["bleu_scores"] = bleu.compute(
|
68 |
+
predictions=predictions, references=references, max_order=4
|
69 |
+
)
|
70 |
+
results["rouge_scores"] = rouge.compute(
|
71 |
+
predictions=predictions, references=references
|
72 |
+
)
|
73 |
+
|
74 |
+
correct = [1 if ref == pred else 0 for ref, pred in zip(references, predictions)]
|
75 |
+
accuracy = sum(correct) / len(references)
|
76 |
+
|
77 |
+
results["accuracy"] = accuracy
|
78 |
+
if debug:
|
79 |
+
correct_ids = [i for i, c in enumerate(correct) if c == 1]
|
80 |
+
results["correct_ids"] = correct_ids
|
81 |
+
|
82 |
+
return results
|
83 |
+
|
84 |
+
|
85 |
+
def save_results(model_name, results_path, dataset, predictions, debug=False):
|
86 |
+
if not os.path.exists(results_path):
|
87 |
+
# Get the directory part of the file path
|
88 |
+
dir_path = os.path.dirname(results_path)
|
89 |
+
|
90 |
+
# Create all directories in the path (if they don't exist)
|
91 |
+
os.makedirs(dir_path, exist_ok=True)
|
92 |
+
df = dataset.to_pandas()
|
93 |
+
df.drop(columns=["text", "prompt"], inplace=True, errors="ignore")
|
94 |
+
else:
|
95 |
+
df = pd.read_csv(results_path, on_bad_lines="warn")
|
96 |
+
|
97 |
+
df[model_name] = predictions
|
98 |
+
|
99 |
+
if debug:
|
100 |
+
print(df.head(1))
|
101 |
+
|
102 |
+
df.to_csv(results_path, index=False)
|
103 |
+
|
104 |
+
|
105 |
+
system_prompt = "You are a helpful assistant that translates Chinese to English."
|
106 |
+
|
107 |
+
|
108 |
+
def get_few_shot_prompt(dataset, num_shots=5):
|
109 |
+
translation_prompt = "You will be given a Chinese sentence to translate. If it is an incomplete sentence, or if you are unsure about the meaning, simply copy the input text as your output. Do not output any additional sentence such as explanation or reasoning.\n\n"
|
110 |
+
if num_shots > 0:
|
111 |
+
example_translations = "Example Translations:\n"
|
112 |
+
for i in range(num_shots):
|
113 |
+
example_translations += f"Chinese: {dataset[i]['chinese']}\n"
|
114 |
+
example_translations += f"English: {dataset[i]['english']}\n"
|
115 |
+
translation_prompt = translation_prompt + example_translations + "\n"
|
116 |
+
|
117 |
+
translation_prompt = translation_prompt + "Chinese: {input}\nEnglish:"
|
118 |
+
return translation_prompt
|
119 |
+
|
120 |
+
|
121 |
+
def load_translation_dataset(
|
122 |
+
data_path, tokenizer=None, num_shots=0, for_openai=False, using_chat_template=True
|
123 |
+
):
|
124 |
+
train_data_file = data_path.replace(".tsv", "-train.tsv")
|
125 |
+
test_data_file = data_path.replace(".tsv", "-test.tsv")
|
126 |
+
|
127 |
+
if not os.path.exists(train_data_file):
|
128 |
+
print("generating train/test data files")
|
129 |
+
dataset = load_dataset(
|
130 |
+
"csv", data_files=data_path, delimiter="\t", split="train"
|
131 |
+
)
|
132 |
+
print(len(dataset))
|
133 |
+
dataset = dataset.filter(lambda x: x["chinese"] and x["english"])
|
134 |
+
|
135 |
+
datasets = dataset.train_test_split(test_size=0.2)
|
136 |
+
print(len(dataset))
|
137 |
+
|
138 |
+
# Convert to pandas DataFrame
|
139 |
+
train_df = pd.DataFrame(datasets["train"])
|
140 |
+
test_df = pd.DataFrame(datasets["test"])
|
141 |
+
|
142 |
+
# Save to TSV
|
143 |
+
train_df.to_csv(train_data_file, sep="\t", index=False)
|
144 |
+
test_df.to_csv(test_data_file, sep="\t", index=False)
|
145 |
+
|
146 |
+
print("loading train/test data files")
|
147 |
+
datasets = load_dataset(
|
148 |
+
"csv",
|
149 |
+
data_files={"train": train_data_file, "test": test_data_file},
|
150 |
+
delimiter="\t",
|
151 |
+
)
|
152 |
+
|
153 |
+
if tokenizer or for_openai:
|
154 |
+
translation_prompt = get_few_shot_prompt(datasets["train"], num_shots)
|
155 |
+
|
156 |
+
def formatting_prompts_func(examples):
|
157 |
+
inputs = examples["chinese"]
|
158 |
+
outputs = examples["english"]
|
159 |
+
|
160 |
+
messages = [
|
161 |
+
{
|
162 |
+
"role": "system",
|
163 |
+
"content": system_prompt,
|
164 |
+
},
|
165 |
+
None,
|
166 |
+
]
|
167 |
+
|
168 |
+
model_name = os.getenv("MODEL_NAME")
|
169 |
+
|
170 |
+
# if "mistral" in model_name.lower():
|
171 |
+
# messages = messages[1:]
|
172 |
+
|
173 |
+
texts = []
|
174 |
+
prompts = []
|
175 |
+
for input, output in zip(inputs, outputs):
|
176 |
+
prompt = translation_prompt.format(input=input)
|
177 |
+
messages[-1] = {"role": "user", "content": prompt}
|
178 |
+
|
179 |
+
if for_openai:
|
180 |
+
prompts.append(messages.copy())
|
181 |
+
text = messages.copy()
|
182 |
+
text.append(
|
183 |
+
{
|
184 |
+
"role": "assistant",
|
185 |
+
"content": output,
|
186 |
+
}
|
187 |
+
)
|
188 |
+
texts.append(text)
|
189 |
+
else:
|
190 |
+
prompt = (
|
191 |
+
tokenizer.apply_chat_template(
|
192 |
+
messages, tokenize=False, add_generation_prompt=True
|
193 |
+
)
|
194 |
+
if using_chat_template
|
195 |
+
else prompt
|
196 |
+
)
|
197 |
+
|
198 |
+
prompts.append(prompt)
|
199 |
+
texts.append(prompt + output + tokenizer.eos_token)
|
200 |
+
|
201 |
+
return {"text": texts, "prompt": prompts}
|
202 |
+
|
203 |
+
datasets = datasets.map(
|
204 |
+
formatting_prompts_func,
|
205 |
+
batched=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
print(datasets)
|
209 |
+
return datasets
|
210 |
+
|
211 |
+
|
212 |
+
def count_entries_with_max_tokens(entries, max_tokens):
|
213 |
+
"""
|
214 |
+
Count the number of entries with the max output tokens or more.
|
215 |
+
|
216 |
+
Parameters:
|
217 |
+
entries (list of int): List of token counts for each entry.
|
218 |
+
max_tokens (int): The maximum token threshold.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
int: The number of entries with token counts greater than or equal to max_tokens.
|
222 |
+
"""
|
223 |
+
count = 0
|
224 |
+
for tokens in entries:
|
225 |
+
if tokens >= max_tokens:
|
226 |
+
count += 1
|
227 |
+
return count
|
228 |
+
|
229 |
+
|
230 |
+
def detect_repetition_scores(row, col, debug=False):
|
231 |
+
# print(f"row: {row}")
|
232 |
+
text = row[col] if isinstance(row[col], str) else ""
|
233 |
+
newline_score, repetition_score, total_repetitions = detect_repetitions(
|
234 |
+
text, debug=debug
|
235 |
+
)
|
236 |
+
newline_score -= row["ground_truth_ews_score"]
|
237 |
+
repetition_score -= row["ground_truth_repetition_score"]
|
238 |
+
total_repetitions -= row["ground_truth_total_repetitions"]
|
239 |
+
|
240 |
+
return pd.Series(
|
241 |
+
[
|
242 |
+
newline_score if newline_score > 0 else 0,
|
243 |
+
repetition_score if repetition_score > 0 else 0,
|
244 |
+
total_repetitions if total_repetitions > 0 else 0,
|
245 |
+
len(text),
|
246 |
+
]
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
def count_chinese_characters(text):
|
251 |
+
if isinstance(text, str) is False:
|
252 |
+
return 0
|
253 |
+
|
254 |
+
# Define a regular expression pattern for Chinese characters
|
255 |
+
chinese_char_pattern = r"[\u4e00-\u9fff]"
|
256 |
+
|
257 |
+
# Use re.findall to find all Chinese characters in the text
|
258 |
+
chinese_chars = re.findall(chinese_char_pattern, text)
|
259 |
+
|
260 |
+
# Return the count of Chinese characters
|
261 |
+
return len(chinese_chars)
|
262 |
+
|
263 |
+
|
264 |
+
def get_metrics(df, max_output_tokens=2048, variant="rpp", existing_metrics_df=None):
|
265 |
+
metrics_df = pd.DataFrame(df.columns.T)[2:]
|
266 |
+
metrics_df.rename(columns={0: "model"}, inplace=True)
|
267 |
+
metrics_df[variant] = metrics_df["model"].apply(
|
268 |
+
lambda x: x.split(f"{variant}-")[-1]
|
269 |
+
)
|
270 |
+
metrics_df["model"] = metrics_df["model"].apply(
|
271 |
+
lambda x: x.split(f"/{variant}-")[0].split("/checkpoint")[0]
|
272 |
+
)
|
273 |
+
|
274 |
+
metrics_df.reset_index(inplace=True)
|
275 |
+
metrics_df = metrics_df.drop(columns=["index"])
|
276 |
+
|
277 |
+
models = [
|
278 |
+
model
|
279 |
+
for model in metrics_df["model"].unique()
|
280 |
+
if ("/" in model or "gpt" in model)
|
281 |
+
and "ground_truth_" not in model
|
282 |
+
and "count_" not in model
|
283 |
+
and "output_" not in model
|
284 |
+
]
|
285 |
+
print(models)
|
286 |
+
|
287 |
+
tokenizers = {model: load_tokenizer(model) for model in models}
|
288 |
+
|
289 |
+
comet = []
|
290 |
+
meteor = []
|
291 |
+
spbleu = []
|
292 |
+
bleu_1 = []
|
293 |
+
rouge_l = []
|
294 |
+
ews_score = []
|
295 |
+
repetition_score = []
|
296 |
+
total_repetitions = []
|
297 |
+
nrr = []
|
298 |
+
num_max_output_tokens = []
|
299 |
+
translation_completeness = []
|
300 |
+
columns = df.columns[2:]
|
301 |
+
|
302 |
+
df[
|
303 |
+
[
|
304 |
+
"ground_truth_ews_score",
|
305 |
+
"ground_truth_repetition_score",
|
306 |
+
"ground_truth_total_repetitions",
|
307 |
+
]
|
308 |
+
] = df["english"].apply(detect_scores)
|
309 |
+
|
310 |
+
new_col = f"count_chinese_characters-ground_truth"
|
311 |
+
df[new_col] = df["chinese"].apply(count_chinese_characters)
|
312 |
+
|
313 |
+
for col in columns:
|
314 |
+
metrics = None
|
315 |
+
if existing_metrics_df is not None:
|
316 |
+
parts = col.split(f"/{variant}-")
|
317 |
+
if len(parts) == 1:
|
318 |
+
break
|
319 |
+
print(parts)
|
320 |
+
val = float(parts[1]) if variant == "rpp" else int(parts[1])
|
321 |
+
result = existing_metrics_df[
|
322 |
+
existing_metrics_df["model"] == parts[0].split("/checkpoint")[0]
|
323 |
+
]
|
324 |
+
|
325 |
+
for i, row in result.iterrows():
|
326 |
+
# print(i, row[variant], val)
|
327 |
+
if row[variant] == val:
|
328 |
+
print(f"Using existing metrics for {col}")
|
329 |
+
metrics = row.to_dict()
|
330 |
+
# print(metrics)
|
331 |
+
break
|
332 |
+
|
333 |
+
if metrics is None:
|
334 |
+
print(f"Calculating metrics for {col}")
|
335 |
+
metrics = calc_metrics(
|
336 |
+
df["english"], df[col], sources=df["chinese"], debug=True
|
337 |
+
)
|
338 |
+
print(f"{col}: {metrics}")
|
339 |
+
|
340 |
+
comet.append(metrics["comet"])
|
341 |
+
meteor.append(metrics["meteor"])
|
342 |
+
spbleu.append(
|
343 |
+
metrics["spbleu"] if "spbleu" in metrics else metrics["sacrebleu"]["score"]
|
344 |
+
)
|
345 |
+
bleu_1.append(
|
346 |
+
metrics["bleu_1"] if "bleu_1" in metrics else metrics["bleu_scores"]["bleu"]
|
347 |
+
)
|
348 |
+
rouge_l.append(
|
349 |
+
metrics["rouge_l"]
|
350 |
+
if "rouge_l" in metrics
|
351 |
+
else metrics["rouge_scores"]["rougeL"]
|
352 |
+
)
|
353 |
+
|
354 |
+
df[["ews_score", "repetition_score", "total_repetitions", "answer_len"]] = df.apply(
|
355 |
+
lambda x: detect_repetition_scores(x, col), axis=1
|
356 |
+
)
|
357 |
+
ews_score.append(df["ews_score"].mean())
|
358 |
+
repetition_score.append(df["repetition_score"].mean())
|
359 |
+
total_repetitions.append(df["total_repetitions"].mean())
|
360 |
+
|
361 |
+
nrr.append(1 - df["total_repetitions"].mean() / df["answer_len"].mean())
|
362 |
+
|
363 |
+
model = col.split(f"/{variant}")[0].split("/checkpoint")[0]
|
364 |
+
|
365 |
+
new_col = f"ground_truth_tokens-{model}"
|
366 |
+
df[new_col] = df["english"].apply(
|
367 |
+
lambda x: len(tokenizers[model](x)["input_ids"])
|
368 |
+
)
|
369 |
+
|
370 |
+
new_col = f"count_chinese_characters-{col}"
|
371 |
+
df[new_col] = df[col].apply(
|
372 |
+
lambda x: 1 if count_chinese_characters(x) > 0 else 0
|
373 |
+
)
|
374 |
+
translation_completeness.append(1 - df[new_col].sum() / len(df))
|
375 |
+
|
376 |
+
new_col = f"output_tokens-{col}"
|
377 |
+
df[new_col] = df[col].apply(
|
378 |
+
lambda x: (
|
379 |
+
len(tokenizers[model](x)["input_ids"]) if isinstance(x, str) else 0
|
380 |
+
)
|
381 |
+
)
|
382 |
+
|
383 |
+
num_max_output_tokens.append(
|
384 |
+
count_entries_with_max_tokens(df[new_col], max_output_tokens)
|
385 |
+
)
|
386 |
+
|
387 |
+
metrics_df["comet"] = comet
|
388 |
+
metrics_df["meteor"] = meteor
|
389 |
+
metrics_df["spbleu"] = spbleu
|
390 |
+
metrics_df["bleu_1"] = bleu_1
|
391 |
+
metrics_df["rouge_l"] = rouge_l
|
392 |
+
metrics_df["ews_score"] = ews_score
|
393 |
+
metrics_df["repetition_score"] = repetition_score
|
394 |
+
metrics_df["total_repetitions"] = total_repetitions
|
395 |
+
metrics_df["nrr"] = nrr
|
396 |
+
metrics_df["rap"] = metrics_df.apply(
|
397 |
+
lambda x: x["comet"] * math.exp(x["nrr"] - 1), axis=1
|
398 |
+
)
|
399 |
+
|
400 |
+
metrics_df["translation_completeness"] = translation_completeness
|
401 |
+
metrics_df["num_max_output_tokens"] = num_max_output_tokens
|
402 |
+
|
403 |
+
if variant != "rpp":
|
404 |
+
metrics_df[variant] = metrics_df[variant].astype(int)
|
405 |
+
|
406 |
+
return metrics_df
|
407 |
+
|
408 |
+
|
409 |
+
def analyze_translation_results(df, col, max_new_tokens=300, repetition_threshold=100):
|
410 |
+
df[["ews_score", "repetition_score", "total_repetitions", "answer_len"]] = df.apply(
|
411 |
+
lambda x: detect_repetition_scores(x, col), axis=1
|
412 |
+
)
|
413 |
+
rows = df.query(f"total_repetitions > {repetition_threshold}")
|
414 |
+
print(
|
415 |
+
f"*** Found {len(rows)} rows with total_repetitions > {repetition_threshold} for {col}"
|
416 |
+
)
|
417 |
+
|
418 |
+
for i in range(len(rows)):
|
419 |
+
row = rows.iloc[i]
|
420 |
+
print(row["chinese"])
|
421 |
+
print("=" * 80)
|
422 |
+
print(row["english"])
|
423 |
+
print("=" * 80)
|
424 |
+
output = row[col]
|
425 |
+
print(output)
|
426 |
+
print("=" * 80)
|
427 |
+
detect_repetitions(output, debug=True)
|
428 |
+
|
429 |
+
output_tokens = f"output_tokens-{col}"
|
430 |
+
df2 = df[df[output_tokens] >= max_new_tokens][
|
431 |
+
["chinese", "english", col, output_tokens]
|
432 |
+
]
|
433 |
+
|
434 |
+
print(
|
435 |
+
f"\n*** Found {len(df2)} rows with output_tokens >= {max_new_tokens} for {col}"
|
436 |
+
)
|
437 |
+
print_row_details(df2, range(len(df2)))
|
438 |
+
|
439 |
+
count_chinese_characters = f"count_chinese_characters-{col}"
|
440 |
+
df3 = df[df[count_chinese_characters] > 0][
|
441 |
+
["chinese", "english", col, count_chinese_characters]
|
442 |
+
]
|
443 |
+
|
444 |
+
print(f"\n*** Found {len(df3)} rows with incomplete translations for {col}")
|
445 |
+
print_row_details(df3, range(len(df3)))
|
446 |
+
|
447 |
+
|
448 |
+
def plot_metrics(metrics_df, figsize=(14, 5), ylim=(0, 0.44)):
|
449 |
+
plt.figure(figsize=figsize)
|
450 |
+
df_melted = pd.melt(
|
451 |
+
metrics_df, id_vars="model", value_vars=["meteor", "bleu_1", "rouge_l"]
|
452 |
+
)
|
453 |
+
|
454 |
+
barplot = sns.barplot(x="variable", y="value", hue="model", data=df_melted)
|
455 |
+
|
456 |
+
# Set different hatches for each model
|
457 |
+
hatches = ["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*", "//", "\\\\"]
|
458 |
+
|
459 |
+
# Create a dictionary to map models to hatches
|
460 |
+
model_hatches = {
|
461 |
+
model: hatches[i % len(hatches)]
|
462 |
+
for i, model in enumerate(metrics_df["model"].unique())
|
463 |
+
}
|
464 |
+
|
465 |
+
# Apply hatches based on the model
|
466 |
+
num_vars = len(df_melted["variable"].unique())
|
467 |
+
for i, bar in enumerate(barplot.patches):
|
468 |
+
model = df_melted["model"].iloc[i // num_vars]
|
469 |
+
bar.set_hatch(model_hatches[model])
|
470 |
+
|
471 |
+
# Manually update legend to match the bar hatches
|
472 |
+
handles, labels = barplot.get_legend_handles_labels()
|
473 |
+
for handle, model in zip(handles, metrics_df["model"].unique()):
|
474 |
+
handle.set_hatch(model_hatches[model])
|
475 |
+
|
476 |
+
barplot.set_xticklabels(["METEOR", "BLEU-1", "ROUGE-L"])
|
477 |
+
for p in barplot.patches:
|
478 |
+
if p.get_height() == 0:
|
479 |
+
continue
|
480 |
+
barplot.annotate(
|
481 |
+
f"{p.get_height():.2f}",
|
482 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
483 |
+
ha="center",
|
484 |
+
va="center",
|
485 |
+
xytext=(0, 10),
|
486 |
+
textcoords="offset points",
|
487 |
+
)
|
488 |
+
|
489 |
+
barplot.set(ylim=ylim, ylabel="Scores", xlabel="Metrics")
|
490 |
+
plt.legend(bbox_to_anchor=(0.5, -0.1), loc="upper center")
|
491 |
+
plt.show()
|
492 |
+
|
493 |
+
|
494 |
+
def plot_times(perf_df, ylim=0.421):
|
495 |
+
# Adjusted code to put "train-time" bars in red at the bottom
|
496 |
+
|
497 |
+
fig, ax1 = plt.subplots(figsize=(12, 10))
|
498 |
+
|
499 |
+
color_train = "tab:red"
|
500 |
+
color_eval = "orange"
|
501 |
+
ax1.set_xlabel("Models")
|
502 |
+
ax1.set_ylabel("Time (mins)")
|
503 |
+
ax1.set_xticks(range(len(perf_df["model"]))) # Set x-ticks positions
|
504 |
+
ax1.set_xticklabels(perf_df["model"], rotation=90)
|
505 |
+
|
506 |
+
# Plot "train-time" first so it's at the bottom
|
507 |
+
ax1.bar(
|
508 |
+
perf_df["model"],
|
509 |
+
perf_df["train-time(mins)"],
|
510 |
+
color=color_train,
|
511 |
+
label="train-time",
|
512 |
+
)
|
513 |
+
|
514 |
+
# Then, plot "eval-time" on top of "train-time"
|
515 |
+
ax1.bar(
|
516 |
+
perf_df["model"],
|
517 |
+
perf_df["eval-time(mins)"],
|
518 |
+
bottom=perf_df["train-time(mins)"],
|
519 |
+
color=color_eval,
|
520 |
+
label="eval-time",
|
521 |
+
)
|
522 |
+
|
523 |
+
ax1.tick_params(axis="y")
|
524 |
+
ax1.legend(loc="upper left")
|
525 |
+
|
526 |
+
if "meteor" in perf_df.columns:
|
527 |
+
ax2 = ax1.twinx()
|
528 |
+
color_meteor = "tab:blue"
|
529 |
+
ax2.set_ylabel("METEOR", color=color_meteor)
|
530 |
+
ax2.plot(
|
531 |
+
perf_df["model"],
|
532 |
+
perf_df["meteor"],
|
533 |
+
color=color_meteor,
|
534 |
+
marker="o",
|
535 |
+
label="meteor",
|
536 |
+
)
|
537 |
+
ax2.tick_params(axis="y", labelcolor=color_meteor)
|
538 |
+
ax2.legend(loc="upper right")
|
539 |
+
ax2.set_ylim(ax2.get_ylim()[0], ylim)
|
540 |
+
|
541 |
+
# Show numbers in bars
|
542 |
+
for p in ax1.patches:
|
543 |
+
height = p.get_height()
|
544 |
+
if height == 0: # Skip bars with height 0
|
545 |
+
continue
|
546 |
+
ax1.annotate(
|
547 |
+
f"{height:.2f}",
|
548 |
+
(p.get_x() + p.get_width() / 2.0, p.get_y() + height),
|
549 |
+
ha="center",
|
550 |
+
va="center",
|
551 |
+
xytext=(0, -10),
|
552 |
+
textcoords="offset points",
|
553 |
+
)
|
554 |
+
|
555 |
+
fig.tight_layout()
|
556 |
+
plt.show()
|
557 |
+
|
558 |
+
|
559 |
+
def translate_via_openai(
|
560 |
+
text, translation_prompt, max_tokens=None, model="gpt-4o-mini", base_url=None
|
561 |
+
):
|
562 |
+
llm = ChatOpenAI(
|
563 |
+
model=model,
|
564 |
+
temperature=0,
|
565 |
+
max_tokens=max_tokens,
|
566 |
+
timeout=None,
|
567 |
+
max_retries=2,
|
568 |
+
base_url=base_url,
|
569 |
+
)
|
570 |
+
|
571 |
+
prompt = ChatPromptTemplate.from_messages(
|
572 |
+
[
|
573 |
+
(
|
574 |
+
"system",
|
575 |
+
"You are a helpful assistant that translates Chinese to English.",
|
576 |
+
),
|
577 |
+
(
|
578 |
+
"human",
|
579 |
+
translation_prompt,
|
580 |
+
),
|
581 |
+
]
|
582 |
+
)
|
583 |
+
|
584 |
+
chain = prompt | llm
|
585 |
+
response = chain.invoke(
|
586 |
+
{
|
587 |
+
"input": text,
|
588 |
+
}
|
589 |
+
)
|
590 |
+
|
591 |
+
return response.content
|
592 |
+
|
593 |
+
|
594 |
+
def eval_openai(num_shots, datasets, model="gpt-4o-mini", max_new_tokens=300):
|
595 |
+
translation_prompt = get_few_shot_prompt(datasets["train"], num_shots=num_shots)
|
596 |
+
eval_dataset = datasets["test"]
|
597 |
+
total = len(eval_dataset)
|
598 |
+
predictions = []
|
599 |
+
|
600 |
+
for i in tqdm(range(total)):
|
601 |
+
output = translate_via_openai(
|
602 |
+
eval_dataset["chinese"][i],
|
603 |
+
translation_prompt,
|
604 |
+
model=model,
|
605 |
+
max_tokens=max_new_tokens,
|
606 |
+
)
|
607 |
+
predictions.append(output)
|
608 |
+
|
609 |
+
return predictions
|
610 |
+
|
611 |
+
|
612 |
+
def convert_time_to_seconds(time_str):
|
613 |
+
# print(f"converting time_str: {time_str}")
|
614 |
+
# Split the time string into its components
|
615 |
+
time_parts = list(map(int, time_str.split(":")))
|
616 |
+
|
617 |
+
# Initialize total minutes
|
618 |
+
total_seconds = 0
|
619 |
+
|
620 |
+
# Calculate total minutes based on the number of parts
|
621 |
+
if len(time_parts) == 3: # HH:MM:SS
|
622 |
+
hours, minutes, seconds = time_parts
|
623 |
+
total_seconds = hours * 3600 + minutes * 60 + seconds
|
624 |
+
elif len(time_parts) == 2: # MM:SS
|
625 |
+
minutes, seconds = time_parts
|
626 |
+
total_seconds = minutes * 60 + seconds
|
627 |
+
elif len(time_parts) == 1: # SS
|
628 |
+
seconds = time_parts[0]
|
629 |
+
total_seconds = seconds
|
630 |
+
|
631 |
+
return total_seconds
|
632 |
+
|
633 |
+
|
634 |
+
def process_log_file(log_file, total_entries, variant):
|
635 |
+
time_pattern = re.compile(r"\[(.{5,10})<00:00")
|
636 |
+
metrics_pattern = re.compile(rf"(.*)/{variant}-(.*) metrics:")
|
637 |
+
|
638 |
+
model = []
|
639 |
+
shots = []
|
640 |
+
eval_time = []
|
641 |
+
|
642 |
+
i = 0
|
643 |
+
|
644 |
+
with open(log_file, "r") as f:
|
645 |
+
try:
|
646 |
+
for line in f:
|
647 |
+
i += 1
|
648 |
+
matches = time_pattern.search(line)
|
649 |
+
if matches:
|
650 |
+
time_pattern_matches = matches
|
651 |
+
else:
|
652 |
+
matches = metrics_pattern.search(line)
|
653 |
+
if matches:
|
654 |
+
metrics_pattern_matches = matches
|
655 |
+
groups = metrics_pattern_matches.groups()
|
656 |
+
|
657 |
+
model.append(groups[0].split("/checkpoint")[0])
|
658 |
+
shots.append(groups[1])
|
659 |
+
|
660 |
+
groups = time_pattern_matches.groups()
|
661 |
+
time_str = groups[0]
|
662 |
+
eval_time.append(
|
663 |
+
convert_time_to_seconds(time_str) / total_entries
|
664 |
+
)
|
665 |
+
except Exception as e:
|
666 |
+
print(f"Error processing log file: {log_file} at line {i}: {line}")
|
667 |
+
print(e)
|
668 |
+
|
669 |
+
df = pd.DataFrame(
|
670 |
+
{
|
671 |
+
"model": model,
|
672 |
+
variant: shots,
|
673 |
+
"eval_time": eval_time,
|
674 |
+
}
|
675 |
+
)
|
676 |
+
return df
|
677 |
+
|
678 |
+
|
679 |
+
def load_eval_times(logs_folder, total_entries=1133, variant="shots"):
|
680 |
+
# Get a list of all files in the logs folder
|
681 |
+
log_files = glob.glob(os.path.join(logs_folder, "*"))
|
682 |
+
log_files.sort()
|
683 |
+
|
684 |
+
time_df = pd.DataFrame({"model": [], variant: [], "eval_time": []})
|
685 |
+
|
686 |
+
for log_file in log_files:
|
687 |
+
print(f"Loading content of {log_file}")
|
688 |
+
df = process_log_file(log_file, total_entries, variant)
|
689 |
+
time_df = pd.concat([time_df, df], ignore_index=True)
|
690 |
+
|
691 |
+
time_df[variant] = time_df[variant].apply(
|
692 |
+
lambda x: x if variant == "rpp" else int(x)
|
693 |
+
)
|
694 |
+
# Keep the last occurrence of each duplicate
|
695 |
+
return time_df.drop_duplicates(subset=["model", variant], keep="last")
|
696 |
+
|
697 |
+
|
698 |
+
def load_alpaca_data(data_path):
|
699 |
+
alpaca_data_path = "data/alpaca_mac.json"
|
700 |
+
|
701 |
+
if os.path.exists(alpaca_data_path):
|
702 |
+
print("loading existing data from:", alpaca_data_path)
|
703 |
+
data = pd.read_json(alpaca_data_path, orient="records", lines=False)
|
704 |
+
return data
|
705 |
+
|
706 |
+
datasets = load_translation_dataset(data_path)
|
707 |
+
prompt_template = get_few_shot_prompt(datasets["train"], num_shots=0)
|
708 |
+
|
709 |
+
df_train = datasets["train"].to_pandas()
|
710 |
+
df_train["instruction"] = df_train.apply(
|
711 |
+
lambda x: prompt_template.format(input=x["chinese"]), axis=1
|
712 |
+
)
|
713 |
+
|
714 |
+
df_alpaca = pd.DataFrame(
|
715 |
+
{
|
716 |
+
"system": [system_prompt] * len(df_train),
|
717 |
+
"instruction": df_train["instruction"].to_list(),
|
718 |
+
"input": [""] * len(df_train),
|
719 |
+
"output": df_train["english"].to_list(),
|
720 |
+
}
|
721 |
+
)
|
722 |
+
|
723 |
+
df_alpaca.to_json(alpaca_data_path, orient="records", lines=False, indent=2)
|
724 |
+
|
725 |
+
return df_alpaca
|
726 |
+
|
727 |
+
|
728 |
+
def load_openai_training_data(
|
729 |
+
data_path, openai_data_path="datasets/mac/openai-training.jsonl"
|
730 |
+
):
|
731 |
+
if os.path.exists(openai_data_path):
|
732 |
+
print("loading existing data from:", openai_data_path)
|
733 |
+
data = pd.read_json(openai_data_path, orient="records", lines=True)
|
734 |
+
return data
|
735 |
+
|
736 |
+
datasets = load_translation_dataset(data_path)
|
737 |
+
prompt_template = get_few_shot_prompt(datasets["train"], num_shots=0)
|
738 |
+
|
739 |
+
df_train = datasets["train"].to_pandas()
|
740 |
+
messages = []
|
741 |
+
|
742 |
+
for i, row in df_train.iterrows():
|
743 |
+
messages.append(
|
744 |
+
[
|
745 |
+
{
|
746 |
+
"role": "system",
|
747 |
+
"content": system_prompt,
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"role": "user",
|
751 |
+
"content": prompt_template.format(input=row["chinese"]),
|
752 |
+
},
|
753 |
+
{
|
754 |
+
"role": "assistant",
|
755 |
+
"content": row["english"],
|
756 |
+
},
|
757 |
+
]
|
758 |
+
)
|
759 |
+
|
760 |
+
df_openai = pd.DataFrame(
|
761 |
+
{
|
762 |
+
"messages": messages,
|
763 |
+
}
|
764 |
+
)
|
765 |
+
df_openai.to_json(openai_data_path, orient="records", lines=True)
|
766 |
+
return df_openai
|
notebooks/00f_Data Analysis_Fine_Tuned_RPP_Generic_Prompt.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/03a_RAPGeT_v2_Data Analysis_Chat_Template.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/03b_RAPGeT_v2_Data Analysis_Generic_Prompt.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/mac-results_rpp_with_mnt_2048_generic_prompt_metrics.csv
CHANGED
@@ -1,21 +1,25 @@
|
|
1 |
-
model,rpp,comet,meteor,spbleu,bleu_1,rouge_l,ews_score,repetition_score,total_repetitions,rap,translation_completeness,num_max_output_tokens
|
2 |
-
|
3 |
-
internlm/internlm2_5-7b-chat,1.
|
4 |
-
internlm/internlm2_5-7b-chat,1.
|
5 |
-
internlm/internlm2_5-7b-chat,1.
|
6 |
-
internlm/internlm2_5-7b-chat,1.
|
7 |
-
internlm/internlm2_5-7b-chat,1.
|
8 |
-
|
9 |
-
microsoft/Phi-3.5-mini-instruct,1.
|
10 |
-
microsoft/Phi-3.5-mini-instruct,1.
|
11 |
-
microsoft/Phi-3.5-mini-instruct,1.
|
12 |
-
microsoft/Phi-3.5-mini-instruct,1.
|
13 |
-
microsoft/Phi-3.5-mini-instruct,1.
|
14 |
-
|
15 |
-
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.
|
16 |
-
shenzhi-wang/
|
17 |
-
shenzhi-wang/
|
18 |
-
shenzhi-wang/
|
19 |
-
shenzhi-wang/
|
20 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.
|
21 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.
|
|
|
|
|
|
|
|
|
|
1 |
+
model,rpp,comet,meteor,spbleu,bleu_1,rouge_l,ews_score,repetition_score,total_repetitions,nrr,rap,translation_completeness,num_max_output_tokens
|
2 |
+
internlm/internlm2_5-7b-chat,1.00,0.7357995069773978,0.4297612514398102,15.060226683930628,0.1506022668393063,0.4097577795330234,0.04942630185348632,9.235657546337158,9.285083848190645,0.9247496423462088,0.6824623187873116,1.0,2
|
3 |
+
internlm/internlm2_5-7b-chat,1.02,0.7377187550620283,0.4246676977198055,14.728605282752795,0.147286052827528,0.4063246630867048,0.06972639011473963,5.35657546337158,5.4210061782877315,0.953789668507456,0.7044041934116965,1.0,1
|
4 |
+
internlm/internlm2_5-7b-chat,1.04,0.7371160490183523,0.4173352728374962,13.846403511622256,0.1384640351162226,0.3988121301027288,0.06884377758164166,5.315092674315975,5.383053839364519,0.9549885977018281,0.7046730517422506,1.0,1
|
5 |
+
internlm/internlm2_5-7b-chat,1.06,0.7338597697698218,0.3997609847704189,12.213374588416173,0.1221337458841617,0.3841365748920261,0.05825242718446602,5.275375110326567,5.332744924977935,0.9561764257893248,0.7023939203825759,1.0,1
|
6 |
+
internlm/internlm2_5-7b-chat,1.08,0.7318234702626478,0.3881614120395272,11.369735763522288,0.1136973576352228,0.372963223209074,0.06707855251544571,5.283318623124448,5.345101500441306,0.9570359334539392,0.7010472282786626,1.0,1
|
7 |
+
internlm/internlm2_5-7b-chat,1.10,0.7288648442604431,0.3784182249483568,10.377989030628608,0.103779890306286,0.3618424457502351,0.05207413945278023,5.288614298323036,5.340688437775817,0.957823935317488,0.6987634348896543,1.0,1
|
8 |
+
microsoft/Phi-3.5-mini-instruct,1.00,0.710605339281136,0.3788926591792472,9.70032874202361,0.097003287420236,0.3556134739443916,5.390997352162401,12.997352162400706,18.368049426301855,0.8624429902835613,0.6192816563968827,1.0,4
|
9 |
+
microsoft/Phi-3.5-mini-instruct,1.02,0.7150978385770836,0.3741049510326346,9.910633597905436,0.0991063359790543,0.3453160556383774,3.586054721977052,7.001765225066196,10.567519858781994,0.9183516206245184,0.6590312746582365,1.0,2
|
10 |
+
microsoft/Phi-3.5-mini-instruct,1.04,0.7074641684778791,0.3538698731015666,9.19721270538052,0.0919721270538052,0.3225824135517728,0.05119152691968226,0.05560458958517211,0.10150044130626655,0.9991834532118691,0.7068867266696017,1.0,0
|
11 |
+
microsoft/Phi-3.5-mini-instruct,1.06,0.6962301708225224,0.3252854575717334,6.967166383106307,0.069671663831063,0.2948764736589108,0.0353045013239188,0.06796116504854369,0.09796999117387467,0.9992538065947363,0.6957108422443861,1.0,0
|
12 |
+
microsoft/Phi-3.5-mini-instruct,1.08,0.6823413657174107,0.301599095293242,5.452744292893752,0.0545274429289375,0.2726387617958179,0.07678729037952339,0.04766107678729038,0.11297440423654016,0.9991814653050001,0.6817830741574145,1.0,0
|
13 |
+
microsoft/Phi-3.5-mini-instruct,1.10,0.6717851540206916,0.2885734336603344,4.751039447225815,0.0475103944722581,0.2604284999048123,0.08031774051191527,0.02383053839364519,0.10414827890556046,0.999281171568508,0.6713024292710505,1.0,0
|
14 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.00,0.739080294072365,0.4490104515425626,6.7013404492782405,0.0670134044927823,0.4196181637680596,0.36716681376875554,139.80935569285083,140.15798764342455,0.5164419894213406,0.4557063153750335,0.999117387466902,15
|
15 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.02,0.743018615750854,0.4514907128972251,8.545954556237808,0.085459545562378,0.4214940415288087,1.0035304501323918,67.00353045013239,67.98852603706973,0.7071153729164131,0.5543722941592363,1.0,6
|
16 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.04,0.7432195577780335,0.4517500968367987,10.080425294411064,0.1008042529441106,0.4200973007348334,0.01059135039717564,35.19770520741395,35.18358340688438,0.8244802169835412,0.6235766669370041,1.0,6
|
17 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.06,0.7430821573139815,0.4484154407825542,10.37470506193322,0.1037470506193321,0.4160289393328045,1.8005295675198587,26.880847308031775,28.656663724624888,0.8478345432646117,0.6381932626507077,1.0,3
|
18 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.08,0.7435937259684909,0.4407733547418294,10.930453247368872,0.1093045324736887,0.4113063412348818,0.09267431597528684,12.007943512797882,12.072374227714034,0.9329421050825354,0.6953650250037441,1.0,3
|
19 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.10,0.7427059700687901,0.4358940590119784,11.381344076286156,0.1138134407628615,0.4062980635945339,0.0176522506619594,11.914386584289497,11.905560458958517,0.9312420672746087,0.6933551155892559,1.0,3
|
20 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.00,0.7222260562908512,0.4039898602650971,13.461179673541356,0.1346117967354136,0.3819960428004565,0.05736981465136805,5.87378640776699,5.9179170344218885,0.9486112388485238,0.6860492554476049,1.0,1
|
21 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.02,0.723643534970515,0.4051102919608809,13.18537912294539,0.1318537912294539,0.3824621732976229,0.06266548984995587,5.840247131509267,5.8914386584289495,0.9486127363429205,0.6873967610154514,1.0,1
|
22 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.04,0.7238812581796301,0.4039456988919502,13.314773371306682,0.1331477337130668,0.3813737464821349,0.05736981465136805,5.845542806707855,5.889673433362754,0.948840810819099,0.6877794238881113,1.0,1
|
23 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.06,0.7252625281686607,0.4012797167602334,13.19924345265053,0.1319924345265053,0.3798291332004637,0.06266548984995587,5.847308031774051,5.884377758164166,0.9494061847846709,0.6894815110844906,1.0,1
|
24 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.08,0.7261167238322592,0.3987395126194482,12.656486100206328,0.1265648610020633,0.376975448872996,0.05648720211827008,5.820829655781112,5.864077669902913,0.9499856972945303,0.6906937144718889,1.0,1
|
25 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.10,0.7264630642225547,0.3964859769229444,12.284961706379857,0.1228496170637985,0.3744555065346823,0.04942630185348632,0.09267431597528684,0.12886142983230361,0.9988510902838437,0.7256289030293478,1.0,0
|
results/mac-results_rpp_with_mnt_2048_metrics.csv
CHANGED
@@ -1,43 +1,31 @@
|
|
1 |
-
model,rpp,comet,meteor,spbleu,bleu_1,rouge_l,ews_score,repetition_score,total_repetitions,rap,translation_completeness,num_max_output_tokens
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
shenzhi-wang/
|
27 |
-
shenzhi-wang/
|
28 |
-
shenzhi-wang/
|
29 |
-
shenzhi-wang/
|
30 |
-
shenzhi-wang/
|
31 |
-
shenzhi-wang/
|
32 |
-
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.00,0.7426396049131678,0.433632501662176,15.209540658023398,0.1520954065802339,0.4089208235151474,0.0,5.798764342453663,5.798764342453663,0.619577239776096,1.0,1
|
33 |
-
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.02,0.7436477056353469,0.4329054166518245,15.19102241646024,0.1519102241646024,0.4068967964789407,0.0,5.77846425419241,5.77846425419241,0.6207074516423631,1.0,1
|
34 |
-
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.04,0.7440943776351209,0.4320478700956207,15.05135166158296,0.1505135166158296,0.4062008380201262,0.0,0.11827007943512798,0.11827007943512798,0.7403141356055205,1.0,0
|
35 |
-
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.06,0.7426502735395928,0.4275429314912545,14.449130821290163,0.1444913082129016,0.4001409979222783,0.0,0.176522506619594,0.176522506619594,0.7370491441666694,1.0,0
|
36 |
-
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.08,0.7408098006080129,0.4206626658729054,13.933703757385222,0.1393370375738522,0.3964824268676203,0.0,0.21888790820829657,0.21888790820829657,0.7339083936807739,1.0,0
|
37 |
-
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.10,0.7392685912871718,0.4111211240399151,13.303738403756984,0.1330373840375698,0.3870959581563503,0.0,0.13857016769638128,0.13857016769638128,0.7348764469929082,1.0,0
|
38 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.00,0.7240239171358935,0.4068335357738006,13.565136550617618,0.1356513655061761,0.3866395067055498,0.0,0.1059135039717564,0.1059135039717564,0.7207261791424407,1.0,0
|
39 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.02,0.7263097057327799,0.4064914781094827,13.42987641622816,0.1342987641622816,0.3863697821025159,0.0,6.238305383936452,6.238305383936452,0.5999878425713037,1.0,1
|
40 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.04,0.7276128307708258,0.4054859896994975,13.295092218891954,0.1329509221889195,0.3851203729935697,0.0,0.1297440423654016,0.1297440423654016,0.7235619893938705,1.0,0
|
41 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.06,0.7276865132383193,0.4014727027723293,13.10860799057166,0.1310860799057166,0.3804952786306688,0.0,0.20741394527802295,0.20741394527802295,0.7212559915451495,1.0,0
|
42 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.08,0.726393195584298,0.3987018836449559,12.850537785783194,0.1285053778578319,0.3788945955746495,0.0,0.2903795233892321,0.2903795233892321,0.717473994791502,1.0,0
|
43 |
-
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.10,0.7244012304511832,0.3932239948456176,12.361161644811926,0.1236116164481192,0.3733413807007665,0.0,0.1500441306266549,0.1500441306266549,0.7197459635880831,1.0,0
|
|
|
1 |
+
model,rpp,comet,meteor,spbleu,bleu_1,rouge_l,ews_score,repetition_score,total_repetitions,nrr,rap,translation_completeness,num_max_output_tokens
|
2 |
+
internlm/internlm2_5-7b-chat,1.00,0.739699612254078,0.4289996929258777,14.734881589173108,0.1473488158917311,0.4096466800937898,0.05383936451897617,12.606354810238305,12.646954986760813,0.8963919016630513,0.6668973100142501,1.0,2
|
3 |
+
internlm/internlm2_5-7b-chat,1.02,0.740223803961056,0.4266246904302194,14.583816688798017,0.1458381668879802,0.4071727106228415,0.06266548984995587,9.849073256840247,9.910856134157106,0.916784004505773,0.6811186927169239,1.0,1
|
4 |
+
internlm/internlm2_5-7b-chat,1.04,0.7398856264610577,0.4154585167056314,13.534659133050225,0.1353465913305021,0.3968657713589718,0.07237422771403354,6.529567519858782,6.596646072374227,0.9439099437148217,0.6995278161884907,1.0,1
|
5 |
+
internlm/internlm2_5-7b-chat,1.06,0.7379362287241489,0.4039588647855378,12.346740971499404,0.1234674097149939,0.3872447044295494,0.06796116504854369,6.533980582524272,6.596646072374227,0.9449043529541853,0.6983788795529117,0.999117387466902,1
|
6 |
+
internlm/internlm2_5-7b-chat,1.08,0.7319988705684732,0.3873176839854818,11.075674965706344,0.1107567496570634,0.3724352909668609,0.05207413945278023,9.83495145631068,9.881729920564872,0.920975175928344,0.6763793903551222,0.999117387466902,1
|
7 |
+
internlm/internlm2_5-7b-chat,1.10,0.7295350462119345,0.3769306874386757,10.305163787094209,0.1030516378709421,0.3634496155759507,0.07855251544571933,6.527802294792586,6.596646072374227,0.9470732363646663,0.6919271286047086,0.999117387466902,1
|
8 |
+
microsoft/Phi-3.5-mini-instruct,1.00,0.7107840433177544,0.3796831545348129,8.71296896471494,0.0871296896471493,0.3589874395901284,10.670785525154457,17.93821712268314,28.58340688437776,0.7979259092866101,0.5807350079921869,1.0,6
|
9 |
+
microsoft/Phi-3.5-mini-instruct,1.02,0.7164765837070485,0.3780585837553919,10.291240080163629,0.1029124008016362,0.3546952732427276,3.585172109443954,7.1403353927625774,10.705207413945278,0.914860909301493,0.6580010154348458,1.0,2
|
10 |
+
microsoft/Phi-3.5-mini-instruct,1.04,0.7111233387336411,0.3547161333845742,8.966881655527896,0.0896688165552789,0.3300979657678754,3.6125330979699912,0.07325684024713151,3.685789938217123,0.9702657286890148,0.6902898733301204,1.0,1
|
11 |
+
microsoft/Phi-3.5-mini-instruct,1.06,0.7024363270136286,0.3298733737040869,7.076233088011138,0.0707623308801113,0.3019513312669543,0.04589585172109444,0.05207413945278023,0.0970873786407767,0.9992496538175567,0.7019094542904326,1.0,0
|
12 |
+
microsoft/Phi-3.5-mini-instruct,1.08,0.6882111219210848,0.3054541022592767,5.105510599247868,0.0510551059924786,0.2736030007297014,3.3609885260370698,0.06443071491615181,3.414827890556046,0.9764915329416268,0.6722210212114398,1.0,1
|
13 |
+
microsoft/Phi-3.5-mini-instruct,1.10,0.6712992989638161,0.2903831801547132,4.091958857999118,0.0409195885799911,0.251653275009876,0.32215357458075905,0.06531332744924978,0.3786407766990291,0.9977125977744483,0.6697655223041806,1.0,0
|
14 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.00,0.7501818982248062,0.4611110508507017,17.87914973742753,0.1787914973742752,0.4340662057009564,0.00706090026478376,0.1262135922330097,0.11650485436893204,0.9990152266843727,0.7494435027453233,1.0,0
|
15 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.02,0.7485114382045625,0.4571517219079576,17.436884594979905,0.174368845949799,0.4311385932640979,0.00706090026478376,0.11562224183583407,0.1059135039717564,0.9991036950172912,0.7478408442461326,1.0,0
|
16 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.04,0.7500591586357918,0.4560467960364254,17.440173470996626,0.1744017347099662,0.4302844557731285,0.00706090026478376,0.13062665489849956,0.1209179170344219,0.9989818138577363,0.7492958474569671,1.0,0
|
17 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.06,0.748812871571673,0.4520416361219855,16.89523258317781,0.168952325831778,0.4260026774745837,0.00706090026478376,0.0997352162400706,0.09002647837599294,0.999249541999897,0.7482511297698003,1.0,0
|
18 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.08,0.7473851635144647,0.4442106511292453,16.16623784482793,0.1616623784482792,0.4195129470585874,0.01059135039717564,0.13062665489849956,0.12444836716681378,0.9989631285573515,0.7466106228007235,1.0,0
|
19 |
+
shenzhi-wang/Llama3.1-70B-Chinese-Chat,1.10,0.7465709781131172,0.4379837926138161,15.60172257624066,0.1560172257624066,0.4132562932940978,0.01059135039717564,0.07855251544571933,0.06531332744924978,0.9994618690596525,0.7461693332490859,1.0,0
|
20 |
+
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.00,0.7426396049131678,0.433632501662176,15.209540658023398,0.1520954065802339,0.4089208235151474,0.00353045013239188,3.901147396293027,3.889673433362754,0.9677537371860069,0.7190742424186368,1.0,1
|
21 |
+
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.02,0.7436477056353469,0.4329054166518245,15.19102241646024,0.1519102241646024,0.4068967964789407,0.0,3.8905560458958517,3.8693733451015007,0.9679787303975633,0.7202123788620469,1.0,1
|
22 |
+
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.04,0.7440943776351209,0.4320478700956207,15.05135166158296,0.1505135166158296,0.4062008380201262,0.00353045013239188,0.1526919682259488,0.13503971756398941,0.9988310348779463,0.7432250654569489,1.0,0
|
23 |
+
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.06,0.7426502735395928,0.4275429314912545,14.449130821290163,0.1444913082129016,0.4001409979222783,0.00706090026478376,0.13768755516328332,0.13327449249779347,0.9988583611812559,0.7418029189370554,1.0,0
|
24 |
+
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.08,0.7408098006080129,0.4206626658729054,13.933703757385222,0.1393370375738522,0.3964824268676203,0.00353045013239188,0.1297440423654016,0.11738746690203,0.9990003006614552,0.7400695835992323,1.0,0
|
25 |
+
shenzhi-wang/Llama3.1-8B-Chinese-Chat,1.10,0.7392685912871718,0.4111211240399151,13.303738403756984,0.1330373840375698,0.3870959581563503,0.00353045013239188,0.12180052956751986,0.10944395410414828,0.9990805075005376,0.7385891517800307,1.0,0
|
26 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.00,0.7240239171358935,0.4068335357738006,13.565136550617618,0.1356513655061761,0.3866395067055498,0.0529567519858782,0.1209179170344219,0.1676963812886143,0.9984771126055,0.7229221493862132,1.0,0
|
27 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.02,0.7263097057327799,0.4064914781094827,13.42987641622816,0.1342987641622816,0.3863697821025159,0.06001765225066196,6.236540158870256,6.294792586054722,0.9458252309188138,0.6880088812871149,1.0,1
|
28 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.04,0.7276128307708258,0.4054859896994975,13.295092218891954,0.1329509221889195,0.3851203729935697,0.05207413945278023,0.1297440423654016,0.16946160635481025,0.9984590566537452,0.726492484037711,1.0,0
|
29 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.06,0.7276865132383193,0.4014727027723293,13.10860799057166,0.1310860799057166,0.3804952786306688,0.05207413945278023,0.13415710503089143,0.18446601941747573,0.9983249715485598,0.7264686378975903,1.0,0
|
30 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.08,0.726393195584298,0.3987018836449559,12.850537785783194,0.1285053778578319,0.3788945955746495,0.05648720211827008,0.15357458075904679,0.21006178287731686,0.9981128634521912,0.7250236850699381,1.0,0
|
31 |
+
shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat,1.10,0.7244012304511832,0.3932239948456176,12.361161644811926,0.1236116164481192,0.3733413807007665,0.05030891438658429,0.08561341571050309,0.13592233009708737,0.9987782625942087,0.7235167427869905,1.0,0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|