NGrov commited on
Commit
23ceea2
·
1 Parent(s): 5068739

evaluate script

Browse files
Files changed (2) hide show
  1. db_schemas.json +0 -0
  2. evaluate_with_db.py +67 -0
db_schemas.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluate_with_db.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import json
3
+ import sqlite3
4
+ from tqdm import tqdm
5
+ from typing import List
6
+ import os
7
+ from pathlib import Path
8
+
9
+ db_schemas_path = "db_schemas.json"
10
+ model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
11
+
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+
15
+
16
+ def query_db(question: str, db_path: str) -> dict:
17
+ try:
18
+ # assert db_path.endswith('.sqlite')
19
+ con = sqlite3.connect(db_path)
20
+ cur = con.cursor()
21
+ cur.execute(question)
22
+ data = cur.fetchall()
23
+ return json.dumps(data)
24
+ except Exception as e:
25
+ print(question, " ", e)
26
+ pass
27
+
28
+
29
+ def evaluate(eval_dataset: List[dict]):
30
+ reference = []
31
+ gen_queries = []
32
+
33
+ with open(db_schemas_path, "r") as schemas:
34
+ db_schema_dict = json.load(schemas)
35
+
36
+ for data in tqdm(eval_dataset, total=len(eval_dataset), desc="Executing queries"):
37
+ question = data["question"]
38
+ schema = data["db_id"]
39
+
40
+ filenames = [
41
+ i for i in os.listdir(Path(DB_PATH, schema)) if i.endswith(SQLITE_SUFFIX)
42
+ ]
43
+ path_to_db = Path(DB_PATH, schema, filenames[0])
44
+
45
+ input_text = " ".join(
46
+ ["Question: ", question, "Schema:", db_schema_dict[schema]]
47
+ )
48
+ model_inputs = tokenizer(input_text, return_tensors="pt")
49
+ outputs = model.generate(**model_inputs, max_length=512)
50
+
51
+ output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
52
+ reference.append(query_db(data["query"], path_to_db))
53
+ gen_queries.append(query_db(output_text, path_to_db))
54
+
55
+ equal_results = [ref == q for ref, q in zip(reference, gen_queries)]
56
+ eq_results_when_reference_works = [
57
+ ref == q for ref, q in zip(reference, gen_queries) if ref is not None
58
+ ]
59
+ num_of_working_ref = len([ref for ref in reference if ref is not None])
60
+ print("Length of eval dataset: ", len(eval_dataset))
61
+ print("Working references: ", num_of_working_ref)
62
+ print("Correct queries in labels: ", num_of_working_ref / len(eval_dataset))
63
+ print("Accuracy with whole dataset: ", sum(equal_results) / len(eval_dataset))
64
+ print(
65
+ "Accuracy with only working references: ",
66
+ sum(eq_results_when_reference_works) / num_of_working_ref,
67
+ )