tdoehmen commited on
Commit
1d0f3f8
1 Parent(s): cca4edc

updated validation

Browse files
Files changed (2) hide show
  1. app.py +143 -40
  2. validate_sql.py +27 -16
app.py CHANGED
@@ -1,33 +1,44 @@
1
  import streamlit as st
2
  import requests
3
- import subprocess
4
- import sys
 
 
5
 
6
- PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}{context}\n### Question:\n{question}\n\n### Response:\n"""
7
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
8
- TMP_DIR = "tmp"
9
- ERROR_MESSAGE = "Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend."
 
10
 
11
  def generate_prompt(question, schema):
12
  input = ""
13
  if schema:
 
 
 
 
 
 
 
 
 
14
  input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501
15
  schema=schema
16
  )
17
  prompt = PROMPT_TEMPLATE.format(
18
- instruction = INSTRUCTION_TEMPLATE.format(
19
- has_schema="."
20
- if schema == ""
21
- else ", given a duckdb database schema."
22
  ),
23
- context="",
24
  input=input,
25
- question=question + ". Use DuckDB shorthand if possible.",
26
  )
27
  return prompt
28
 
 
29
  def generate_sql(question, schema):
30
  prompt = generate_prompt(question, schema)
 
31
  s = requests.Session()
32
  api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
33
  url = f"{api_base}/v1/completions"
@@ -36,50 +47,142 @@ def generate_sql(question, schema):
36
  "prompt": prompt,
37
  "temperature": 0.1,
38
  "max_tokens": 200,
39
- "stop":'<s>',
40
- "n": 1
41
  }
42
  headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
43
-
44
  with s.post(url, json=body, headers=headers) as resp:
45
- return resp.json()["choices"][0]["text"]
 
 
 
 
 
 
 
 
46
 
47
  def validate_sql(query, schema):
48
- try:
49
- # Define subprocess
50
- process = subprocess.Popen(
51
- [sys.executable, './validate_sql.py', query, schema],
52
- stdout=subprocess.PIPE,
53
- stderr=subprocess.PIPE
54
- )
55
- # Get output and potential parser, and binder error message
56
- stdout, stderr = process.communicate(timeout=0.5)
57
- if stderr:
58
- return False
59
- return True
60
- except subprocess.TimeoutExpired:
61
- process.kill()
62
- # timeout reached, so parsing and binding was very likely successful
63
- return True
64
 
65
  st.title("DuckDB-NSQL-7B Demo")
66
 
67
  expander = st.expander("Customize Schema (Optional)")
68
- expander.text("Execute this query in your DuckDB database to get your current schema:")
69
- expander.code("SELECT array_to_string(list(sql), '\\n') from duckdb_tables()", language="sql")
 
 
 
 
70
 
71
  # Input field for text prompt
