tdoehmen commited on
Commit
8a8e4f3
1 Parent(s): 1d0f3f8

udpated error message and validation

Browse files
Files changed (2) hide show
  1. app.py +19 -11
  2. validate_sql.py +7 -17
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import streamlit as st
2
  import requests
3
- import sqlglot
4
- import tempfile
5
  import re
6
- from validate_sql import validate_query
7
 
8
  PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
9
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
10
- ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend. ]\n\n:red[ If the question is about your own database, make sure to set the correct schema. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
11
  STOP_TOKENS = ["###", ";", "--", "```"]
12
 
13
 
@@ -54,17 +53,26 @@ def generate_sql(question, schema):
54
  with s.post(url, json=body, headers=headers) as resp:
55
  sql_query = resp.json()["choices"][0]["text"]
56
 
57
- #for token in STOP_TOKENS:
58
- # sql_query = sql_query.split(token)[0]
59
- #sql_query = sqlglot.parse_one(sql_query, read="duckdb").sql(
60
- # dialect="duckdb", pretty=True
61
- #)
62
  return sql_query
63
 
64
 
65
  def validate_sql(query, schema):
66
- valid, msg = validate_query(query, schema)
67
- return valid, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  st.title("DuckDB-NSQL-7B Demo")
 
1
  import streamlit as st
2
  import requests
3
+ import subprocess
 
4
  import re
5
+ import sys
6
 
7
  PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
8
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
9
+ ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[ Try rephrasing the question/instruction. And if the question is about your own database, make sure to set the correct schema. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
10
  STOP_TOKENS = ["###", ";", "--", "```"]
11
 
12
 
 
53
  with s.post(url, json=body, headers=headers) as resp:
54
  sql_query = resp.json()["choices"][0]["text"]
55
 
 
 
 
 
 
56
  return sql_query
57
 
58
 
59
  def validate_sql(query, schema):
60
+ try:
61
+ # Define subprocess
62
+ process = subprocess.Popen(
63
+ [sys.executable, './validate_sql.py', query, schema],
64
+ stdout=subprocess.PIPE,
65
+ stderr=subprocess.PIPE
66
+ )
67
+ # Get output and potential parser, and binder error message
68
+ stdout, stderr = process.communicate(timeout=0.5)
69
+ if stderr:
70
+ return False, stderr.decode('utf8')
71
+ return True, ""
72
+ except subprocess.TimeoutExpired:
73
+ process.kill()
74
+ # timeout reached, so parsing and binding was very likely successful
75
+ return True, ""
76
 
77
 
78
  st.title("DuckDB-NSQL-7B Demo")
validate_sql.py CHANGED
@@ -4,8 +4,6 @@ from duckdb import ParserException, SyntaxException, BinderException, CatalogExc
4
 
5
 
6
  def validate_query(query, schemas):
7
- valid = True
8
- msg = ""
9
  try:
10
  print("Running query: ", query)
11
  with duckdb.connect(
@@ -17,25 +15,17 @@ def validate_query(query, schemas):
17
  cursor = duckdb_conn.cursor()
18
  cursor.execute(query)
19
  except ParserException as e:
20
- msg = str(e)
21
- valid = False
22
  except SyntaxException as e:
23
- msg = str(e)
24
- valid = False
25
  except BinderException as e:
26
- msg = str(e)
27
- valid = False
28
  except CatalogException as e:
29
- msg = str(e)
30
- if "but it exists" in msg and "extension" in msg:
31
- valid = True
32
- msg = ""
33
- else:
34
- valid = False
35
  except Exception as e:
36
- msg = str(e)
37
- valid = True
38
- return valid, msg
39
 
40
 
41
  if __name__ == "__main__":
 
4
 
5
 
6
  def validate_query(query, schemas):
 
 
7
  try:
8
  print("Running query: ", query)
9
  with duckdb.connect(
 
15
  cursor = duckdb_conn.cursor()
16
  cursor.execute(query)
17
  except ParserException as e:
18
+ raise e
 
19
  except SyntaxException as e:
20
+ raise e
 
21
  except BinderException as e:
22
+ raise e
 
23
  except CatalogException as e:
24
+ if not ("but it exists" in str(e) and "extension" in str(e)):
25
+ raise e
 
 
 
 
26
  except Exception as e:
27
+ return True
28
+ return True
 
29
 
30
 
31
  if __name__ == "__main__":