import streamlit as st import requests import sqlglot import tempfile import re from validate_sql import validate_query PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n""" INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501 ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend. ]\n\n:red[ If the question is about your own database, make sure to set the correct schema. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```" STOP_TOKENS = ["###", ";", "--", "```"] def generate_prompt(question, schema): input = "" if schema: # Lowercase types inside each CREATE TABLE (...) statement for create_table in re.findall( r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE ): for create_col in re.findall(r"(\w+) (\w+)", create_table): schema = schema.replace( f"{create_col[0]} {create_col[1]}", f"{create_col[0]} {create_col[1].lower()}", ) input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501 schema=schema ) prompt = PROMPT_TEMPLATE.format( instruction=INSTRUCTION_TEMPLATE.format( has_schema="." if schema == "" else ", given a duckdb database schema." ), input=input, question=question, ) return prompt def generate_sql(question, schema): prompt = generate_prompt(question, schema) s = requests.Session() api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run" url = f"{api_base}/v1/completions" body = { "model": "motherduck-sql-fp16", "prompt": prompt, "temperature": 0.1, "max_tokens": 200, "stop": "", "n": 1, } headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"} with s.post(url, json=body, headers=headers) as resp: sql_query = resp.json()["choices"][0]["text"] #for token in STOP_TOKENS: # sql_query = sql_query.split(token)[0] #sql_query = sqlglot.parse_one(sql_query, read="duckdb").sql( # dialect="duckdb", pretty=True #) return sql_query def validate_sql(query, schema): valid, msg = validate_query(query, schema) return valid, msg st.title("DuckDB-NSQL-7B Demo") expander = st.expander("Customize Schema (Optional)") expander.markdown( "If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:" ) expander.markdown( """```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""", ) # Input field for text prompt default_schema = """CREATE TABLE rideshare( hvfhs_license_num VARCHAR, dispatching_base_num VARCHAR, originating_base_num VARCHAR, request_datetime TIMESTAMP, on_scene_datetime TIMESTAMP, pickup_datetime TIMESTAMP, dropoff_datetime TIMESTAMP, PULocationID BIGINT, DOLocationID BIGINT, trip_miles DOUBLE, trip_time BIGINT, base_passenger_fare DOUBLE, tolls DOUBLE, bcf DOUBLE, sales_tax DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE, tips DOUBLE, driver_pay DOUBLE, shared_request_flag VARCHAR, shared_match_flag VARCHAR, access_a_ride_flag VARCHAR, wav_request_flag VARCHAR, wav_match_flag VARCHAR ); CREATE TABLE service_requests( unique_key BIGINT, created_date TIMESTAMP, closed_date TIMESTAMP, agency VARCHAR, agency_name VARCHAR, complaint_type VARCHAR, descriptor VARCHAR, location_type VARCHAR, incident_zip VARCHAR, incident_address VARCHAR, street_name VARCHAR, cross_street_1 VARCHAR, cross_street_2 VARCHAR, intersection_street_1 VARCHAR, intersection_street_2 VARCHAR, address_type VARCHAR, city VARCHAR, landmark VARCHAR, facility_type VARCHAR, status VARCHAR, due_date TIMESTAMP, resolution_description VARCHAR, resolution_action_updated_date TIMESTAMP, community_board VARCHAR, bbl VARCHAR, borough VARCHAR, x_coordinate_state_plane VARCHAR, y_coordinate_state_plane VARCHAR, open_data_channel_type VARCHAR, park_facility_name VARCHAR, park_borough VARCHAR, vehicle_type VARCHAR, taxi_company_borough VARCHAR, taxi_pick_up_location VARCHAR, bridge_highway_name VARCHAR, bridge_highway_direction VARCHAR, road_ramp VARCHAR, bridge_highway_segment VARCHAR, latitude DOUBLE, longitude DOUBLE ); CREATE TABLE taxi( VendorID BIGINT, tpep_pickup_datetime TIMESTAMP, tpep_dropoff_datetime TIMESTAMP, passenger_count DOUBLE, trip_distance DOUBLE, RatecodeID DOUBLE, store_and_fwd_flag VARCHAR, PULocationID BIGINT, DOLocationID BIGINT, payment_type BIGINT, fare_amount DOUBLE, extra DOUBLE, mta_tax DOUBLE, tip_amount DOUBLE, tolls_amount DOUBLE, improvement_surcharge DOUBLE, total_amount DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE, drivers VARCHAR[], speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[], other_violations JSON );""" schema = expander.text_area("Current schema:", value=default_schema, height=500) # Input field for text prompt text_prompt = st.text_input( "What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv" ) if text_prompt: sql_query = generate_sql(text_prompt, schema) valid, msg = validate_sql(sql_query, schema) if not valid: st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg)) else: st.markdown(f"""```sql\n{sql_query}\n```""")