72
- default_schema = 'CREATE TABLE rideshare(hvfhs_license_num VARCHAR, dispatching_base_num VARCHAR, originating_base_num VARCHAR, request_datetime TIMESTAMP, on_scene_datetime TIMESTAMP, pickup_datetime TIMESTAMP, dropoff_datetime TIMESTAMP, PULocationID BIGINT, DOLocationID BIGINT, trip_miles DOUBLE, trip_time BIGINT, base_passenger_fare DOUBLE, tolls DOUBLE, bcf DOUBLE, sales_tax DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE, tips DOUBLE, driver_pay DOUBLE, shared_request_flag VARCHAR, shared_match_flag VARCHAR, access_a_ride_flag VARCHAR, wav_request_flag VARCHAR, wav_match_flag VARCHAR);\nCREATE TABLE taxi(VendorID BIGINT, tpep_pickup_datetime TIMESTAMP, tpep_dropoff_datetime TIMESTAMP, passenger_count DOUBLE, trip_distance DOUBLE, RatecodeID DOUBLE, store_and_fwd_flag VARCHAR, PULocationID BIGINT, DOLocationID BIGINT, payment_type BIGINT, fare_amount DOUBLE, extra DOUBLE, mta_tax DOUBLE, tip_amount DOUBLE, tolls_amount DOUBLE, improvement_surcharge DOUBLE, total_amount DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE);\nCREATE TABLE service_requests(unique_key BIGINT, created_date TIMESTAMP, closed_date TIMESTAMP, agency VARCHAR, agency_name VARCHAR, complaint_type VARCHAR, descriptor VARCHAR, location_type VARCHAR, incident_zip VARCHAR, incident_address VARCHAR, street_name VARCHAR, cross_street_1 VARCHAR, cross_street_2 VARCHAR, intersection_street_1 VARCHAR, intersection_street_2 VARCHAR, address_type VARCHAR, city VARCHAR, landmark VARCHAR, facility_type VARCHAR, status VARCHAR, due_date TIMESTAMP, resolution_description VARCHAR, resolution_action_updated_date TIMESTAMP, community_board VARCHAR, bbl VARCHAR, borough VARCHAR, x_coordinate_state_plane VARCHAR, y_coordinate_state_plane VARCHAR, open_data_channel_type VARCHAR, park_facility_name VARCHAR, park_borough VARCHAR, vehicle_type VARCHAR, taxi_company_borough VARCHAR, taxi_pick_up_location VARCHAR, bridge_highway_name VARCHAR, bridge_highway_direction VARCHAR, road_ramp VARCHAR, bridge_highway_segment VARCHAR, latitude DOUBLE, longitude DOUBLE);'
73
- schema = expander.text_input("Current schema:", value=default_schema)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Input field for text prompt
76
- text_prompt = st.text_input("What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv")
 
 
77
 
78
  if text_prompt:
79
  sql_query = generate_sql(text_prompt, schema)
80
- valid = validate_sql(sql_query, schema)
81
  if not valid:
82
- st.code(ERROR_MESSAGE, language="text")
83
  else:
84
- st.code(sql_query, language="sql")
85
-
 
1
  import streamlit as st
2
  import requests
3
+ import sqlglot
4
+ import tempfile
5
+ import re
6
+ from validate_sql import validate_query
7
 
8
+ PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
9
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
10
+ ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend. ]\n\n:red[ If the question is about your own database, make sure to set the correct schema. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
11
+ STOP_TOKENS = ["###", ";", "--", "```"]
12
+
13
 
14
  def generate_prompt(question, schema):
15
  input = ""
16
  if schema:
17
+ # Lowercase types inside each CREATE TABLE (...) statement
18
+ for create_table in re.findall(
19
+ r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE
20
+ ):
21
+ for create_col in re.findall(r"(\w+) (\w+)", create_table):
22
+ schema = schema.replace(
23
+ f"{create_col[0]} {create_col[1]}",
24
+ f"{create_col[0]} {create_col[1].lower()}",
25
+ )
26
  input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501
27
  schema=schema
28
  )
29
  prompt = PROMPT_TEMPLATE.format(
30
+ instruction=INSTRUCTION_TEMPLATE.format(
31
+ has_schema="." if schema == "" else ", given a duckdb database schema."
 
 
32
  ),
 
33
  input=input,
34
+ question=question,
35
  )
36
  return prompt
37
 
38
+
39
  def generate_sql(question, schema):
40
  prompt = generate_prompt(question, schema)
41
+
42
  s = requests.Session()
43
  api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
44
  url = f"{api_base}/v1/completions"
 
47
  "prompt": prompt,
48
  "temperature": 0.1,
49
  "max_tokens": 200,
50
+ "stop": "<s>",
51
+ "n": 1,
52
  }
53
  headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
 
54
  with s.post(url, json=body, headers=headers) as resp:
55
+ sql_query = resp.json()["choices"][0]["text"]
56
+
57
+ #for token in STOP_TOKENS:
58
+ # sql_query = sql_query.split(token)[0]
59
+ #sql_query = sqlglot.parse_one(sql_query, read="duckdb").sql(
60
+ # dialect="duckdb", pretty=True
61
+ #)
62
+ return sql_query
63
+
64
 
