DuckDB-NSQL-7B / app.py
tdoehmen's picture
updated validation
dcec0ff
import streamlit as st
import requests
import subprocess
import re
import sys
PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[If the question is about your own database, make sure to set the correct schema. Otherwise, try to rephrase your request. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
STOP_TOKENS = ["###", ";", "--", "```"]
def generate_prompt(question, schema):
input = ""
if schema:
# Lowercase types inside each CREATE TABLE (...) statement
for create_table in re.findall(
r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE
):
for create_col in re.findall(r"(\w+) (\w+)", create_table):
schema = schema.replace(
f"{create_col[0]} {create_col[1]}",
f"{create_col[0]} {create_col[1].lower()}",
)
input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501
schema=schema
)
prompt = PROMPT_TEMPLATE.format(
instruction=INSTRUCTION_TEMPLATE.format(
has_schema="." if schema == "" else ", given a duckdb database schema."
),
input=input,
question=question,
)
return prompt
def generate_sql(question, schema):
prompt = generate_prompt(question, schema)
s = requests.Session()
api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
url = f"{api_base}/v1/completions"
body = {
"model": "motherduck-sql-fp16",
"prompt": prompt,
"temperature": 0.1,
"max_tokens": 200,
"stop": "<s>",
"n": 1,
}
headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
with s.post(url, json=body, headers=headers) as resp:
sql_query = resp.json()["choices"][0]["text"]
return sql_query
def validate_sql(query, schema):
try:
# Define subprocess
process = subprocess.Popen(
[sys.executable, './validate_sql.py', query, schema],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Get output and potential parser, and binder error message
stdout, stderr = process.communicate(timeout=0.5)
if stderr:
error_message = stderr.decode('utf8').split("\n")
# skip traceback
if len(error_message) > 3:
error_message = "\n".join(error_message[3:])
return False, error_message
return True, ""
except subprocess.TimeoutExpired:
process.kill()
# timeout reached, so parsing and binding was very likely successful
return True, ""
st.title("DuckDB-NSQL-7B Demo")
expander = st.expander("Customize Schema (Optional)")
expander.markdown(
"If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:"
)
expander.markdown(
"""```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""",
)
# Input field for text prompt
default_schema = """CREATE TABLE rideshare(
hvfhs_license_num VARCHAR,
dispatching_base_num VARCHAR,
originating_base_num VARCHAR,
request_datetime TIMESTAMP,
on_scene_datetime TIMESTAMP,
pickup_datetime TIMESTAMP,
dropoff_datetime TIMESTAMP,
PULocationID BIGINT,
DOLocationID BIGINT,
trip_miles DOUBLE,
trip_time BIGINT,
base_passenger_fare DOUBLE,
tolls DOUBLE,
bcf DOUBLE,
sales_tax DOUBLE,
congestion_surcharge DOUBLE,
airport_fee DOUBLE,
tips DOUBLE,
driver_pay DOUBLE,
shared_request_flag VARCHAR,
shared_match_flag VARCHAR,
access_a_ride_flag VARCHAR,
wav_request_flag VARCHAR,
wav_match_flag VARCHAR
);
CREATE TABLE service_requests(
unique_key BIGINT,
created_date TIMESTAMP,
closed_date TIMESTAMP,
agency VARCHAR,
agency_name VARCHAR,
complaint_type VARCHAR,
descriptor VARCHAR,
location_type VARCHAR,
incident_zip VARCHAR,
incident_address VARCHAR,
street_name VARCHAR,
cross_street_1 VARCHAR,
cross_street_2 VARCHAR,
intersection_street_1 VARCHAR,
intersection_street_2 VARCHAR,
address_type VARCHAR,
city VARCHAR,
landmark VARCHAR,
facility_type VARCHAR,
status VARCHAR,
due_date TIMESTAMP,
resolution_description VARCHAR,
resolution_action_updated_date TIMESTAMP,
community_board VARCHAR,
bbl VARCHAR,
borough VARCHAR,
x_coordinate_state_plane VARCHAR,
y_coordinate_state_plane VARCHAR,
open_data_channel_type VARCHAR,
park_facility_name VARCHAR,
park_borough VARCHAR,
vehicle_type VARCHAR,
taxi_company_borough VARCHAR,
taxi_pick_up_location VARCHAR,
bridge_highway_name VARCHAR,
bridge_highway_direction VARCHAR,
road_ramp VARCHAR,
bridge_highway_segment VARCHAR,
latitude DOUBLE,
longitude DOUBLE
);
CREATE TABLE taxi(
VendorID BIGINT,
tpep_pickup_datetime TIMESTAMP,
tpep_dropoff_datetime TIMESTAMP,
passenger_count DOUBLE,
trip_distance DOUBLE,
RatecodeID DOUBLE,
store_and_fwd_flag VARCHAR,
PULocationID BIGINT,
DOLocationID BIGINT,
payment_type BIGINT,
fare_amount DOUBLE,
extra DOUBLE,
mta_tax DOUBLE,
tip_amount DOUBLE,
tolls_amount DOUBLE,
improvement_surcharge DOUBLE,
total_amount DOUBLE,
congestion_surcharge DOUBLE,
airport_fee DOUBLE,
drivers VARCHAR[],
speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[],
other_violations JSON
);"""
schema = expander.text_area("Current schema:", value=default_schema, height=500)
# Input field for text prompt
text_prompt = st.text_input(
"What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv"
)
if text_prompt:
sql_query = generate_sql(text_prompt, schema)
valid, msg = validate_sql(sql_query, schema)
if not valid:
st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))
else:
st.markdown(f"""```sql\n{sql_query}\n```""")