Spaces:
Sleeping
Sleeping
rahulkiitk
commited on
Commit
·
632338f
1
Parent(s):
91d52f9
Updating Leaderboard code
Browse files- app.py +222 -0
- requirements.txt +10 -0
- sample_prediction.csv +10 -0
- script.py +374 -0
- submissions/.DS_Store +0 -0
- submissions/baseline/baseline.csv +6 -0
- submissions/modify.sh +28 -0
- tests/test.json +0 -0
- tests/test_sql.json +0 -0
- uploads.py +382 -0
app.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
5 |
+
from huggingface_hub import HfApi
|
6 |
+
from uploads import add_new_eval
|
7 |
+
|
8 |
+
CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
|
9 |
+
CITATION_BUTTON_TEXT = r"""@inproceedings{kumar-etal-2024-booksql,
|
10 |
+
title = "BookSQL: A Large Scale Text-to-SQL Dataset for Accounting Domain",
|
11 |
+
author = "Kumar, Rahul and Raja, Amar and Harsola, Shrutendra and Subrahmaniam, Vignesh and Modi, Ashutosh",
|
12 |
+
booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics",
|
13 |
+
month = "march",
|
14 |
+
year = "2024",
|
15 |
+
address = "Mexico City, Mexico",
|
16 |
+
publisher = "Association for Computational Linguistics"
|
17 |
+
}"""
|
18 |
+
|
19 |
+
api = HfApi()
|
20 |
+
TOKEN = os.environ.get("TOKEN", None)
|
21 |
+
LEADERBOARD_PATH = f"Exploration-lab/BookSQL-Leaderboard"
|
22 |
+
|
23 |
+
|
24 |
+
def restart_space():
|
25 |
+
api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
|
26 |
+
|
27 |
+
|
28 |
+
# Function to load data from a given CSV file
|
29 |
+
def baseline_load_data(tasks):
|
30 |
+
# version = version.replace("%", "p")
|
31 |
+
file_path = f"submissions/baseline/baseline.csv" # Replace with your file paths
|
32 |
+
df = pd.read_csv(file_path)
|
33 |
+
|
34 |
+
# we only want specific columns and in a specific order
|
35 |
+
|
36 |
+
# column_names = [
|
37 |
+
# "Method",
|
38 |
+
# "Submitted By",
|
39 |
+
# "L-NER",
|
40 |
+
# "RR",
|
41 |
+
# "CJPE",
|
42 |
+
# "BAIL",
|
43 |
+
# "LSI",
|
44 |
+
# "PCR",
|
45 |
+
# "SUMM",
|
46 |
+
# "Average",
|
47 |
+
# ]
|
48 |
+
column_names = [
|
49 |
+
"Method",
|
50 |
+
"Submitted By",
|
51 |
+
"EMA",
|
52 |
+
"EX",
|
53 |
+
"BLEU-4",
|
54 |
+
"ROUGE-L"
|
55 |
+
]
|
56 |
+
if tasks is None:
|
57 |
+
breakpoint()
|
58 |
+
# based on the tasks, remove the columns that are not needed
|
59 |
+
if "EMA" not in tasks:
|
60 |
+
column_names.remove("EMA")
|
61 |
+
if "EX" not in tasks:
|
62 |
+
column_names.remove("EX")
|
63 |
+
if "BLEU-4" not in tasks:
|
64 |
+
column_names.remove("BLEU-4")
|
65 |
+
if "ROUGE-L" not in tasks:
|
66 |
+
column_names.remove("ROUGE-L")
|
67 |
+
|
68 |
+
df = df[column_names]
|
69 |
+
# df = df.sort_values(by="Average", ascending=False)
|
70 |
+
df = df.drop_duplicates(subset=["Method"], keep="first")
|
71 |
+
|
72 |
+
return df
|
73 |
+
|
74 |
+
|
75 |
+
def load_data(tasks):
|
76 |
+
baseline_df = baseline_load_data(tasks)
|
77 |
+
|
78 |
+
return baseline_df
|
79 |
+
|
80 |
+
|
81 |
+
# Function for searching in the leaderboard
|
82 |
+
def search_leaderboard(df, query):
|
83 |
+
if query == "":
|
84 |
+
return df
|
85 |
+
else:
|
86 |
+
return df[df["Method"].str.contains(query)]
|
87 |
+
|
88 |
+
|
89 |
+
# Function to change the version of the leaderboard
|
90 |
+
def change_version(tasks):
|
91 |
+
new_df = load_data(tasks)
|
92 |
+
return new_df
|
93 |
+
|
94 |
+
|
95 |
+
# Initialize Gradio app
|
96 |
+
demo = gr.Blocks()
|
97 |
+
|
98 |
+
with demo:
|
99 |
+
gr.Markdown(
|
100 |
+
"""
|
101 |
+
## 🥇 BookSQL Leaderboard
|
102 |
+
Given the importance and wide prevalence of business databases across the world, the proposed dataset, BookSQL focuses on the finance and accounting domain. Accounting databases are used across a wide spectrum of industries like construction, healthcare, retail, educational services, insurance, restaurant, real estate, etc. Business in these industries arranges their financial transactions into their own different set of categories (called a chart of accounts Industry Details in accounting terminology.
|
103 |
+
Text-to-SQL system developed on BookSQL will be robust at handling various types of accounting databases. The total size of the dataset is 1 million. The dataset is prepared under financial experts' supervision, and the dataset's statistics are provided in below table. The dataset consists of 27 businesses, and each business has around 35k - 40k transactions
|
104 |
+
Read more at [https://exploration-lab.github.io/BookSQL/](https://exploration-lab.github.io/BookSQL/).
|
105 |
+
Please follow this format for uploading prediction file (https://huggingface.co/spaces/Exploration-Lab/BookSQL/blob/main/sample_prediction.csv)
|
106 |
+
"""
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
with gr.Accordion("📙 Citation", open=False):
|
111 |
+
citation_button = gr.Textbox(
|
112 |
+
value=CITATION_BUTTON_TEXT,
|
113 |
+
label=CITATION_BUTTON_LABEL,
|
114 |
+
elem_id="citation-button",
|
115 |
+
show_copy_button=True,
|
116 |
+
) # .style(show_copy_button=True)
|
117 |
+
|
118 |
+
with gr.Tabs():
|
119 |
+
with gr.TabItem("Leaderboard"):
|
120 |
+
|
121 |
+
with gr.Row():
|
122 |
+
# tasks_checkbox = gr.CheckboxGroup(
|
123 |
+
# label="Select Tasks",
|
124 |
+
# choices=["L-NER", "RR", "CJPE", "BAIL", "LSI", "PCR", "SUMM"],
|
125 |
+
# value=["L-NER", "RR", "CJPE", "BAIL", "LSI", "PCR", "SUMM"],
|
126 |
+
# )
|
127 |
+
tasks_checkbox = gr.CheckboxGroup(
|
128 |
+
label="Select Tasks",
|
129 |
+
choices=["EMA","EX","BLEU-4","ROUGE-L"],
|
130 |
+
value=["EMA","EX","BLEU-4","ROUGE-L"],
|
131 |
+
)
|
132 |
+
|
133 |
+
with gr.Row():
|
134 |
+
search_bar = gr.Textbox(
|
135 |
+
placeholder="Search for methods...",
|
136 |
+
show_label=False,
|
137 |
+
)
|
138 |
+
|
139 |
+
leaderboard_table = gr.components.Dataframe(
|
140 |
+
value=load_data(
|
141 |
+
# "baseline",
|
142 |
+
["EMA","EX","BLEU-4","ROUGE-L"],
|
143 |
+
),
|
144 |
+
interactive=True,
|
145 |
+
visible=True,
|
146 |
+
)
|
147 |
+
|
148 |
+
# version_dropdown.change(
|
149 |
+
# change_version,
|
150 |
+
# inputs=[model_dropdown, version_dropdown, tasks_checkbox],
|
151 |
+
# outputs=leaderboard_table,
|
152 |
+
# )
|
153 |
+
|
154 |
+
# model_dropdown.change(
|
155 |
+
# change_version,
|
156 |
+
# inputs=[model_dropdown, version_dropdown, tasks_checkbox],
|
157 |
+
# outputs=leaderboard_table,
|
158 |
+
# )
|
159 |
+
|
160 |
+
search_bar.change(
|
161 |
+
search_leaderboard,
|
162 |
+
inputs=[
|
163 |
+
leaderboard_table,
|
164 |
+
search_bar,
|
165 |
+
# tasks_checkbox
|
166 |
+
],
|
167 |
+
outputs=leaderboard_table,
|
168 |
+
)
|
169 |
+
|
170 |
+
tasks_checkbox.change(
|
171 |
+
change_version,
|
172 |
+
inputs=[tasks_checkbox],
|
173 |
+
outputs=leaderboard_table,
|
174 |
+
)
|
175 |
+
|
176 |
+
with gr.Accordion("Submit a new model for evaluation"):
|
177 |
+
with gr.Row():
|
178 |
+
with gr.Column():
|
179 |
+
method_name_textbox = gr.Textbox(label="Method name")
|
180 |
+
url_textbox = gr.Textbox(label="Url to model information")
|
181 |
+
with gr.Column():
|
182 |
+
organisation = gr.Textbox(label="Organisation")
|
183 |
+
mail = gr.Textbox(label="Contact email")
|
184 |
+
file_output = gr.File()
|
185 |
+
|
186 |
+
submit_button = gr.Button("Submit Eval")
|
187 |
+
submission_result = gr.Markdown()
|
188 |
+
submit_button.click(
|
189 |
+
add_new_eval,
|
190 |
+
[
|
191 |
+
method_name_textbox,
|
192 |
+
url_textbox,
|
193 |
+
file_output,
|
194 |
+
organisation,
|
195 |
+
mail,
|
196 |
+
],
|
197 |
+
submission_result,
|
198 |
+
)
|
199 |
+
|
200 |
+
gr.Markdown(
|
201 |
+
"""
|
202 |
+
## Quick Links
|
203 |
+
|
204 |
+
- [**GitHub Repository**](https://github.com/exploration-lab/BookSQL): Access the source code, fine-tuning scripts, and additional resources for the BookSQL dataset.
|
205 |
+
- [**arXiv Paper**](#): Detailed information about the BookSQL dataset and its significance in unlearning tasks.
|
206 |
+
- [**Dataset on Hugging Face**](https://huggingface.co/datasets/Exploration-Lab/BookSQL): Direct link to download the BookSQL dataset.
|
207 |
+
|
208 |
+
|
209 |
+
"""
|
210 |
+
)
|
211 |
+
|
212 |
+
# scheduler = BackgroundScheduler()
|
213 |
+
# scheduler.add_job(restart_space, "interval", seconds=1800)
|
214 |
+
# scheduler.start()
|
215 |
+
# demo.queue(default_concurrency_limit=40).launch()
|
216 |
+
|
217 |
+
# demo.launch()
|
218 |
+
scheduler = BackgroundScheduler()
|
219 |
+
scheduler.add_job(restart_space, "interval", seconds=3600)
|
220 |
+
scheduler.start()
|
221 |
+
# demo.launch(debug=True)
|
222 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seaborn
|
2 |
+
scipy
|
3 |
+
datasets==2.14.5
|
4 |
+
gradio
|
5 |
+
huggingface-hub==0.18.0
|
6 |
+
numpy==1.24.2
|
7 |
+
APScheduler==3.10.1
|
8 |
+
evaluate
|
9 |
+
rouge_score
|
10 |
+
sqlparse
|
sample_prediction.csv
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,id,pred_sql
|
2 |
+
0,0,SELECT * FROM employees WHERE department_id = 3
|
3 |
+
1,1,UPDATE customers SET status = 'active' WHERE customer_id = 42
|
4 |
+
2,2,"INSERT INTO orders (order_id, order_date, customer_id) VALUES (1001, '2023-06-01', 5)"
|
5 |
+
3,3,DELETE FROM products WHERE product_id = 200
|
6 |
+
4,4,"SELECT name, salary FROM employees WHERE salary > 50000"
|
7 |
+
5,5,UPDATE products SET price = price * 1.1 WHERE category = 'electronics'
|
8 |
+
6,6,"INSERT INTO users (user_id, username, email) VALUES (10, 'jdoe', 'jdoe@example.com')"
|
9 |
+
7,7,DELETE FROM sessions WHERE last_active < '2023-01-01'
|
10 |
+
8,8,SELECT DISTINCT category FROM products
|
script.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from email.utils import parseaddr
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
import os
|
4 |
+
import datetime
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
|
8 |
+
import evaluate as nlp_evaluate
|
9 |
+
import re
|
10 |
+
import sqlite3
|
11 |
+
import random
|
12 |
+
from tqdm import tqdm
|
13 |
+
import sys
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
from get_exact_and_f1_score.ext_services.jsql_parser import JSQLParser
|
18 |
+
from get_exact_and_f1_score.metrics.partial_match_eval.evaluate import evaluate
|
19 |
+
|
20 |
+
random.seed(10001)
|
21 |
+
|
22 |
+
bleu = nlp_evaluate.load("bleu")
|
23 |
+
rouge = nlp_evaluate.load('rouge')
|
24 |
+
|
25 |
+
|
26 |
+
LEADERBOARD_PATH = "Exploration-Lab/BookSQL-Leaderboard"
|
27 |
+
RESULTS_PATH = "Exploration-Lab/BookSQL-Leaderboard-results"
|
28 |
+
api = HfApi()
|
29 |
+
TOKEN = os.environ.get("TOKEN", None)
|
30 |
+
YEAR_VERSION = "2024"
|
31 |
+
|
32 |
+
sqlite_path = "accounting/accounting_for_testing.sqlite"
|
33 |
+
|
34 |
+
|
35 |
+
_jsql_parser = JSQLParser.create()
|
36 |
+
|
37 |
+
def format_error(msg):
|
38 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>"
|
39 |
+
|
40 |
+
|
41 |
+
def format_warning(msg):
|
42 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>"
|
43 |
+
|
44 |
+
|
45 |
+
def format_log(msg):
|
46 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>"
|
47 |
+
|
48 |
+
|
49 |
+
def model_hyperlink(link, model_name):
|
50 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
51 |
+
|
52 |
+
|
53 |
+
def input_verification(method_name, url, path_to_file, organisation, mail):
|
54 |
+
for input in [method_name, url, path_to_file, organisation, mail]:
|
55 |
+
if input == "":
|
56 |
+
return format_warning("Please fill all the fields.")
|
57 |
+
|
58 |
+
# Very basic email parsing
|
59 |
+
_, parsed_mail = parseaddr(mail)
|
60 |
+
if not "@" in parsed_mail:
|
61 |
+
return format_warning("Please provide a valid email adress.")
|
62 |
+
|
63 |
+
if path_to_file is None:
|
64 |
+
return format_warning("Please attach a file.")
|
65 |
+
|
66 |
+
return parsed_mail
|
67 |
+
|
68 |
+
def replace_current_date_and_now(_sql, _date):
|
69 |
+
_sql = _sql.replace('current_date', "\'"+_date+"\'")
|
70 |
+
_sql = _sql.replace(', now', ", \'"+_date+"\'")
|
71 |
+
return _sql
|
72 |
+
|
73 |
+
def remove_gold_Non_exec(data,df1, sqlite_path):
|
74 |
+
|
75 |
+
con = sqlite3.connect(sqlite_path)
|
76 |
+
cur = con.cursor()
|
77 |
+
|
78 |
+
out, non_exec=[], []
|
79 |
+
new_df = df1.copy()
|
80 |
+
new_df.loc[:, 'Exec/Non-Exec'] = 0
|
81 |
+
for i,s in tqdm(enumerate(data)):
|
82 |
+
_sql = str(s).replace('"', "'").lower()
|
83 |
+
_sql = replace_current_date_and_now(_sql, '2022-06-01')
|
84 |
+
_sql = replace_percent_symbol_y(_sql)
|
85 |
+
try:
|
86 |
+
cur.execute(_sql)
|
87 |
+
res = cur.fetchall()
|
88 |
+
out.append(i)
|
89 |
+
except:
|
90 |
+
non_exec.append(i)
|
91 |
+
print("_sql: ", _sql)
|
92 |
+
|
93 |
+
new_df.loc[out, 'Exec/Non-Exec'] = 1
|
94 |
+
con.close()
|
95 |
+
return out, non_exec, new_df
|
96 |
+
|
97 |
+
def remove_data_from_index(data, ind_list):
|
98 |
+
new_data=[]
|
99 |
+
for i in ind_list:
|
100 |
+
new_data.append(data[i])
|
101 |
+
return new_data
|
102 |
+
|
103 |
+
def get_exec_match_acc(gold, pred):
|
104 |
+
assert len(gold)==len(pred)
|
105 |
+
count=0
|
106 |
+
goldd = [re.sub(' +', ' ', str(g).replace("'", '"').lower()) for g in gold]
|
107 |
+
predd = [re.sub(' +', ' ', str(p).replace("'", '"').lower()) for p in pred]
|
108 |
+
# for g, p in zip(gold, pred):
|
109 |
+
# #extra space, double quotes, lower_case
|
110 |
+
# gg = re.sub(' +', ' ', str(g).replace("'", '"').lower())
|
111 |
+
# gg = re.sub(' +', ' ', str(p).replace("'", '"').lower())
|
112 |
+
# if gold==pred:
|
113 |
+
# count+=1
|
114 |
+
|
115 |
+
goldd = _jsql_parser.translate_batch(goldd)
|
116 |
+
predd = _jsql_parser.translate_batch(predd)
|
117 |
+
pcm_f1_scores = evaluate(goldd, predd)
|
118 |
+
pcm_em_scores = evaluate(goldd, predd, exact_match=True)
|
119 |
+
|
120 |
+
_pcm_f1_scores, _pcm_em_scores=[], []
|
121 |
+
for f1, em in zip(pcm_f1_scores, pcm_em_scores):
|
122 |
+
if type(f1)==float and type(em)==float:
|
123 |
+
_pcm_f1_scores.append(f1)
|
124 |
+
_pcm_em_scores.append(em)
|
125 |
+
|
126 |
+
assert len(_pcm_f1_scores) == len(_pcm_em_scores)
|
127 |
+
|
128 |
+
jsql_error_count=0 ####JSQLError
|
129 |
+
for i, score in enumerate(pcm_f1_scores):
|
130 |
+
if type(score)==str:
|
131 |
+
jsql_error_count+=1
|
132 |
+
|
133 |
+
print("JSQLError in sql: ", jsql_error_count)
|
134 |
+
|
135 |
+
return sum(_pcm_em_scores) / len(_pcm_em_scores), sum(_pcm_f1_scores) / len(_pcm_f1_scores)
|
136 |
+
|
137 |
+
def replace_percent_symbol_y(_sql):
|
138 |
+
_sql = _sql.replace('%y', "%Y")
|
139 |
+
return _sql
|
140 |
+
|
141 |
+
|
142 |
+
def get_exec_results(sqlite_path, scores, df, flag, gold_sql_map_res={}):
|
143 |
+
|
144 |
+
con = sqlite3.connect(sqlite_path)
|
145 |
+
cur = con.cursor()
|
146 |
+
|
147 |
+
i,j,count=0,0,0
|
148 |
+
out,non_exec={},{}
|
149 |
+
new_df = df.copy()
|
150 |
+
responses=[]
|
151 |
+
for s in tqdm(scores):
|
152 |
+
_sql = str(s).replace('"', "'").lower()
|
153 |
+
_sql = replace_current_date_and_now(_sql, '2022-06-01')
|
154 |
+
_sql = replace_percent_symbol_y(_sql)
|
155 |
+
try:
|
156 |
+
cur.execute(_sql)
|
157 |
+
res = cur.fetchall()
|
158 |
+
out[i] = str(res)
|
159 |
+
except Exception as err:
|
160 |
+
non_exec[i]=err
|
161 |
+
i+=1
|
162 |
+
|
163 |
+
if flag=='g':
|
164 |
+
new_df.loc[list(out.keys()), 'GOLD_res'] = list(out.values())
|
165 |
+
# assert len(gold_sql_map_res)==count
|
166 |
+
if flag=='p':
|
167 |
+
new_df.loc[list(out.keys()), 'PRED_res'] = list(out.values())
|
168 |
+
if flag=='d':
|
169 |
+
new_df.loc[list(out.keys()), 'DEBUG_res'] = list(out.values())
|
170 |
+
|
171 |
+
con.close()
|
172 |
+
return out, non_exec, new_df
|
173 |
+
|
174 |
+
def get_scores(gold_dict, pred_dict):
|
175 |
+
exec_count, non_exec_count=0, 0
|
176 |
+
none_count=0
|
177 |
+
correct_sql, incorrect_sql = [], []
|
178 |
+
for k, res in pred_dict.items():
|
179 |
+
if k in gold_dict:
|
180 |
+
if gold_dict[k]==str(None) or str(None) in gold_dict[k]:
|
181 |
+
none_count+=1
|
182 |
+
continue
|
183 |
+
if res==gold_dict[k]:
|
184 |
+
exec_count+=1
|
185 |
+
correct_sql.append(k)
|
186 |
+
else:
|
187 |
+
non_exec_count+=1
|
188 |
+
incorrect_sql.append(k)
|
189 |
+
|
190 |
+
return exec_count, non_exec_count, none_count, correct_sql, incorrect_sql
|
191 |
+
|
192 |
+
def get_total_gold_none_count(gold_dict):
|
193 |
+
none_count, ok_count=0, 0
|
194 |
+
for k, res in gold_dict.items():
|
195 |
+
if res==str(None) or str(None) in res:
|
196 |
+
none_count+=1
|
197 |
+
else: ok_count+=1
|
198 |
+
return ok_count, none_count
|
199 |
+
|
200 |
+
|
201 |
+
def evaluate(df):
|
202 |
+
# df - [id, pred_sql]
|
203 |
+
pred_sql = df['pred_sql'].to_list()
|
204 |
+
ids = df['id'].to_list()
|
205 |
+
f = open(f"tests/test.json")
|
206 |
+
questions_and_ids = json.load(f)
|
207 |
+
ts = open(f"tests/test_sql.json")
|
208 |
+
gold_sql = json.load(ts)
|
209 |
+
|
210 |
+
gold_sql_list=[]
|
211 |
+
pred_sql_list=[]
|
212 |
+
questions_list=[]
|
213 |
+
for idx, pred in zip(ids, pred_sql):
|
214 |
+
ques = questions_and_ids[idx]['Query']
|
215 |
+
gd_sql = gold_sql[idx]['SQL']
|
216 |
+
gold_sql_list.append(gd_sql)
|
217 |
+
pred_sql_list.append(pred_sql_list)
|
218 |
+
questions_list.append(ques)
|
219 |
+
|
220 |
+
df = pd.DataFrame({'NLQ':questions_list, 'GOLD SQL':gold_sql_list, 'PREDICTED SQL':pred_sql_list})
|
221 |
+
|
222 |
+
test_size = len(df)
|
223 |
+
|
224 |
+
pred_score = df['PREDICTED SQL'].str.lower().values
|
225 |
+
# debug_score = df['DEBUGGED SQL'].str.lower().values
|
226 |
+
gold_score1 = df['GOLD SQL'].str.lower().values
|
227 |
+
|
228 |
+
|
229 |
+
print("Checking non-exec Gold sql query")
|
230 |
+
gold_exec, gold_not_exec, new_df = remove_gold_Non_exec(gold_score1, df, sqlite_path)
|
231 |
+
print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size))
|
232 |
+
print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size))
|
233 |
+
|
234 |
+
|
235 |
+
prev_non_exec_df = new_df[new_df['Exec/Non-Exec'] == 0]
|
236 |
+
new_df = new_df[new_df['Exec/Non-Exec']==1]
|
237 |
+
|
238 |
+
prev_non_exec_df.reset_index(inplace=True)
|
239 |
+
new_df.reset_index(inplace=True)
|
240 |
+
|
241 |
+
#Removing Non-exec sql from data
|
242 |
+
print(f"Removing {len(gold_not_exec)} non-exec sql query from all Gold/Pred/Debug")
|
243 |
+
gold_score1 = remove_data_from_index(gold_score1, gold_exec)
|
244 |
+
pred_score = remove_data_from_index(pred_score, gold_exec)
|
245 |
+
# debug_score = remove_data_from_index(debug_score, gold_exec)
|
246 |
+
gold_score = [[x] for x in gold_score1]
|
247 |
+
|
248 |
+
assert len(gold_score) == len(pred_score) #== len(debug_score)
|
249 |
+
|
250 |
+
pred_bleu_score = bleu.compute(predictions=pred_score, references=gold_score)
|
251 |
+
pred_rouge_score = rouge.compute(predictions=pred_score, references=gold_score)
|
252 |
+
pred_exact_match, pred_partial_f1_score = get_exec_match_acc(gold_score1, pred_score)
|
253 |
+
|
254 |
+
print("PREDICTED_vs_GOLD Final bleu_score: ", pred_bleu_score['bleu'])
|
255 |
+
print("PREDICTED_vs_GOLD Final rouge_score: ", pred_rouge_score['rougeL'])
|
256 |
+
print("PREDICTED_vs_GOLD Exact Match Accuracy: ", pred_exact_match)
|
257 |
+
print("PREDICTED_vs_GOLD Partial CM F1 score: ", pred_partial_f1_score)
|
258 |
+
print()
|
259 |
+
|
260 |
+
|
261 |
+
new_df.loc[:, 'GOLD_res'] = str(None)
|
262 |
+
new_df.loc[:, 'PRED_res'] = str(None)
|
263 |
+
# new_df.loc[:, 'DEBUG_res'] = str(None)
|
264 |
+
|
265 |
+
print("Getting Gold results")
|
266 |
+
# gout_res_dict, gnon_exec_err_dict, gold_sql_map_res = get_exec_results(cur, gold_score1, 'g')
|
267 |
+
gout_res_dict, gnon_exec_err_dict, new_df = get_exec_results(sqlite_path, gold_score1, new_df, 'g')
|
268 |
+
|
269 |
+
total_gold_ok_count, total_gold_none_count = get_total_gold_none_count(gout_res_dict)
|
270 |
+
print("Total Gold None count: ", total_gold_none_count)
|
271 |
+
|
272 |
+
print("Getting Pred results")
|
273 |
+
pout_res_dict, pnon_exec_err_dict, new_df = get_exec_results(sqlite_path, pred_score, new_df, 'p')
|
274 |
+
# print("Getting Debug results")
|
275 |
+
# dout_res_dict, dnon_exec_err_dict = get_exec_results(cur, debug_score, 'd')
|
276 |
+
|
277 |
+
print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size))
|
278 |
+
print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size))
|
279 |
+
print()
|
280 |
+
print("PRED Total exec SQL query: {}/{}".format(len(pout_res_dict), len(pred_score)))
|
281 |
+
print("PRED Total non-exec SQL query: {}/{}".format(len(pnon_exec_err_dict), len(pred_score)))
|
282 |
+
print()
|
283 |
+
# print("DEBUG Total exec SQL query: {}/{}".format(len(dout_res_dict), len(debug_score)))
|
284 |
+
# print("DEBUG Total non-exec SQL query: {}/{}".format(len(dnon_exec_err_dict), len(debug_score)))
|
285 |
+
# print()
|
286 |
+
pred_correct_exec_acc_count, pred_incorrect_exec_acc_count, pred_none_count, pred_correct_sql, pred_incorrect_sql = get_scores(gout_res_dict, pout_res_dict)
|
287 |
+
# debug_correct_exec_acc_count, debug_incorrect_exec_acc_count, debug_none_count, debug_correct_sql, debug_incorrect_sql = get_scores(gout_res_dict, dout_res_dict)
|
288 |
+
# print("PRED_vs_GOLD None_count: ", total_gold_none_count)
|
289 |
+
print("PRED_vs_GOLD Correct_Exec_count without None: ", pred_correct_exec_acc_count)
|
290 |
+
print("PRED_vs_GOLD Incorrect_Exec_count without None: ", pred_incorrect_exec_acc_count)
|
291 |
+
print("PRED_vs_GOLD Exec_Accuracy: ", pred_correct_exec_acc_count/total_gold_ok_count)
|
292 |
+
print()
|
293 |
+
|
294 |
+
return pred_exact_match, pred_correct_exec_acc_count/total_gold_ok_count, pred_partial_f1_score, pred_bleu_score['bleu'], pred_rouge_score['rougeL']
|
295 |
+
|
296 |
+
def add_new_eval(
|
297 |
+
method_name: str,
|
298 |
+
url: str,
|
299 |
+
path_to_file: str,
|
300 |
+
organisation: str,
|
301 |
+
mail: str,
|
302 |
+
):
|
303 |
+
|
304 |
+
parsed_mail = input_verification(
|
305 |
+
method_name,
|
306 |
+
url,
|
307 |
+
path_to_file,
|
308 |
+
organisation,
|
309 |
+
mail,
|
310 |
+
)
|
311 |
+
|
312 |
+
# load the file
|
313 |
+
df = pd.read_csv(path_to_file)
|
314 |
+
submission_df = pd.read_csv(path_to_file)
|
315 |
+
|
316 |
+
# modify the df to include metadata
|
317 |
+
df["Method"] = method_name
|
318 |
+
df["url"] = url
|
319 |
+
df["organisation"] = organisation
|
320 |
+
df["mail"] = parsed_mail
|
321 |
+
df["timestamp"] = datetime.datetime.now()
|
322 |
+
|
323 |
+
submission_df = pd.read_csv(path_to_file)
|
324 |
+
submission_df["Method"] = method_name
|
325 |
+
submission_df["Submitted By"] = organisation
|
326 |
+
# upload to spaces using the hf api at
|
327 |
+
|
328 |
+
path_in_repo = f"submissions/{method_name}"
|
329 |
+
file_name = f"{method_name}-{organisation}-{datetime.datetime.now().strftime('%Y-%m-%d')}.csv"
|
330 |
+
|
331 |
+
EM, EX, PCM_F1, BLEU, ROUGE = evaluate(submission_df)
|
332 |
+
submission_df['EM'] = EM
|
333 |
+
submission_df['EX'] = EX
|
334 |
+
# submission_df['PCM_F1'] = PCM_F1
|
335 |
+
submission_df['BLEU'] = BLEU
|
336 |
+
submission_df['ROUGE'] = ROUGE
|
337 |
+
|
338 |
+
# upload the df to spaces
|
339 |
+
import io
|
340 |
+
|
341 |
+
buffer = io.BytesIO()
|
342 |
+
df.to_csv(buffer, index=False) # Write the DataFrame to a buffer in CSV format
|
343 |
+
buffer.seek(0) # Rewind the buffer to the beginning
|
344 |
+
|
345 |
+
api.upload_file(
|
346 |
+
repo_id=RESULTS_PATH,
|
347 |
+
path_in_repo=f"{path_in_repo}/{file_name}",
|
348 |
+
path_or_fileobj=buffer,
|
349 |
+
token=TOKEN,
|
350 |
+
repo_type="dataset",
|
351 |
+
)
|
352 |
+
# read the leaderboard
|
353 |
+
leaderboard_df = pd.read_csv(f"submissions/baseline/baseline.csv")
|
354 |
+
|
355 |
+
# append the new submission_df csv to the leaderboard
|
356 |
+
# leaderboard_df = leaderboard_df._append(submission_df)
|
357 |
+
leaderboard_df = pd.concat([leaderboard_df, submission_df], ignore_index=True)
|
358 |
+
|
359 |
+
# save the new leaderboard
|
360 |
+
# leaderboard_df.to_csv(f"submissions/baseline/baseline.csv", index=False)
|
361 |
+
leaderboard_buffer = io.BytesIO()
|
362 |
+
leaderboard_df.to_csv(leaderboard_buffer, index=False)
|
363 |
+
leaderboard_buffer.seek(0)
|
364 |
+
api.upload_file(
|
365 |
+
repo_id=LEADERBOARD_PATH,
|
366 |
+
path_in_repo=f"submissions/baseline/baseline.csv",
|
367 |
+
path_or_fileobj=leaderboard_buffer,
|
368 |
+
token=TOKEN,
|
369 |
+
repo_type="space",
|
370 |
+
)
|
371 |
+
|
372 |
+
return format_log(
|
373 |
+
f"Method {method_name} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed"
|
374 |
+
)
|
submissions/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
submissions/baseline/baseline.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Method,Submitted By,EMA,EX,BLEU-4,ROUGE-L
|
2 |
+
SEDE,IITK,0.43,0.443,0.69,0.83
|
3 |
+
UniSAr,IITK,0.43,0.47,0.72,0.8
|
4 |
+
RESDSQL,IITK,0.52,0.54,0.74,0.81
|
5 |
+
DIN-SQL + GPT4,IITK,0.09,0.08,0.43,0.68
|
6 |
+
Dfew + GPT4,IITK,0.48,0.67,0.86,0.9
|
submissions/modify.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Loop through each CSV file in the current directory
|
4 |
+
for csv_file in *.csv; do
|
5 |
+
# Check if the file is a regular file
|
6 |
+
if [ -f "$csv_file" ]; then
|
7 |
+
echo "Processing $csv_file..."
|
8 |
+
|
9 |
+
# Temporary file
|
10 |
+
temp_file=$(mktemp)
|
11 |
+
|
12 |
+
# Check if the file has a header
|
13 |
+
if head -1 "$csv_file" | grep -q "Submitted By"; then
|
14 |
+
echo "The 'Submitted By' column already exists in $csv_file."
|
15 |
+
continue
|
16 |
+
fi
|
17 |
+
|
18 |
+
# Add 'Submitted By' column header and 'Baseline' entry for each row
|
19 |
+
awk -v OFS="," 'NR==1 {print $0, "Submitted By"} NR>1 {print $0, "Baseline"}' "$csv_file" > "$temp_file"
|
20 |
+
|
21 |
+
# Move the temporary file to original file
|
22 |
+
mv "$temp_file" "$csv_file"
|
23 |
+
|
24 |
+
echo "Column 'Submitted By' added successfully with 'Baseline' entry in each row for $csv_file."
|
25 |
+
fi
|
26 |
+
done
|
27 |
+
|
28 |
+
echo "All CSV files processed."
|
tests/test.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/test_sql.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
uploads.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from email.utils import parseaddr
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
import os
|
4 |
+
import datetime
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
|
8 |
+
import evaluate as nlp_evaluate
|
9 |
+
import re
|
10 |
+
import sqlite3
|
11 |
+
import random
|
12 |
+
from tqdm import tqdm
|
13 |
+
import sys
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from sqlparse import parse
|
17 |
+
|
18 |
+
random.seed(10001)
|
19 |
+
|
20 |
+
bleu = nlp_evaluate.load("bleu")
|
21 |
+
rouge = nlp_evaluate.load('rouge')
|
22 |
+
|
23 |
+
|
24 |
+
LEADERBOARD_PATH = "Exploration-Lab/BookSQL-Leaderboard"
|
25 |
+
RESULTS_PATH = "Exploration-Lab/BookSQL-Leaderboard"
|
26 |
+
api = HfApi()
|
27 |
+
TOKEN = os.environ.get("TOKEN", None)
|
28 |
+
YEAR_VERSION = "2024"
|
29 |
+
|
30 |
+
sqlite_path = "accounting/accounting_for_testing.sqlite"
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def format_error(msg):
|
35 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>"
|
36 |
+
|
37 |
+
|
38 |
+
def format_warning(msg):
|
39 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>"
|
40 |
+
|
41 |
+
|
42 |
+
def format_log(msg):
|
43 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>"
|
44 |
+
|
45 |
+
|
46 |
+
def model_hyperlink(link, model_name):
|
47 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
48 |
+
|
49 |
+
|
50 |
+
def input_verification(method_name, url, path_to_file, organisation, mail):
|
51 |
+
for input in [method_name, url, path_to_file, organisation, mail]:
|
52 |
+
if input == "":
|
53 |
+
return format_warning("Please fill all the fields.")
|
54 |
+
|
55 |
+
# Very basic email parsing
|
56 |
+
_, parsed_mail = parseaddr(mail)
|
57 |
+
if not "@" in parsed_mail:
|
58 |
+
return format_warning("Please provide a valid email adress.")
|
59 |
+
|
60 |
+
if path_to_file is None:
|
61 |
+
return format_warning("Please attach a file.")
|
62 |
+
|
63 |
+
return parsed_mail
|
64 |
+
|
65 |
+
def replace_current_date_and_now(_sql, _date):
|
66 |
+
_sql = _sql.replace('current_date', "\'"+_date+"\'")
|
67 |
+
_sql = _sql.replace(', now', ", \'"+_date+"\'")
|
68 |
+
return _sql
|
69 |
+
|
70 |
+
def remove_gold_Non_exec(data,df1, sqlite_path):
|
71 |
+
|
72 |
+
con = sqlite3.connect(sqlite_path)
|
73 |
+
cur = con.cursor()
|
74 |
+
|
75 |
+
out, non_exec=[], []
|
76 |
+
new_df = df1.copy()
|
77 |
+
new_df.loc[:, 'Exec/Non-Exec'] = 0
|
78 |
+
for i,s in tqdm(enumerate(data)):
|
79 |
+
_sql = str(s).replace('"', "'").lower()
|
80 |
+
_sql = replace_current_date_and_now(_sql, '2022-06-01')
|
81 |
+
_sql = replace_percent_symbol_y(_sql)
|
82 |
+
try:
|
83 |
+
cur.execute(_sql)
|
84 |
+
res = cur.fetchall()
|
85 |
+
out.append(i)
|
86 |
+
except:
|
87 |
+
non_exec.append(i)
|
88 |
+
# print("_sql: ", _sql)
|
89 |
+
|
90 |
+
new_df.loc[out, 'Exec/Non-Exec'] = 1
|
91 |
+
con.close()
|
92 |
+
return out, non_exec, new_df
|
93 |
+
|
94 |
+
def remove_data_from_index(data, ind_list):
|
95 |
+
new_data=[]
|
96 |
+
for i in ind_list:
|
97 |
+
new_data.append(data[i])
|
98 |
+
return new_data
|
99 |
+
|
100 |
+
def parse_query(query):
|
101 |
+
parsed = parse(query)[0]
|
102 |
+
return parsed
|
103 |
+
|
104 |
+
def normalize_query(query):
|
105 |
+
# Remove comments
|
106 |
+
query = re.sub(r'--.*', '', query)
|
107 |
+
query = re.sub(r'/\*.*?\*/', '', query, flags=re.DOTALL)
|
108 |
+
|
109 |
+
# Remove extra whitespace
|
110 |
+
query = re.sub(r'\s+', ' ', query)
|
111 |
+
|
112 |
+
# Strip leading and trailing whitespace
|
113 |
+
query = query.strip()
|
114 |
+
|
115 |
+
return query.lower()
|
116 |
+
|
117 |
+
def get_exec_match_acc(gold, pred):
|
118 |
+
assert len(gold)==len(pred)
|
119 |
+
correct_sql_count=0
|
120 |
+
count=0
|
121 |
+
goldd = [re.sub(' +', ' ', str(g).replace("'", '"').lower()) for g in gold]
|
122 |
+
predd = [re.sub(' +', ' ', str(p).replace("'", '"').lower()) for p in pred]
|
123 |
+
# for g, p in zip(gold, pred):
|
124 |
+
# #extra space, double quotes, lower_case
|
125 |
+
# gg = re.sub(' +', ' ', str(g).replace("'", '"').lower())
|
126 |
+
# gg = re.sub(' +', ' ', str(p).replace("'", '"').lower())
|
127 |
+
# if gold==pred:
|
128 |
+
# count+=1
|
129 |
+
|
130 |
+
for q1, q2 in zip(goldd, predd):
|
131 |
+
q1 = normalize_query(q1)
|
132 |
+
q2 = normalize_query(q2)
|
133 |
+
|
134 |
+
parsed_query1 = parse_query(q1)
|
135 |
+
parsed_query2 = parse_query(q2)
|
136 |
+
|
137 |
+
if str(parsed_query1) == str(parsed_query2):
|
138 |
+
correct_sql_count+=1
|
139 |
+
|
140 |
+
return correct_sql_count/len(goldd), 0
|
141 |
+
|
142 |
+
def replace_percent_symbol_y(_sql):
|
143 |
+
_sql = _sql.replace('%y', "%Y")
|
144 |
+
return _sql
|
145 |
+
|
146 |
+
|
147 |
+
def get_exec_results(sqlite_path, scores, df, flag, gold_sql_map_res={}):
|
148 |
+
|
149 |
+
con = sqlite3.connect(sqlite_path)
|
150 |
+
cur = con.cursor()
|
151 |
+
|
152 |
+
i,j,count=0,0,0
|
153 |
+
out,non_exec={},{}
|
154 |
+
new_df = df.copy()
|
155 |
+
responses=[]
|
156 |
+
for s in tqdm(scores):
|
157 |
+
_sql = str(s).replace('"', "'").lower()
|
158 |
+
_sql = replace_current_date_and_now(_sql, '2022-06-01')
|
159 |
+
_sql = replace_percent_symbol_y(_sql)
|
160 |
+
try:
|
161 |
+
cur.execute(_sql)
|
162 |
+
res = cur.fetchall()
|
163 |
+
out[i] = str(res)
|
164 |
+
except Exception as err:
|
165 |
+
non_exec[i]=err
|
166 |
+
i+=1
|
167 |
+
|
168 |
+
if flag=='g':
|
169 |
+
new_df.loc[list(out.keys()), 'GOLD_res'] = list(out.values())
|
170 |
+
# assert len(gold_sql_map_res)==count
|
171 |
+
if flag=='p':
|
172 |
+
new_df.loc[list(out.keys()), 'PRED_res'] = list(out.values())
|
173 |
+
if flag=='d':
|
174 |
+
new_df.loc[list(out.keys()), 'DEBUG_res'] = list(out.values())
|
175 |
+
|
176 |
+
con.close()
|
177 |
+
return out, non_exec, new_df
|
178 |
+
|
179 |
+
def get_scores(gold_dict, pred_dict):
|
180 |
+
exec_count, non_exec_count=0, 0
|
181 |
+
none_count=0
|
182 |
+
correct_sql, incorrect_sql = [], []
|
183 |
+
for k, res in pred_dict.items():
|
184 |
+
if k in gold_dict:
|
185 |
+
if gold_dict[k]==str(None) or str(None) in gold_dict[k]:
|
186 |
+
none_count+=1
|
187 |
+
continue
|
188 |
+
if res==gold_dict[k]:
|
189 |
+
exec_count+=1
|
190 |
+
correct_sql.append(k)
|
191 |
+
else:
|
192 |
+
non_exec_count+=1
|
193 |
+
incorrect_sql.append(k)
|
194 |
+
|
195 |
+
return exec_count, non_exec_count, none_count, correct_sql, incorrect_sql
|
196 |
+
|
197 |
+
def get_total_gold_none_count(gold_dict):
|
198 |
+
none_count, ok_count=0, 0
|
199 |
+
for k, res in gold_dict.items():
|
200 |
+
if res==str(None) or str(None) in res:
|
201 |
+
none_count+=1
|
202 |
+
else: ok_count+=1
|
203 |
+
return ok_count, none_count
|
204 |
+
|
205 |
+
|
206 |
+
def Evaluate(df):
|
207 |
+
# df - [id, pred_sql]
|
208 |
+
pred_sql = df['pred_sql'].to_list()
|
209 |
+
ids = df['id'].to_list()
|
210 |
+
f = open(f"tests/test.json")
|
211 |
+
questions_and_ids = json.load(f)
|
212 |
+
ts = open(f"tests/test_sql.json")
|
213 |
+
gold_sql = json.load(ts)
|
214 |
+
|
215 |
+
gold_sql_list=[]
|
216 |
+
pred_sql_list=[]
|
217 |
+
questions_list=[]
|
218 |
+
for idx, pred in zip(ids, pred_sql):
|
219 |
+
ques = questions_and_ids[idx]['Query']
|
220 |
+
gd_sql = gold_sql[idx]['SQL']
|
221 |
+
gold_sql_list.append(gd_sql)
|
222 |
+
pred_sql_list.append(pred)
|
223 |
+
questions_list.append(ques)
|
224 |
+
|
225 |
+
df = pd.DataFrame({'NLQ':questions_list, 'GOLD SQL':gold_sql_list, 'PREDICTED SQL':pred_sql_list})
|
226 |
+
|
227 |
+
test_size = len(df)
|
228 |
+
|
229 |
+
pred_score = df['PREDICTED SQL'].str.lower().values
|
230 |
+
# debug_score = df['DEBUGGED SQL'].str.lower().values
|
231 |
+
gold_score1 = df['GOLD SQL'].str.lower().values
|
232 |
+
|
233 |
+
|
234 |
+
print("Checking non-exec Gold sql query")
|
235 |
+
gold_exec, gold_not_exec, new_df = remove_gold_Non_exec(gold_score1, df, sqlite_path)
|
236 |
+
print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size))
|
237 |
+
print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size))
|
238 |
+
|
239 |
+
|
240 |
+
prev_non_exec_df = new_df[new_df['Exec/Non-Exec'] == 0]
|
241 |
+
new_df = new_df[new_df['Exec/Non-Exec']==1]
|
242 |
+
|
243 |
+
prev_non_exec_df.reset_index(inplace=True)
|
244 |
+
new_df.reset_index(inplace=True)
|
245 |
+
|
246 |
+
#Removing Non-exec sql from data
|
247 |
+
print(f"Removing {len(gold_not_exec)} non-exec sql query from all Gold/Pred/Debug ")
|
248 |
+
gold_score1 = remove_data_from_index(gold_score1, gold_exec)
|
249 |
+
pred_score = remove_data_from_index(pred_score, gold_exec)
|
250 |
+
# debug_score = remove_data_from_index(debug_score, gold_exec)
|
251 |
+
gold_score = [[x] for x in gold_score1]
|
252 |
+
|
253 |
+
assert len(gold_score) == len(pred_score) #== len(debug_score)
|
254 |
+
|
255 |
+
pred_bleu_score = bleu.compute(predictions=pred_score, references=gold_score)
|
256 |
+
pred_rouge_score = rouge.compute(predictions=pred_score, references=gold_score)
|
257 |
+
pred_exact_match, pred_partial_f1_score = get_exec_match_acc(gold_score1, pred_score)
|
258 |
+
|
259 |
+
print("PREDICTED_vs_GOLD Final bleu_score: ", pred_bleu_score['bleu'])
|
260 |
+
print("PREDICTED_vs_GOLD Final rouge_score: ", pred_rouge_score['rougeL'])
|
261 |
+
print("PREDICTED_vs_GOLD Exact Match Accuracy: ", pred_exact_match)
|
262 |
+
# print("PREDICTED_vs_GOLD Partial CM F1 score: ", pred_partial_f1_score)
|
263 |
+
print()
|
264 |
+
|
265 |
+
|
266 |
+
new_df.loc[:, 'GOLD_res'] = str(None)
|
267 |
+
new_df.loc[:, 'PRED_res'] = str(None)
|
268 |
+
# new_df.loc[:, 'DEBUG_res'] = str(None)
|
269 |
+
|
270 |
+
print("Getting Gold results")
|
271 |
+
# gout_res_dict, gnon_exec_err_dict, gold_sql_map_res = get_exec_results(cur, gold_score1, 'g')
|
272 |
+
gout_res_dict, gnon_exec_err_dict, new_df = get_exec_results(sqlite_path, gold_score1, new_df, 'g')
|
273 |
+
|
274 |
+
total_gold_ok_count, total_gold_none_count = get_total_gold_none_count(gout_res_dict)
|
275 |
+
print("Total Gold None count: ", total_gold_none_count)
|
276 |
+
|
277 |
+
print("Getting Pred results")
|
278 |
+
pout_res_dict, pnon_exec_err_dict, new_df = get_exec_results(sqlite_path, pred_score, new_df, 'p')
|
279 |
+
# print("Getting Debug results")
|
280 |
+
# dout_res_dict, dnon_exec_err_dict = get_exec_results(cur, debug_score, 'd')
|
281 |
+
|
282 |
+
print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size))
|
283 |
+
print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size))
|
284 |
+
print()
|
285 |
+
print("PRED Total exec SQL query: {}/{}".format(len(pout_res_dict), len(pred_score)))
|
286 |
+
print("PRED Total non-exec SQL query: {}/{}".format(len(pnon_exec_err_dict), len(pred_score)))
|
287 |
+
print()
|
288 |
+
# print("DEBUG Total exec SQL query: {}/{}".format(len(dout_res_dict), len(debug_score)))
|
289 |
+
# print("DEBUG Total non-exec SQL query: {}/{}".format(len(dnon_exec_err_dict), len(debug_score)))
|
290 |
+
# print()
|
291 |
+
pred_correct_exec_acc_count, pred_incorrect_exec_acc_count, pred_none_count, pred_correct_sql, pred_incorrect_sql = get_scores(gout_res_dict, pout_res_dict)
|
292 |
+
# debug_correct_exec_acc_count, debug_incorrect_exec_acc_count, debug_none_count, debug_correct_sql, debug_incorrect_sql = get_scores(gout_res_dict, dout_res_dict)
|
293 |
+
# print("PRED_vs_GOLD None_count: ", total_gold_none_count)
|
294 |
+
print("PRED_vs_GOLD Correct_Exec_count without None: ", pred_correct_exec_acc_count)
|
295 |
+
print("PRED_vs_GOLD Incorrect_Exec_count without None: ", pred_incorrect_exec_acc_count)
|
296 |
+
print("PRED_vs_GOLD Exec_Accuracy: ", pred_correct_exec_acc_count/total_gold_ok_count)
|
297 |
+
print()
|
298 |
+
|
299 |
+
return pred_exact_match, pred_correct_exec_acc_count/total_gold_ok_count, pred_partial_f1_score, pred_bleu_score['bleu'], pred_rouge_score['rougeL']
|
300 |
+
|
301 |
+
def add_new_eval(
|
302 |
+
method_name: str,
|
303 |
+
url: str,
|
304 |
+
path_to_file: str,
|
305 |
+
organisation: str,
|
306 |
+
mail: str,
|
307 |
+
):
|
308 |
+
|
309 |
+
parsed_mail = input_verification(
|
310 |
+
method_name,
|
311 |
+
url,
|
312 |
+
path_to_file,
|
313 |
+
organisation,
|
314 |
+
mail,
|
315 |
+
)
|
316 |
+
|
317 |
+
# load the file
|
318 |
+
df = pd.read_csv(path_to_file)
|
319 |
+
submission_df = pd.read_csv(path_to_file)
|
320 |
+
|
321 |
+
# modify the df to include metadata
|
322 |
+
df["Method"] = method_name
|
323 |
+
df["url"] = url
|
324 |
+
df["organisation"] = organisation
|
325 |
+
df["mail"] = parsed_mail
|
326 |
+
df["timestamp"] = datetime.datetime.now()
|
327 |
+
|
328 |
+
submission_df = pd.read_csv(path_to_file)
|
329 |
+
submission_df["Method"] = method_name
|
330 |
+
submission_df["Submitted By"] = organisation
|
331 |
+
# upload to spaces using the hf api at
|
332 |
+
|
333 |
+
path_in_repo = f"submissions/{method_name}"
|
334 |
+
file_name = f"{method_name}-{organisation}-{datetime.datetime.now().strftime('%Y-%m-%d')}.csv"
|
335 |
+
|
336 |
+
EM, EX, PCM_F1, BLEU, ROUGE = Evaluate(submission_df)
|
337 |
+
sub_df = pd.DataFrame()
|
338 |
+
sub_df["Method"] = method_name
|
339 |
+
sub_df["Submitted By"] = organisation
|
340 |
+
sub_df['EMA'] = EM
|
341 |
+
sub_df['EX'] = EX
|
342 |
+
# submission_df['PCM_F1'] = PCM_F1
|
343 |
+
sub_df['BLEU-4'] = BLEU
|
344 |
+
sub_df['ROUGE-L'] = ROUGE
|
345 |
+
|
346 |
+
# upload the df to spaces
|
347 |
+
import io
|
348 |
+
|
349 |
+
buffer = io.BytesIO()
|
350 |
+
df.to_csv(buffer, index=False) # Write the DataFrame to a buffer in CSV format
|
351 |
+
buffer.seek(0) # Rewind the buffer to the beginning
|
352 |
+
|
353 |
+
api.upload_file(
|
354 |
+
repo_id=RESULTS_PATH,
|
355 |
+
path_in_repo=f"{path_in_repo}/{file_name}",
|
356 |
+
path_or_fileobj=buffer,
|
357 |
+
token=TOKEN,
|
358 |
+
repo_type="dataset",
|
359 |
+
)
|
360 |
+
# read the leaderboard
|
361 |
+
leaderboard_df = pd.read_csv(f"submissions/baseline/baseline.csv")
|
362 |
+
|
363 |
+
# append the new submission_df csv to the leaderboard
|
364 |
+
# leaderboard_df = leaderboard_df._append(submission_df)
|
365 |
+
leaderboard_df = pd.concat([leaderboard_df, sub_df], ignore_index=True)
|
366 |
+
|
367 |
+
# save the new leaderboard
|
368 |
+
# leaderboard_df.to_csv(f"submissions/baseline/baseline.csv", index=False)
|
369 |
+
leaderboard_buffer = io.BytesIO()
|
370 |
+
leaderboard_df.to_csv(leaderboard_buffer, index=False)
|
371 |
+
leaderboard_buffer.seek(0)
|
372 |
+
api.upload_file(
|
373 |
+
repo_id=LEADERBOARD_PATH,
|
374 |
+
path_in_repo=f"submissions/baseline/baseline.csv",
|
375 |
+
path_or_fileobj=leaderboard_buffer,
|
376 |
+
token=TOKEN,
|
377 |
+
repo_type="space",
|
378 |
+
)
|
379 |
+
|
380 |
+
return format_log(
|
381 |
+
f"Method {method_name} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed"
|
382 |
+
)
|