Spaces:
Sleeping
Sleeping
supported flag APPLY_CHAT_TEMPLATE_FOR_RAG
Browse files- .env.example +4 -0
- app_modules/init.py +2 -10
- app_modules/llm_chat_chain.py +2 -37
- app_modules/llm_inference.py +38 -0
- app_modules/llm_qa_chain.py +12 -4
- app_modules/utils.py +62 -216
- qa_chain_test.py +12 -2
- requirements.txt +3 -1
.env.example
CHANGED
@@ -25,7 +25,11 @@ OPENAI_MODEL_NAME=
|
|
25 |
OLLAMA_MODEL_NAME=llama3:8b
|
26 |
|
27 |
OLLAMA_RP=1.15
|
|
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
# cpu, mps or cuda:0 - if unset, use whatever detected
|
31 |
HF_EMBEDDINGS_DEVICE_TYPE=
|
|
|
25 |
OLLAMA_MODEL_NAME=llama3:8b
|
26 |
|
27 |
OLLAMA_RP=1.15
|
28 |
+
HF_RP=1.15
|
29 |
|
30 |
+
LANGCHAIN_DEBUG=false
|
31 |
+
BATCH_SIZE=1
|
32 |
+
APPLY_CHAT_TEMPLATE_FOR_RAG=true
|
33 |
|
34 |
# cpu, mps or cuda:0 - if unset, use whatever detected
|
35 |
HF_EMBEDDINGS_DEVICE_TYPE=
|
app_modules/init.py
CHANGED
@@ -10,7 +10,7 @@ from langchain.vectorstores.chroma import Chroma
|
|
10 |
from langchain.vectorstores.faiss import FAISS
|
11 |
|
12 |
from app_modules.llm_loader import LLMLoader
|
13 |
-
from app_modules.utils import get_device_types, init_settings
|
14 |
|
15 |
found_dotenv = find_dotenv(".env")
|
16 |
|
@@ -53,21 +53,13 @@ def app_init():
|
|
53 |
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
54 |
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
55 |
|
56 |
-
debug_metrics = os.getenv("DEBUG_METRICS", "false").lower() == "true"
|
57 |
-
|
58 |
-
if debug_metrics:
|
59 |
-
start = timer()
|
60 |
-
load_spacy_model()
|
61 |
-
end = timer()
|
62 |
-
print(f"Completed in {end - start:.3f}s")
|
63 |
-
|
64 |
qa_with_rag = os.getenv("QA_WITH_RAG", "true").lower() == "true"
|
65 |
print(f"qa_with_rag: {qa_with_rag}")
|
66 |
|
67 |
retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
|
68 |
print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
|
69 |
|
70 |
-
if qa_with_rag and not retrieve_from_questions_file
|
71 |
print(f"hf_embeddings_model_name: {hf_embeddings_model_name}")
|
72 |
start = timer()
|
73 |
embeddings = HuggingFaceInstructEmbeddings(
|
|
|
10 |
from langchain.vectorstores.faiss import FAISS
|
11 |
|
12 |
from app_modules.llm_loader import LLMLoader
|
13 |
+
from app_modules.utils import get_device_types, init_settings
|
14 |
|
15 |
found_dotenv = find_dotenv(".env")
|
16 |
|
|
|
53 |
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
54 |
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
qa_with_rag = os.getenv("QA_WITH_RAG", "true").lower() == "true"
|
57 |
print(f"qa_with_rag: {qa_with_rag}")
|
58 |
|
59 |
retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
|
60 |
print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
|
61 |
|
62 |
+
if qa_with_rag and not retrieve_from_questions_file:
|
63 |
print(f"hf_embeddings_model_name: {hf_embeddings_model_name}")
|
64 |
start = timer()
|
65 |
embeddings = HuggingFaceInstructEmbeddings(
|
app_modules/llm_chat_chain.py
CHANGED
@@ -6,7 +6,7 @@ from langchain.chains import ConversationChain, LLMChain
|
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
from langchain.chains.base import Chain
|
8 |
|
9 |
-
from app_modules.llm_inference import LLMInference
|
10 |
from app_modules.utils import CustomizedConversationSummaryBufferMemory
|
11 |
from langchain.chains import LLMChain
|
12 |
from langchain.globals import get_debug
|
@@ -15,23 +15,6 @@ chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "tr
|
|
15 |
B_INST, E_INST = "[INST]", "[/INST]"
|
16 |
|
17 |
|
18 |
-
def get_system_prompt_and_user_message(orca=False):
|
19 |
-
# system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
20 |
-
system_prompt = (
|
21 |
-
"You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
|
22 |
-
if orca
|
23 |
-
else "You are a chatbot having a conversation with a human."
|
24 |
-
)
|
25 |
-
|
26 |
-
user_message = "{input}"
|
27 |
-
|
28 |
-
if chat_history_enabled:
|
29 |
-
user_message = "Chat History:\n\n{history} \n\n" + user_message
|
30 |
-
system_prompt += " Read the chat history to get context."
|
31 |
-
|
32 |
-
return system_prompt, user_message
|
33 |
-
|
34 |
-
|
35 |
def create_llama_2_prompt_template():
|
36 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
37 |
|
@@ -141,19 +124,7 @@ class ChatChain(LLMInference):
|
|
141 |
if not isinstance(inputs, list):
|
142 |
inputs = {"input": inputs["question"]}
|
143 |
elif self.llm_loader.llm_model_type == "huggingface":
|
144 |
-
inputs = [
|
145 |
-
[
|
146 |
-
{
|
147 |
-
"role": "system",
|
148 |
-
"content": self.get_system_message(i),
|
149 |
-
},
|
150 |
-
{
|
151 |
-
"role": "user",
|
152 |
-
"content": self.get_user_message(i),
|
153 |
-
},
|
154 |
-
]
|
155 |
-
for i in inputs
|
156 |
-
]
|
157 |
else:
|
158 |
inputs = [{"input": i["question"]} for i in inputs]
|
159 |
|
@@ -161,9 +132,3 @@ class ChatChain(LLMInference):
|
|
161 |
print("_process_inputs:", json.dumps(inputs, indent=4))
|
162 |
|
163 |
return inputs
|
164 |
-
|
165 |
-
def get_system_message(self, input) -> Chain:
|
166 |
-
return get_system_prompt_and_user_message()[0]
|
167 |
-
|
168 |
-
def get_user_message(self, input) -> Chain:
|
169 |
-
return input["question"]
|
|
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
from langchain.chains.base import Chain
|
8 |
|
9 |
+
from app_modules.llm_inference import LLMInference, get_system_prompt_and_user_message
|
10 |
from app_modules.utils import CustomizedConversationSummaryBufferMemory
|
11 |
from langchain.chains import LLMChain
|
12 |
from langchain.globals import get_debug
|
|
|
15 |
B_INST, E_INST = "[INST]", "[/INST]"
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def create_llama_2_prompt_template():
|
19 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
20 |
|
|
|
124 |
if not isinstance(inputs, list):
|
125 |
inputs = {"input": inputs["question"]}
|
126 |
elif self.llm_loader.llm_model_type == "huggingface":
|
127 |
+
inputs = [self.apply_chat_template(input["question"]) for input in inputs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
else:
|
129 |
inputs = [{"input": i["question"]} for i in inputs]
|
130 |
|
|
|
132 |
print("_process_inputs:", json.dumps(inputs, indent=4))
|
133 |
|
134 |
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
app_modules/llm_inference.py
CHANGED
@@ -14,6 +14,25 @@ from langchain.chains.base import Chain
|
|
14 |
from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
|
15 |
from app_modules.utils import remove_extra_spaces
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class LLMInference(metaclass=abc.ABCMeta):
|
19 |
def __init__(self, llm_loader):
|
@@ -143,3 +162,22 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
143 |
|
144 |
t.join()
|
145 |
return que.get()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
|
15 |
from app_modules.utils import remove_extra_spaces
|
16 |
|
17 |
+
chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "true"
|
18 |
+
|
19 |
+
|
20 |
+
def get_system_prompt_and_user_message(orca=False):
|
21 |
+
# system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
22 |
+
system_prompt = (
|
23 |
+
"You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
|
24 |
+
if orca
|
25 |
+
else "You are a chatbot having a conversation with a human."
|
26 |
+
)
|
27 |
+
|
28 |
+
user_message = "{input}"
|
29 |
+
|
30 |
+
if chat_history_enabled:
|
31 |
+
user_message = "Chat History:\n\n{history} \n\n" + user_message
|
32 |
+
system_prompt += " Read the chat history to get context."
|
33 |
+
|
34 |
+
return system_prompt, user_message
|
35 |
+
|
36 |
|
37 |
class LLMInference(metaclass=abc.ABCMeta):
|
38 |
def __init__(self, llm_loader):
|
|
|
162 |
|
163 |
t.join()
|
164 |
return que.get()
|
165 |
+
|
166 |
+
def apply_chat_template(self, user_message):
|
167 |
+
result = (
|
168 |
+
[]
|
169 |
+
if self.llm_loader.model_name.lower().startswith("gemma")
|
170 |
+
else [
|
171 |
+
{
|
172 |
+
"role": "system",
|
173 |
+
"content": get_system_prompt_and_user_message()[0],
|
174 |
+
}
|
175 |
+
]
|
176 |
+
)
|
177 |
+
result.append(
|
178 |
+
{
|
179 |
+
"role": "user",
|
180 |
+
"content": user_message,
|
181 |
+
}
|
182 |
+
)
|
183 |
+
return result
|
app_modules/llm_qa_chain.py
CHANGED
@@ -6,12 +6,17 @@ from langchain.chains import ConversationalRetrievalChain
|
|
6 |
from langchain.chains.base import Chain
|
7 |
from app_modules.llm_inference import LLMInference
|
8 |
from app_modules.utils import CustomizedConversationSummaryBufferMemory
|
|
|
9 |
from langchain_core.retrievers import BaseRetriever
|
10 |
from langchain_core.documents import Document
|
11 |
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
12 |
from langchain.globals import get_debug
|
13 |
|
14 |
retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
|
|
|
|
|
|
|
|
|
15 |
|
16 |
if retrieve_from_questions_file:
|
17 |
questions_file_path = os.getenv("QUESTIONS_FILE_PATH")
|
@@ -108,8 +113,11 @@ class QAChain(LLMInference):
|
|
108 |
# find the query in the df
|
109 |
filtered = df[df["question"].str.lower() == query.lower()]
|
110 |
|
111 |
-
context = filtered.iloc[0]["context"]
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
6 |
from langchain.chains.base import Chain
|
7 |
from app_modules.llm_inference import LLMInference
|
8 |
from app_modules.utils import CustomizedConversationSummaryBufferMemory
|
9 |
+
|
10 |
from langchain_core.retrievers import BaseRetriever
|
11 |
from langchain_core.documents import Document
|
12 |
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
13 |
from langchain.globals import get_debug
|
14 |
|
15 |
retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
|
16 |
+
apply_chat_template_for_rag = os.getenv("APPLY_CHAT_TEMPLATE_FOR_RAG") == "true"
|
17 |
+
|
18 |
+
print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
|
19 |
+
print(f"apply_chat_template_for_rag: {apply_chat_template_for_rag}", flush=True)
|
20 |
|
21 |
if retrieve_from_questions_file:
|
22 |
questions_file_path = os.getenv("QUESTIONS_FILE_PATH")
|
|
|
113 |
# find the query in the df
|
114 |
filtered = df[df["question"].str.lower() == query.lower()]
|
115 |
|
116 |
+
context = filtered.iloc[0]["context"] if len(filtered) > 0 else ""
|
117 |
|
118 |
+
if apply_chat_template_for_rag:
|
119 |
+
return self.apply_chat_template(
|
120 |
+
f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}"
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
return f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}\n\nHelpful Answer:"
|
app_modules/utils.py
CHANGED
@@ -7,7 +7,8 @@ import os
|
|
7 |
import platform
|
8 |
import re
|
9 |
from pathlib import Path
|
10 |
-
|
|
|
11 |
import requests
|
12 |
import torch
|
13 |
from tqdm import tqdm
|
@@ -186,234 +187,79 @@ class CustomizedConversationSummaryBufferMemory(ConversationSummaryBufferMemory)
|
|
186 |
)
|
187 |
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
return 0
|
192 |
-
distance = distance_calculator.evaluate_string_pairs(
|
193 |
-
prediction=entry1, prediction_b=entry2
|
194 |
-
)
|
195 |
-
# print(f"entry1: {entry1}, entry2: {entry2}, distance: {distance['score']}")
|
196 |
-
return distance["score"]
|
197 |
-
|
198 |
-
|
199 |
-
def FindInList(entry, elist, distance_calculator=None, debug=False):
|
200 |
-
for item in elist:
|
201 |
-
if distance_calculator is not None:
|
202 |
-
distance = CalculateDistance(entry, item, distance_calculator)
|
203 |
-
if distance < distance_threshold:
|
204 |
-
if debug:
|
205 |
-
print(
|
206 |
-
f"FindInList - matched by distance {distance:.3f}: {entry} - {item}"
|
207 |
-
)
|
208 |
-
return True
|
209 |
-
if entry == item:
|
210 |
-
return True
|
211 |
-
return False
|
212 |
-
|
213 |
-
|
214 |
-
def CalculatePRF1F2(
|
215 |
-
goldAnswerList, predAnswerList, distance_calculator=None, debug=False
|
216 |
-
):
|
217 |
-
if len(goldAnswerList) == 0:
|
218 |
-
if len(predAnswerList) == 0:
|
219 |
-
return [
|
220 |
-
1.0,
|
221 |
-
1.0,
|
222 |
-
1.0,
|
223 |
-
1.0,
|
224 |
-
] # consider it 'correct' when there is no labeled answer, and also no predicted answer
|
225 |
-
else:
|
226 |
-
return [
|
227 |
-
0.0,
|
228 |
-
1.0,
|
229 |
-
0.0,
|
230 |
-
0.0,
|
231 |
-
] # precision=0 and recall=1 when there is no labeled answer, but has some predicted answer(s)
|
232 |
-
elif len(predAnswerList) == 0:
|
233 |
-
return [
|
234 |
-
1.0,
|
235 |
-
0.0,
|
236 |
-
0.0,
|
237 |
-
0.0,
|
238 |
-
] # precision=1 and recall=0 when there is labeled answer(s), but no predicted answer
|
239 |
-
else:
|
240 |
-
glist = goldAnswerList
|
241 |
-
plist = predAnswerList
|
242 |
-
|
243 |
-
tp = 1e-40 # numerical trick
|
244 |
-
fp = 0.0
|
245 |
-
fn = 0.0
|
246 |
-
|
247 |
-
for gentry in glist:
|
248 |
-
if FindInList(
|
249 |
-
gentry, plist, distance_calculator=distance_calculator, debug=True
|
250 |
-
):
|
251 |
-
tp += 1
|
252 |
-
else:
|
253 |
-
fn += 1
|
254 |
-
for pentry in plist:
|
255 |
-
if not FindInList(pentry, glist, distance_calculator=distance_calculator):
|
256 |
-
fp += 1
|
257 |
-
|
258 |
-
precision = tp / (tp + fp)
|
259 |
-
recall = tp / (tp + fn)
|
260 |
-
|
261 |
-
f1 = (2 * precision * recall) / (precision + recall)
|
262 |
-
f2 = (5 * precision * recall) / (4 * precision + recall)
|
263 |
-
return [precision, recall, f1, f2]
|
264 |
-
|
265 |
-
|
266 |
-
nlp = None
|
267 |
-
distance_threshold = 0.05
|
268 |
-
|
269 |
-
|
270 |
-
def load_spacy_model():
|
271 |
-
import spacy
|
272 |
-
|
273 |
-
global nlp
|
274 |
-
if nlp is not None:
|
275 |
-
return nlp
|
276 |
-
|
277 |
-
global distance_threshold
|
278 |
-
distance_threshold = float(os.getenv("DISTANCE_THRESHOLD", "0.05"))
|
279 |
|
280 |
-
spacy_model_name = os.getenv("SPACY_MODEL_NAME", "en_core_web_trf")
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
nlp = spacy.load(spacy_model_name)
|
286 |
-
print(f"loaded spacy model from {spacy_model_name}")
|
287 |
-
return nlp
|
288 |
-
except OSError:
|
289 |
-
print(f"downloading spacy model {spacy_model_name}")
|
290 |
-
spacy.cli.download(spacy_model_name)
|
291 |
-
print(f"downloaded spacy model {spacy_model_name}")
|
292 |
-
|
293 |
-
|
294 |
-
def clean_text(text):
|
295 |
-
text = text.lower()
|
296 |
-
text = text.replace('"', "")
|
297 |
-
text = text.replace(".", "")
|
298 |
-
# text = text.replace("ō", "o")
|
299 |
-
return text
|
300 |
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
-
def get_entities_in_text(text, debug=False):
|
303 |
-
nlp = load_spacy_model()
|
304 |
-
doc = nlp(text)
|
305 |
-
entities_in_text = []
|
306 |
-
for word in doc.ents:
|
307 |
-
if debug:
|
308 |
-
print(word.text, word.label_)
|
309 |
-
entity = clean_text(word.text)
|
310 |
-
if entity not in entities_in_text:
|
311 |
-
entities_in_text.append(entity)
|
312 |
|
313 |
-
|
314 |
-
|
|
|
315 |
|
316 |
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
if debug:
|
322 |
-
print(
|
323 |
-
|
324 |
|
325 |
-
print("entities_in_question ---------------")
|
326 |
-
entities_in_question = get_entities_in_text(question["question"], debug)
|
327 |
|
328 |
-
|
329 |
-
|
330 |
|
331 |
-
print("done with NER with spaCy -----------")
|
332 |
|
333 |
-
|
|
|
334 |
|
335 |
-
predAnswerList = [
|
336 |
-
pentry
|
337 |
-
for pentry in entities_in_answer
|
338 |
-
if not FindInList(pentry, entities_in_question)
|
339 |
-
]
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
print(f"pred_answers: {predAnswerList}")
|
345 |
-
|
346 |
-
precision, recall, f1, f2 = CalculatePRF1F2(
|
347 |
-
ground_truth,
|
348 |
-
predAnswerList,
|
349 |
-
debug=debug,
|
350 |
-
distance_calculator=distance_calculator,
|
351 |
-
)
|
352 |
-
print(f"precision: {precision}, recall: {recall}, f1: {f1}, f2: {f2}")
|
353 |
-
else:
|
354 |
-
precision = 0.0
|
355 |
-
recall = 0.0
|
356 |
-
f1 = 0.0
|
357 |
-
f2 = 0.0
|
358 |
-
entities_in_answer = []
|
359 |
-
entities_in_question = []
|
360 |
-
|
361 |
-
return (
|
362 |
-
precision,
|
363 |
-
recall,
|
364 |
-
f1,
|
365 |
-
f2,
|
366 |
-
entities_in_answer,
|
367 |
-
ground_truth,
|
368 |
-
entities_in_question,
|
369 |
)
|
370 |
-
|
371 |
-
|
372 |
-
def calculate_metrics_gemini(question, answer, debug=False):
|
373 |
-
precision = 0.0
|
374 |
-
recall = 0.0
|
375 |
-
f1 = 0.0
|
376 |
-
|
377 |
-
return (precision, recall, f1)
|
378 |
-
|
379 |
-
|
380 |
-
if __name__ == "__main__":
|
381 |
-
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
382 |
-
from langchain.evaluation import load_evaluator
|
383 |
-
|
384 |
-
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
385 |
-
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
386 |
-
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
387 |
-
|
388 |
-
hf_embeddings_model_name = "hkunlp/instructor-large"
|
389 |
-
print(f"hf_embeddings_model_name: {hf_embeddings_model_name}")
|
390 |
-
embeddings = HuggingFaceInstructEmbeddings(
|
391 |
-
model_name=hf_embeddings_model_name,
|
392 |
-
model_kwargs={"device": hf_embeddings_device_type},
|
393 |
-
)
|
394 |
-
|
395 |
-
hf_evaluator = load_evaluator("pairwise_embedding_distance", embeddings=embeddings)
|
396 |
-
|
397 |
-
question = {
|
398 |
-
"question": "what does jamaican people speak",
|
399 |
-
"entities_in_question": ["jamaican"],
|
400 |
-
"answers": ["jamaican english", "jamaican creole english language"],
|
401 |
-
}
|
402 |
-
answer = "Jamaican people primarily speak Jamaican Patois, which is an English-based creole language with significant West African influences. It is spoken as a native language by the majority of Jamaicans and also exists in various forms among Jamaican expatriates and non-Jamaicans in different parts of the world. The phonology of Jamaican Patois includes around 21 consonants (with some dialectal variation regarding the status of /h/ as a phoneme) and between nine and sixteen vowels, some of which are capable of nasalization or lengthening. There are also instances of palatalization in Jamaican Patois, where certain consonants appear to be phonemic in some dialects but may be considered phonetic in others. For example, the palatal stops [c], [ɟ], and [ɲ] may be analyzed as phonemes or as instances of phonetic palatalization depending on the account."
|
403 |
-
calculate_metrics(question, answer, distance_calculator=hf_evaluator, debug=True)
|
404 |
-
|
405 |
-
question = {
|
406 |
-
"question": "who is governor of ohio 2011",
|
407 |
-
"entities_in_question": ["2011"],
|
408 |
-
"answers": ["john kasich", "return j. meigs, jr.", "ted strickland"],
|
409 |
-
}
|
410 |
-
answer = "The lieutenant governor of Ohio in 2011 was Mary Taylor, who served alongside Governor John Kasich. She assumed office on January 10, 2011, after being elected as the lieutenant governor in the 2010 election. During her tenure, she faced criticism for using the state airplane for personal errands and reportedly had high turnover among her staff."
|
411 |
-
calculate_metrics(question, answer, distance_calculator=hf_evaluator, debug=True)
|
412 |
-
|
413 |
-
question = {
|
414 |
-
"question": "where is the fukushima daiichi nuclear power station",
|
415 |
-
"entities_in_question": ["the fukushima daiichi nuclear power station"],
|
416 |
-
"answers": ["japan", "okuma"],
|
417 |
-
}
|
418 |
-
answer = "The Fukushima Daiichi Nuclear Power Station is located in the towns of Ōkuma and Futaba in Fukushima Prefecture, Japan."
|
419 |
-
calculate_metrics(question, answer, distance_calculator=hf_evaluator, debug=True)
|
|
|
7 |
import platform
|
8 |
import re
|
9 |
from pathlib import Path
|
10 |
+
import evaluate
|
11 |
+
import pandas as pd
|
12 |
import requests
|
13 |
import torch
|
14 |
from tqdm import tqdm
|
|
|
187 |
)
|
188 |
|
189 |
|
190 |
+
bleu = evaluate.load("bleu")
|
191 |
+
rouge = evaluate.load("rouge")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
|
|
193 |
|
194 |
+
def calc_metrics(df):
|
195 |
+
predictions = [df["answer"][i] for i in range(len(df))]
|
196 |
+
references = [df["ground_truth"][i] for i in range(len(df))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
bleu_scores = bleu.compute(
|
199 |
+
predictions=predictions, references=references, max_order=1
|
200 |
+
)
|
201 |
+
rouge_scores = rouge.compute(predictions=predictions, references=references)
|
202 |
+
return {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
+
pattern_abnormal_newlines = re.compile(r"\n{5,}")
|
206 |
+
pattern_text_repetitions = re.compile(r"\b(\w.+?)\b(\1+)", re.M | re.DOTALL)
|
207 |
+
exception_pattern = re.compile(r"(\w+\.)\1")
|
208 |
|
209 |
|
210 |
+
# final version for repetition detection
|
211 |
+
def detect_repetitions(
|
212 |
+
text, debug=False, pattern_text_repetitions=pattern_text_repetitions
|
213 |
+
):
|
214 |
+
subtotals = [0, 0]
|
215 |
+
|
216 |
+
if isinstance(text, str):
|
217 |
+
patterns = [pattern_abnormal_newlines, pattern_text_repetitions]
|
218 |
+
for i, pattern in enumerate(patterns):
|
219 |
+
if debug:
|
220 |
+
print(
|
221 |
+
f"----detect {'abnormal newlines' if i == 0 else 'text repetitions'}----"
|
222 |
+
)
|
223 |
+
matches = pattern.finditer(text)
|
224 |
+
for match in matches:
|
225 |
+
if debug:
|
226 |
+
print(match)
|
227 |
+
for groupNum in range(0, len(match.groups())):
|
228 |
+
groupNum = groupNum + 1
|
229 |
+
print(
|
230 |
+
"Group {groupNum} found at {start}-{end}: `{group}`".format(
|
231 |
+
groupNum=groupNum,
|
232 |
+
start=match.start(groupNum),
|
233 |
+
end=match.end(groupNum),
|
234 |
+
group=match.group(groupNum),
|
235 |
+
)
|
236 |
+
)
|
237 |
+
|
238 |
+
if exception_pattern.match(match[0]):
|
239 |
+
if debug:
|
240 |
+
print("ignored: ", match[0])
|
241 |
+
continue
|
242 |
+
|
243 |
+
start, end = match.span()
|
244 |
+
subtotals[i] += end - start
|
245 |
+
|
246 |
+
result = (subtotals[0], subtotals[1], subtotals[0] + subtotals[1])
|
247 |
|
248 |
if debug:
|
249 |
+
print(result)
|
250 |
+
return result
|
251 |
|
|
|
|
|
252 |
|
253 |
+
def detect_abnormal_newlines(text, debug=False):
|
254 |
+
return detect_repetitions(text, debug=debug)[0]
|
255 |
|
|
|
256 |
|
257 |
+
def detect_text_repetitions(text, debug=False):
|
258 |
+
return detect_repetitions(text, debug=debug)[1]
|
259 |
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
+
def detect_repetition_scores(text, debug=False):
|
262 |
+
newline_score, repetition_score, total_repetitions = detect_repetitions(
|
263 |
+
text, debug=debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
)
|
265 |
+
return pd.Series([newline_score, repetition_score, total_repetitions])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qa_chain_test.py
CHANGED
@@ -12,7 +12,7 @@ if chatting:
|
|
12 |
|
13 |
from app_modules.init import app_init
|
14 |
from app_modules.llm_qa_chain import QAChain
|
15 |
-
from app_modules.utils import print_llm_response
|
16 |
|
17 |
llm_loader, qa_chain = app_init()
|
18 |
|
@@ -116,7 +116,9 @@ if __name__ == "__main__":
|
|
116 |
query = df["question"][i]
|
117 |
id = df["id"][i]
|
118 |
|
119 |
-
ground_truth = question[
|
|
|
|
|
120 |
|
121 |
word_count = len(nltk.word_tokenize(answer))
|
122 |
|
@@ -128,6 +130,10 @@ if __name__ == "__main__":
|
|
128 |
"ground_truth": ground_truth,
|
129 |
}
|
130 |
|
|
|
|
|
|
|
|
|
131 |
pd.options.display.float_format = "{:.3f}".format
|
132 |
print(df2.describe())
|
133 |
|
@@ -147,6 +153,8 @@ if __name__ == "__main__":
|
|
147 |
df2.to_csv(csv_file, mode="a", index=False, header=True)
|
148 |
print(f"test results saved to file: {csv_file}")
|
149 |
|
|
|
|
|
150 |
df = pd.DataFrame(
|
151 |
{
|
152 |
"model": [llm_loader.model_name],
|
@@ -154,6 +162,8 @@ if __name__ == "__main__":
|
|
154 |
"word_count": [word_count],
|
155 |
"inference_time": [total_time],
|
156 |
"inference_speed": [word_count / total_time],
|
|
|
|
|
157 |
}
|
158 |
)
|
159 |
|
|
|
12 |
|
13 |
from app_modules.init import app_init
|
14 |
from app_modules.llm_qa_chain import QAChain
|
15 |
+
from app_modules.utils import print_llm_response, calc_metrics, detect_repetition_scores
|
16 |
|
17 |
llm_loader, qa_chain = app_init()
|
18 |
|
|
|
116 |
query = df["question"][i]
|
117 |
id = df["id"][i]
|
118 |
|
119 |
+
ground_truth = question[
|
120 |
+
"wellFormedAnswers" if "wellFormedAnswers" in question else "answers"
|
121 |
+
]
|
122 |
|
123 |
word_count = len(nltk.word_tokenize(answer))
|
124 |
|
|
|
130 |
"ground_truth": ground_truth,
|
131 |
}
|
132 |
|
133 |
+
df2[["newline_score", "repetition_score", "total_repetitions"]] = df2[
|
134 |
+
"answer"
|
135 |
+
].apply(detect_repetition_scores)
|
136 |
+
|
137 |
pd.options.display.float_format = "{:.3f}".format
|
138 |
print(df2.describe())
|
139 |
|
|
|
153 |
df2.to_csv(csv_file, mode="a", index=False, header=True)
|
154 |
print(f"test results saved to file: {csv_file}")
|
155 |
|
156 |
+
scores = calc_metrics(df2)
|
157 |
+
|
158 |
df = pd.DataFrame(
|
159 |
{
|
160 |
"model": [llm_loader.model_name],
|
|
|
162 |
"word_count": [word_count],
|
163 |
"inference_time": [total_time],
|
164 |
"inference_speed": [word_count / total_time],
|
165 |
+
"bleu1": [scores["bleu_scores"]["bleu"]],
|
166 |
+
"rougeL": [scores["rouge_scores"]["rougeL"]],
|
167 |
}
|
168 |
)
|
169 |
|
requirements.txt
CHANGED
@@ -9,4 +9,6 @@ gradio==4.26.0
|
|
9 |
spaces==0.27.1
|
10 |
black==24.4.0
|
11 |
chardet==5.2.0
|
12 |
-
sentencepiece==0.2.0
|
|
|
|
|
|
9 |
spaces==0.27.1
|
10 |
black==24.4.0
|
11 |
chardet==5.2.0
|
12 |
+
sentencepiece==0.2.0
|
13 |
+
evaluate==0.4.2
|
14 |
+
rouge_score==0.1.2
|