File size: 5,081 Bytes
57816fc
 
b3eb06a
 
57816fc
a148b10
 
b3eb06a
 
a148b10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57816fc
 
 
 
a148b10
 
57816fc
 
 
 
 
a148b10
 
57816fc
a148b10
57816fc
a148b10
57816fc
 
b3eb06a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57816fc
 
 
 
 
 
 
cef725d
57816fc
 
 
b91e941
57816fc
 
 
b3eb06a
 
 
 
 
57816fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import requests
import subprocess
import sys

PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}{context}\n### Question:\n{question}\n\n### Response:\n"""
INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}"""  # noqa: E501
TMP_DIR = "tmp"
ERROR_MESSAGE = "Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend."

def generate_prompt(question, schema):
    input = ""
    if schema:
        input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format(  # noqa: E501
            schema=schema
        )
    prompt = PROMPT_TEMPLATE.format(
        instruction = INSTRUCTION_TEMPLATE.format(
            has_schema="."
            if schema == ""
            else ", given a duckdb database schema."
        ),
        context="",
        input=input,
        question=question + ". Use DuckDB shorthand if possible.",
    )
    return prompt

def generate_sql(question, schema):
    prompt = generate_prompt(question, schema)
    s = requests.Session()
    api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
    url = f"{api_base}/v1/completions"
    body = {
        "model": "motherduck-sql-fp16",
        "prompt": prompt,
        "temperature": 0.1,
        "max_tokens": 200,
        "stop":'<s>',
        "n": 1
    }
    headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}

    with s.post(url, json=body, headers=headers) as resp:
        return resp.json()["choices"][0]["text"]

def validate_sql(query, schema):
    try:
        # Define subprocess
        process = subprocess.Popen(
            [sys.executable, './validate_sql.py', query, schema],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        # Get output and potential parser, and binder error message
        stdout, stderr = process.communicate(timeout=0.5)
        if stderr:
            return False
        return True
    except subprocess.TimeoutExpired:
        process.kill()
        # timeout reached, so parsing and binding was very likely successful
        return True

st.title("DuckDB-NSQL-7B Demo")

expander = st.expander("Customize Schema (Optional)")
expander.text("Execute this query in your DuckDB database to get your current schema:")
expander.code("SELECT array_to_string(list(sql), '\\n') from duckdb_tables()", language="sql")

# Input field for text prompt
default_schema = 'CREATE TABLE rideshare(hvfhs_license_num VARCHAR, dispatching_base_num VARCHAR, originating_base_num VARCHAR, request_datetime TIMESTAMP, on_scene_datetime TIMESTAMP, pickup_datetime TIMESTAMP, dropoff_datetime TIMESTAMP, PULocationID BIGINT, DOLocationID BIGINT, trip_miles DOUBLE, trip_time BIGINT, base_passenger_fare DOUBLE, tolls DOUBLE, bcf DOUBLE, sales_tax DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE, tips DOUBLE, driver_pay DOUBLE, shared_request_flag VARCHAR, shared_match_flag VARCHAR, access_a_ride_flag VARCHAR, wav_request_flag VARCHAR, wav_match_flag VARCHAR);\nCREATE TABLE taxi(VendorID BIGINT, tpep_pickup_datetime TIMESTAMP, tpep_dropoff_datetime TIMESTAMP, passenger_count DOUBLE, trip_distance DOUBLE, RatecodeID DOUBLE, store_and_fwd_flag VARCHAR, PULocationID BIGINT, DOLocationID BIGINT, payment_type BIGINT, fare_amount DOUBLE, extra DOUBLE, mta_tax DOUBLE, tip_amount DOUBLE, tolls_amount DOUBLE, improvement_surcharge DOUBLE, total_amount DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE);\nCREATE TABLE service_requests(unique_key BIGINT, created_date TIMESTAMP, closed_date TIMESTAMP, agency VARCHAR, agency_name VARCHAR, complaint_type VARCHAR, descriptor VARCHAR, location_type VARCHAR, incident_zip VARCHAR, incident_address VARCHAR, street_name VARCHAR, cross_street_1 VARCHAR, cross_street_2 VARCHAR, intersection_street_1 VARCHAR, intersection_street_2 VARCHAR, address_type VARCHAR, city VARCHAR, landmark VARCHAR, facility_type VARCHAR, status VARCHAR, due_date TIMESTAMP, resolution_description VARCHAR, resolution_action_updated_date TIMESTAMP, community_board VARCHAR, bbl VARCHAR, borough VARCHAR, x_coordinate_state_plane VARCHAR, y_coordinate_state_plane VARCHAR, open_data_channel_type VARCHAR, park_facility_name VARCHAR, park_borough VARCHAR, vehicle_type VARCHAR, taxi_company_borough VARCHAR, taxi_pick_up_location VARCHAR, bridge_highway_name VARCHAR, bridge_highway_direction VARCHAR, road_ramp VARCHAR, bridge_highway_segment VARCHAR, latitude DOUBLE, longitude DOUBLE);'
schema = expander.text_input("Current schema:", value=default_schema)

# Input field for text prompt
text_prompt = st.text_input("What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv")

if text_prompt:
    sql_query = generate_sql(text_prompt, schema)
    valid = validate_sql(sql_query, schema)
    if not valid:
        st.code(ERROR_MESSAGE, language="text")
    else:
        st.code(sql_query, language="sql")