tdoehmen commited on
Commit
b3eb06a
1 Parent(s): cef725d

added eval sql

Browse files
Files changed (4) hide show
  1. MODEL_README.md +156 -0
  2. app.py +27 -2
  3. requirements.txt +1 -0
  4. validate_sql.py +57 -0
MODEL_README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama2
3
+ inference:
4
+ parameters:
5
+ do_sample: false
6
+ max_length: 200
7
+ widget:
8
+ - text: "CREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- how many stadiums in total?\n\nSELECT"
9
+ example_title: "Number stadiums"
10
+ - text: "CREATE TABLE work_orders ( ID NUMBER, CREATED_AT TEXT, COST FLOAT, INVOICE_AMOUNT FLOAT, IS_DUE BOOLEAN, IS_OPEN BOOLEAN, IS_OVERDUE BOOLEAN, COUNTRY_NAME TEXT, )\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- how many work orders are open?\n\nSELECT"
11
+ example_title: "Open work orders"
12
+ - text: "CREATE TABLE stadium ( stadium_id number, location text, name text, capacity number, highest number, lowest number, average number )\n\nCREATE TABLE singer ( singer_id number, name text, country text, song_name text, song_release_year text, age number, is_male others )\n\nCREATE TABLE concert ( concert_id number, concert_name text, theme text, stadium_id text, year text )\n\nCREATE TABLE singer_in_concert ( concert_id number, singer_id text )\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?\n\nSELECT"
13
+ example_title: "Stadium capacity"
14
+ ---
15
+
16
+ # DucKDB-NSQL-7B
17
+
18
+ ## Model Description
19
+
20
+ NSQL is a family of autoregressive open-source large foundation models (FMs) designed specifically for SQL generation tasks.
21
+
22
+ In this repository we are introducing a new member of NSQL, DuckDB-NSQL. It's based on Meta's original [Llama-2 7B model](https://huggingface.co/meta-llama/Llama-2-7b) and further pre-trained on a dataset of general SQL queries and then fine-tuned on a dataset composed of DuckDB text-to-SQL pairs.
23
+
24
+ ## Training Data
25
+
26
+ The general SQL queries are the SQL subset from [The Stack](https://huggingface.co/datasets/bigcode/the-stack), containing 1M training samples. The samples we transpiled to DuckDB SQL, using [sqlglot](https://github.com/tobymao/sqlglot). The labeled text-to-SQL pairs come [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) that were also transpiled to DuckDB SQL, and 200k synthetically generated DuckDB SQL queries, based on the DuckDB v.0.9.2 documentation.
27
+
28
+ ## Evaluation Data
29
+
30
+ We evaluate our models on a DuckDB-specific benchmark that contains 75 text-to-SQL pairs. The benchmark is available [here](https://github.com/NumbersStationAI/DuckDB-NSQL/).
31
+
32
+ ## Training Procedure
33
+
34
+ DuckDB-NSQL was trained using cross-entropy loss to maximize the likelihood of sequential inputs. For finetuning on text-to-SQL pairs, we only compute the loss over the SQL portion of the pair. The model is trained using 80GB A100s, leveraging data and model parallelism. We pre-trained for 3 epochs and fine-tuned for 10 epochs.
35
+
36
+ ## Intended Use and Limitations
37
+
38
+ The model was designed for text-to-SQL generation tasks from given table schema and natural language prompts. The model works best with the prompt format defined below and outputs.
39
+ In contrast to existing text-to-SQL models, the SQL generation is not contrained to `SELECT` statements, but can generate any valid DuckDB SQL statement, including statements for official DuckDB extensions.
40
+
41
+ ## How to Use
42
+
43
+ Example 1:
44
+
45
+ ```python
46
+ import torch
47
+ from transformers import AutoTokenizer, AutoModelForCausalLM
48
+ tokenizer = AutoTokenizer.from_pretrained("motherduckdb/nsql-duckdb-7B")
49
+ model = AutoModelForCausalLM.from_pretrained("motherduckdb/nsql-duckdb-7B", torch_dtype=torch.bfloat16)
50
+
51
+ text = """CREATE TABLE stadium (
52
+ stadium_id number,
53
+ location text,
54
+ name text,
55
+ capacity number,
56
+ highest number,
57
+ lowest number,
58
+ average number
59
+ )
60
+
61
+ CREATE TABLE singer (
62
+ singer_id number,
63
+ name text,
64
+ country text,
65
+ song_name text,
66
+ song_release_year text,
67
+ age number,
68
+ is_male others
69
+ )
70
+
71
+ CREATE TABLE concert (
72
+ concert_id number,
73
+ concert_name text,
74
+ theme text,
75
+ stadium_id text,
76
+ year text
77
+ )
78
+
79
+ CREATE TABLE singer_in_concert (
80
+ concert_id number,
81
+ singer_id text
82
+ )
83
+
84
+ -- Using valid DuckDB SQL, answer the following questions for the tables provided above.
85
+
86
+ -- What is the maximum, the average, and the minimum capacity of stadiums ?
87
+
88
+ SELECT"""
89
+
90
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
91
+
92
+ generated_ids = model.generate(input_ids, max_length=500)
93
+ print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
94
+ ```
95
+
96
+ Example 2:
97
+
98
+ ```python
99
+ import torch
100
+ from transformers import AutoTokenizer, AutoModelForCausalLM
101
+ tokenizer = AutoTokenizer.from_pretrained("motherduckdb/nsql-duckdb-7B")
102
+ model = AutoModelForCausalLM.from_pretrained("motherduckdb/nsql-duckdb-7B", torch_dtype=torch.bfloat16)
103
+
104
+ text = """CREATE TABLE stadium (
105
+ stadium_id number,
106
+ location text,
107
+ name text,
108
+ capacity number,
109
+ )
110
+
111
+ -- Using valid DuckDB SQL, answer the following questions for the tables provided above.
112
+
113
+ -- how many stadiums in total?
114
+
115
+ SELECT"""
116
+
117
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
118
+
119
+ generated_ids = model.generate(input_ids, max_length=500)
120
+ print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
121
+ ```
122
+
123
+ Example 3:
124
+
125
+ ```python
126
+ import torch
127
+ from transformers import AutoTokenizer, AutoModelForCausalLM
128
+ tokenizer = AutoTokenizer.from_pretrained("motherduckdb/nsql-duckdb-7B")
129
+ model = AutoModelForCausalLM.from_pretrained("motherduckdb/nsql-duckdb-7B", torch_dtype=torch.bfloat16)
130
+
131
+ text = """CREATE TABLE work_orders (
132
+ ID NUMBER,
133
+ CREATED_AT TEXT,
134
+ COST FLOAT,
135
+ INVOICE_AMOUNT FLOAT,
136
+ IS_DUE BOOLEAN,
137
+ IS_OPEN BOOLEAN,
138
+ IS_OVERDUE BOOLEAN,
139
+ COUNTRY_NAME TEXT,
140
+ )
141
+
142
+ -- Using valid DuckDB SQL, answer the following questions for the tables provided above.
143
+
144
+ -- how many work orders are open?
145
+
146
+ SELECT"""
147
+
148
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
149
+
150
+ generated_ids = model.generate(input_ids, max_length=500)
151
+ print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
152
+ ```
153
+
154
+
155
+
156
+ For more information (e.g., run with your local database), please find examples in [this repository](https://github.com/NumbersStationAI/DuckDB-NSQL).
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import streamlit as st
2
  import requests
3
-
 
4
 
5
  PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}{context}\n### Question:\n{question}\n\n### Response:\n"""
6
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
 
 
7
 
8
  def generate_prompt(question, schema):
9
  input = ""
@@ -41,6 +44,24 @@ def generate_sql(question, schema):
41
  with s.post(url, json=body, headers=headers) as resp:
42
  return resp.json()["choices"][0]["text"]
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  st.title("DuckDB-NSQL-7B Demo")
45
 
46
  expander = st.expander("Customize Schema (Optional)")
@@ -56,5 +77,9 @@ text_prompt = st.text_input("What DuckDB SQL query can I write for you?", value=
56
 
57
  if text_prompt:
58
  sql_query = generate_sql(text_prompt, schema)
59
- st.code(sql_query, language="sql")
 
 
 
 
60
 
 
1
  import streamlit as st
2
  import requests
3
+ import subprocess
4
+ import sys
5
 
6
  PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}{context}\n### Question:\n{question}\n\n### Response:\n"""
7
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
8
+ TMP_DIR = "tmp"
9
+ ERROR_MESSAGE = "Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend."
10
 
11
  def generate_prompt(question, schema):
12
  input = ""
 
44
  with s.post(url, json=body, headers=headers) as resp:
45
  return resp.json()["choices"][0]["text"]
46
 
47
+ def validate_sql(query, schema):
48
+ try:
49
+ # Define subprocess
50
+ process = subprocess.Popen(
51
+ [sys.executable, './validate_sql.py', query, schema],
52
+ stdout=subprocess.PIPE,
53
+ stderr=subprocess.PIPE
54
+ )
55
+ # Get output and potential parser, and binder error message
56
+ stdout, stderr = process.communicate(timeout=0.5)
57
+ if stderr:
58
+ return False
59
+ return True
60
+ except subprocess.TimeoutExpired:
61
+ process.kill()
62
+ # timeout reached, so parsing and binding was very likely successful
63
+ return True
64
+
65
  st.title("DuckDB-NSQL-7B Demo")
66
 
67
  expander = st.expander("Customize Schema (Optional)")
 
77
 
78
  if text_prompt:
79
  sql_query = generate_sql(text_prompt, schema)
80
+ valid = validate_sql(sql_query, schema)
81
+ if not valid:
82
+ st.code(ERROR_MESSAGE, language="text")
83
+ else:
84
+ st.code(sql_query, language="sql")
85
 
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ duckdb==0.9.2
validate_sql.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import shutil
3
+ import os
4
+ import uuid
5
+ import duckdb
6
+ from duckdb import ParserException, SyntaxException, BinderException, CatalogException
7
+
8
+ TMP_DIR = "tmp"
9
+ class WithDuckDBConnectionInTmpDir(object):
10
+ def __init__(self):
11
+ self.tmp_dir = TMP_DIR + str(uuid.uuid1())
12
+ os.makedirs(self.tmp_dir)
13
+ self.original_wd = os.getcwd()
14
+
15
+ def __enter__(self):
16
+ os.chdir(self.tmp_dir)
17
+ self.con = duckdb.connect()
18
+ self.con.execute("SET enable_external_access=False")
19
+ return self.con
20
+
21
+ def __exit__(self, *args):
22
+ self.con.close()
23
+ os.chdir(self.original_wd)
24
+ shutil.rmtree(self.tmp_dir)
25
+
26
+ def validate_query(query, schemas):
27
+ try:
28
+ with WithDuckDBConnectionInTmpDir() as duckdb_conn:
29
+ # register schemas
30
+ for schema in schemas.split(";"):
31
+ duckdb_conn.execute(schema)
32
+ cursor = duckdb_conn.cursor()
33
+ cursor.execute(query)
34
+ except ParserException as e:
35
+ raise e
36
+ except SyntaxException as e:
37
+ raise e
38
+ except BinderException as e:
39
+ raise e
40
+ except Exception as e:
41
+ message = str(e)
42
+ if "but it exists" in message and "extension" in message:
43
+ print(message)
44
+ elif message.startswith("Catalog Error: Table with name"):
45
+ raise e
46
+ elif "Catalog Error: Table Function with name" in message:
47
+ raise e
48
+ elif "Catalog Error: Copy Function" in message:
49
+ raise e
50
+ else:
51
+ print(message)
52
+
53
+ if __name__ == '__main__':
54
+ if len(sys.argv) > 2:
55
+ validate_query(sys.argv[1], sys.argv[2])
56
+ else:
57
+ print("No query provided.")