BERT-F1
Browse files- .env.example +2 -81
- app.py +29 -7
- eval_modules/utils.py +13 -2
.env.example
CHANGED
@@ -1,81 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
# LLM_MODEL_TYPE=hftgi
|
4 |
-
# LLM_MODEL_TYPE=ollama
|
5 |
-
# LLM_MODEL_TYPE=google
|
6 |
-
# LLM_MODEL_TYPE=vllm
|
7 |
-
|
8 |
-
HUGGINGFACE_AUTH_TOKEN=
|
9 |
-
|
10 |
-
HFTGI_SERVER_URL=
|
11 |
-
|
12 |
-
OPENAI_API_KEY=
|
13 |
-
|
14 |
-
GOOGLE_API_KEY=
|
15 |
-
|
16 |
-
# if unset, default to "gpt-3.5-turbo"
|
17 |
-
OPENAI_MODEL_NAME=
|
18 |
-
|
19 |
-
# GEMINI_MODEL_NAME=gemini-1.5-pro-latest
|
20 |
-
|
21 |
-
# OLLAMA_MODEL_NAME=orca2:7b
|
22 |
-
# OLLAMA_MODEL_NAME=mistral:7b
|
23 |
-
# OLLAMA_MODEL_NAME=gemma:7b
|
24 |
-
# OLLAMA_MODEL_NAME=llama2:7b
|
25 |
-
OLLAMA_MODEL_NAME=llama3:8b
|
26 |
-
|
27 |
-
OLLAMA_RP=1.15
|
28 |
-
HF_RP=1.15
|
29 |
-
|
30 |
-
LANGCHAIN_DEBUG=false
|
31 |
-
BATCH_SIZE=1
|
32 |
-
APPLY_CHAT_TEMPLATE_FOR_RAG=true
|
33 |
-
|
34 |
-
# cpu, mps or cuda:0 - if unset, use whatever detected
|
35 |
-
HF_EMBEDDINGS_DEVICE_TYPE=
|
36 |
-
HF_PIPELINE_DEVICE_TYPE=
|
37 |
-
|
38 |
-
# uncomment one of the below to load corresponding quantized model
|
39 |
-
# LOAD_QUANTIZED_MODEL=4bit
|
40 |
-
# LOAD_QUANTIZED_MODEL=8bit
|
41 |
-
|
42 |
-
QA_WITH_RAG=true
|
43 |
-
# QA_WITH_RAG=false
|
44 |
-
|
45 |
-
RETRIEVER_TYPE=questions_file
|
46 |
-
# RETRIEVER_TYPE=vectorstore
|
47 |
-
|
48 |
-
QUESTIONS_FILE_PATH="./data/datasets/ms_macro.json"
|
49 |
-
|
50 |
-
DISABLE_MODEL_PRELOADING=true
|
51 |
-
CHAT_HISTORY_ENABLED=false
|
52 |
-
SHOW_PARAM_SETTINGS=false
|
53 |
-
SHARE_GRADIO_APP=false
|
54 |
-
|
55 |
-
# if unset, default to "hkunlp/instructor-xl"
|
56 |
-
HF_EMBEDDINGS_MODEL_NAME="hkunlp/instructor-large"
|
57 |
-
|
58 |
-
# number of cpu cores - used to set n_threads for GPT4ALL & LlamaCpp models
|
59 |
-
NUMBER_OF_CPU_CORES=
|
60 |
-
|
61 |
-
USING_TORCH_BFLOAT16=true
|
62 |
-
|
63 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-3b"
|
64 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-7b"
|
65 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-12b"
|
66 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="TheBloke/wizardLM-7B-HF"
|
67 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="TheBloke/vicuna-7B-1.1-HF"
|
68 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="nomic-ai/gpt4all-j"
|
69 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="nomic-ai/gpt4all-falcon"
|
70 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="lmsys/fastchat-t5-3b-v1.0"
|
71 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-7b-chat-hf"
|
72 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-13b-chat-hf"
|
73 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-70b-chat-hf"
|
74 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
|
75 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-70B-Instruct"
|
76 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Orca-2-7b"
|
77 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Orca-2-13b"
|
78 |
-
HUGGINGFACE_MODEL_NAME_OR_PATH="google/gemma-1.1-2b-it"
|
79 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="google/gemma-1.1-7b-it"
|
80 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-128k-instruct"
|
81 |
-
# HUGGINGFACE_MODEL_NAME_OR_PATH="mistralai/Mistral-7B-Instruct-v0.2"
|
|
|
1 |
+
HF_TOKEN=
|
2 |
+
MODEL_NAME=microsoft/Phi-3.5-mini-instruct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,10 +1,30 @@
|
|
1 |
import json
|
2 |
import os
|
|
|
|
|
3 |
import gradio as gr
|
4 |
-
from
|
5 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from eval_modules.calc_repetitions_v2e import detect_repetitions
|
7 |
|
|
|
|
|
|
|
|
|
|
|
8 |
questions_file_path = os.getenv("QUESTIONS_FILE_PATH") or "./ms_macro.json"
|
9 |
|
10 |
questions = json.loads(open(questions_file_path).read())
|
@@ -18,7 +38,8 @@ For more information on `huggingface_hub` Inference API support, please check th
|
|
18 |
"""
|
19 |
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
20 |
# client = InferenceClient("HuggingFaceH4/zephyr-7b-gemma-v0.1")
|
21 |
-
client = InferenceClient("microsoft/Phi-3.5-mini-instruct")
|
|
|
22 |
|
23 |
|
24 |
def chat(
|
@@ -74,11 +95,11 @@ def chat(
|
|
74 |
answer = partial_text
|
75 |
(whitespace_score, repetition_score, total_repetitions) = detect_repetitions(answer)
|
76 |
partial_text += "\n\nRepetition Metrics:\n"
|
77 |
-
partial_text += f"1.
|
78 |
-
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
|
79 |
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
|
80 |
partial_text += (
|
81 |
-
f"1.
|
82 |
)
|
83 |
|
84 |
if index >= 0: # RAG
|
@@ -87,11 +108,12 @@ def chat(
|
|
87 |
if "wellFormedAnswers" in questions[index]
|
88 |
else "answers"
|
89 |
)
|
90 |
-
scores =
|
91 |
|
92 |
partial_text += "\n\n Performance Metrics:\n"
|
93 |
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
|
94 |
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
|
|
|
95 |
|
96 |
partial_text += f"\n\nGround truth: {questions[index][key][0]}\n"
|
97 |
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
import sys
|
4 |
+
import evaluate
|
5 |
import gradio as gr
|
6 |
+
from dotenv import find_dotenv, load_dotenv
|
7 |
+
from huggingface_hub import InferenceClient, login
|
8 |
+
|
9 |
+
found_dotenv = find_dotenv(".env")
|
10 |
+
|
11 |
+
if len(found_dotenv) == 0:
|
12 |
+
found_dotenv = find_dotenv(".env.example")
|
13 |
+
print(f"loading env vars from: {found_dotenv}")
|
14 |
+
load_dotenv(found_dotenv, override=False)
|
15 |
+
|
16 |
+
path = os.path.dirname(found_dotenv)
|
17 |
+
print(f"Adding {path} to sys.path")
|
18 |
+
sys.path.append(path)
|
19 |
+
|
20 |
+
from eval_modules.utils import calc_perf_scores
|
21 |
from eval_modules.calc_repetitions_v2e import detect_repetitions
|
22 |
|
23 |
+
model_name = os.getenv("MODEL_NAME") or "microsoft/Phi-3.5-mini-instruct"
|
24 |
+
hf_token = os.getenv("HF_TOKEN")
|
25 |
+
|
26 |
+
login(token=hf_token, add_to_git_credential=True)
|
27 |
+
|
28 |
questions_file_path = os.getenv("QUESTIONS_FILE_PATH") or "./ms_macro.json"
|
29 |
|
30 |
questions = json.loads(open(questions_file_path).read())
|
|
|
38 |
"""
|
39 |
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
40 |
# client = InferenceClient("HuggingFaceH4/zephyr-7b-gemma-v0.1")
|
41 |
+
# client = InferenceClient("microsoft/Phi-3.5-mini-instruct")
|
42 |
+
client = InferenceClient(model_name, token=hf_token)
|
43 |
|
44 |
|
45 |
def chat(
|
|
|
95 |
answer = partial_text
|
96 |
(whitespace_score, repetition_score, total_repetitions) = detect_repetitions(answer)
|
97 |
partial_text += "\n\nRepetition Metrics:\n"
|
98 |
+
partial_text += f"1. EWC Repetition Score: {whitespace_score:.3f}\n"
|
99 |
+
partial_text += f"1. Text Repetition Score: {repetition_score:.3f}\n"
|
100 |
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
|
101 |
partial_text += (
|
102 |
+
f"1. Repetition Ratio: {total_repetitions / len(answer):.3f}\n"
|
103 |
)
|
104 |
|
105 |
if index >= 0: # RAG
|
|
|
108 |
if "wellFormedAnswers" in questions[index]
|
109 |
else "answers"
|
110 |
)
|
111 |
+
scores = calc_perf_scores([answer], [questions[index][key]], debug=True)
|
112 |
|
113 |
partial_text += "\n\n Performance Metrics:\n"
|
114 |
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
|
115 |
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
|
116 |
+
partial_text += f'1. BERT-F1: {scores["bert_scores"]["f1"][0]:.3f}\n'
|
117 |
|
118 |
partial_text += f"\n\nGround truth: {questions[index][key][0]}\n"
|
119 |
|
eval_modules/utils.py
CHANGED
@@ -173,9 +173,10 @@ def ensure_model_is_downloaded(llm_model_type):
|
|
173 |
|
174 |
bleu = evaluate.load("bleu")
|
175 |
rouge = evaluate.load("rouge")
|
|
|
176 |
|
177 |
|
178 |
-
def
|
179 |
if debug:
|
180 |
print("predictions:", predictions)
|
181 |
print("references:", references)
|
@@ -184,7 +185,17 @@ def calc_bleu_rouge_scores(predictions, references, debug=False):
|
|
184 |
predictions=predictions, references=references, max_order=1
|
185 |
)
|
186 |
rouge_scores = rouge.compute(predictions=predictions, references=references)
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
if debug:
|
190 |
print("result:", result)
|
|
|
173 |
|
174 |
bleu = evaluate.load("bleu")
|
175 |
rouge = evaluate.load("rouge")
|
176 |
+
bert_score = evaluate.load("bertscore")
|
177 |
|
178 |
|
179 |
+
def calc_perf_scores(predictions, references, debug=False):
|
180 |
if debug:
|
181 |
print("predictions:", predictions)
|
182 |
print("references:", references)
|
|
|
185 |
predictions=predictions, references=references, max_order=1
|
186 |
)
|
187 |
rouge_scores = rouge.compute(predictions=predictions, references=references)
|
188 |
+
bert_scores = bert_score.compute(
|
189 |
+
predictions=predictions,
|
190 |
+
references=references,
|
191 |
+
lang="en",
|
192 |
+
model_type="microsoft/deberta-large-mnli",
|
193 |
+
)
|
194 |
+
result = {
|
195 |
+
"bleu_scores": bleu_scores,
|
196 |
+
"rouge_scores": rouge_scores,
|
197 |
+
"bert_scores": bert_scores,
|
198 |
+
}
|
199 |
|
200 |
if debug:
|
201 |
print("result:", result)
|