65
  def validate_sql(query, schema):
66
+ valid, msg = validate_query(query, schema)
67
+ return valid, msg
68
+
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  st.title("DuckDB-NSQL-7B Demo")
71
 
72
  expander = st.expander("Customize Schema (Optional)")
73
+ expander.markdown(
74
+ "If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:"
75
+ )
76
+ expander.markdown(
77
+ """```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""",
78
+ )
79
 
80
  # Input field for text prompt
81
+ default_schema = """CREATE TABLE rideshare(
82
+ hvfhs_license_num VARCHAR,
83
+ dispatching_base_num VARCHAR,
84
+ originating_base_num VARCHAR,
85
+ request_datetime TIMESTAMP,
86
+ on_scene_datetime TIMESTAMP,
87
+ pickup_datetime TIMESTAMP,
88
+ dropoff_datetime TIMESTAMP,
89
+ PULocationID BIGINT,
90
+ DOLocationID BIGINT,
91
+ trip_miles DOUBLE,
92
+ trip_time BIGINT,
93
+ base_passenger_fare DOUBLE,
94
+ tolls DOUBLE,
95
+ bcf DOUBLE,
96
+ sales_tax DOUBLE,
97
+ congestion_surcharge DOUBLE,
98
+ airport_fee DOUBLE,
99
+ tips DOUBLE,
100
+ driver_pay DOUBLE,
101
+ shared_request_flag VARCHAR,
102
+ shared_match_flag VARCHAR,
103
+ access_a_ride_flag VARCHAR,
104
+ wav_request_flag VARCHAR,
105
+ wav_match_flag VARCHAR
106
+ );
107
+
108
+ CREATE TABLE service_requests(
109
+ unique_key BIGINT,
110
+ created_date TIMESTAMP,
111
+ closed_date TIMESTAMP,
112
+ agency VARCHAR,
113
+ agency_name VARCHAR,
114
+ complaint_type VARCHAR,
115
+ descriptor VARCHAR,
116
+ location_type VARCHAR,
117
+ incident_zip VARCHAR,
118
+ incident_address VARCHAR,
119
+ street_name VARCHAR,
120
+ cross_street_1 VARCHAR,
121
+ cross_street_2 VARCHAR,
122
+ intersection_street_1 VARCHAR,
123
+ intersection_street_2 VARCHAR,
124
+ address_type VARCHAR,
125
+ city VARCHAR,
126
+ landmark VARCHAR,
127
+ facility_type VARCHAR,
128
+ status VARCHAR,
129
+ due_date TIMESTAMP,
130
+ resolution_description VARCHAR,
131
+ resolution_action_updated_date TIMESTAMP,
132
+ community_board VARCHAR,
133
+ bbl VARCHAR,
134
+ borough VARCHAR,
135
+ x_coordinate_state_plane VARCHAR,
136
+ y_coordinate_state_plane VARCHAR,
137
+ open_data_channel_type VARCHAR,
138
+ park_facility_name VARCHAR,
139
+ park_borough VARCHAR,
140
+ vehicle_type VARCHAR,
141
+ taxi_company_borough VARCHAR,
142
+ taxi_pick_up_location VARCHAR,
143
+ bridge_highway_name VARCHAR,
144
+ bridge_highway_direction VARCHAR,
145
+ road_ramp VARCHAR,
146
+ bridge_highway_segment VARCHAR,
147
+ latitude DOUBLE,
148
+ longitude DOUBLE
149
+ );
150
+
151
+ CREATE TABLE taxi(
152
+ VendorID BIGINT,
153
+ tpep_pickup_datetime TIMESTAMP,
154
+ tpep_dropoff_datetime TIMESTAMP,
155
+ passenger_count DOUBLE,
156
+ trip_distance DOUBLE,
157
+ RatecodeID DOUBLE,
158
+ store_and_fwd_flag VARCHAR,
159
+ PULocationID BIGINT,
160
+ DOLocationID BIGINT,
161
+ payment_type BIGINT,
162
+ fare_amount DOUBLE,
163
+ extra DOUBLE,
164
+ mta_tax DOUBLE,
165
+ tip_amount DOUBLE,
166
+ tolls_amount DOUBLE,
167
+ improvement_surcharge DOUBLE,
168
+ total_amount DOUBLE,
169
+ congestion_surcharge DOUBLE,
170
+ airport_fee DOUBLE,
171
+ drivers VARCHAR[],
172
+ speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[],
173
+ other_violations JSON
174
+ );"""
175
+ schema = expander.text_area("Current schema:", value=default_schema, height=500)
176
 
177
  # Input field for text prompt
178
+ text_prompt = st.text_input(
179
+ "What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv"
180
+ )
181
 
182
  if text_prompt:
183
  sql_query = generate_sql(text_prompt, schema)
184
+ valid, msg = validate_sql(sql_query, schema)
185
  if not valid:
186
+ st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))
187
  else:
188
+ st.markdown(f"""```sql\n{sql_query}\n```""")
 
validate_sql.py CHANGED
@@ -1,34 +1,45 @@
1
  import sys
2
  import duckdb
3
- from duckdb import ParserException, SyntaxException, BinderException
 
4
 
5
  def validate_query(query, schemas):
 
 
6
  try:
7
- with duckdb.connect(":memory:", config={"enable_external_access": False}) as duckdb_conn:
 
 
 
8
  # register schemas
9
  for schema in schemas.split(";"):
10
  duckdb_conn.execute(schema)
11
  cursor = duckdb_conn.cursor()
12
  cursor.execute(query)
13
  except ParserException as e:
14
- raise e
 
15
  except SyntaxException as e:
16
- raise e
 
17
  except BinderException as e:
18
- raise e
 
 
 
 
 
 
 
 
19
  except Exception as e:
20
- message = str(e)
21
- if "but it exists" in message and "extension" in message:
22
- print(message)
23
- elif message.startswith("Catalog Error: Table with name"):
24
- raise e
25
- elif "Catalog Error: Table Function with name" in message:
26
- raise e
27
- elif "Catalog Error: Copy Function" in message:
28
- raise e
29
 
30
- if __name__ == '__main__':
31
  if len(sys.argv) > 2:
32
  validate_query(sys.argv[1], sys.argv[2])
33
  else:
34
- print("No query provided.")
 
1
  import sys
2
  import duckdb
3
+ from duckdb import ParserException, SyntaxException, BinderException, CatalogException
4
+
5
 
6
  def validate_query(query, schemas):
7
+ valid = True
8
+ msg = ""
9
  try:
10
+ print("Running query: ", query)
11
+ with duckdb.connect(
12
+ ":memory:", config={"enable_external_access": False}
13
+ ) as duckdb_conn:
14
  # register schemas
15
  for schema in schemas.split(";"):
16
  duckdb_conn.execute(schema)
17
  cursor = duckdb_conn.cursor()
18
  cursor.execute(query)
19
  except ParserException as e:
20
+ msg = str(e)
21
+ valid = False
22
  except SyntaxException as e:
23
+ msg = str(e)
24
+ valid = False
25
  except BinderException as e:
26
+ msg = str(e)
27
+ valid = False
28
+ except CatalogException as e:
29
+ msg = str(e)
30
+ if "but it exists" in msg and "extension" in msg:
31
+ valid = True
32
+ msg = ""
33
+ else:
34
+ valid = False
35
  except Exception as e:
36
+ msg = str(e)
37
+ valid = True
38
+ return valid, msg
39
+
 
 
 
 
 
40
 
41
+ if __name__ == "__main__":
42
  if len(sys.argv) > 2:
43
  validate_query(sys.argv[1], sys.argv[2])
44
  else:
45
+ print("No query provided.")