File size: 4,046 Bytes
6a0ec6a
 
6c37e10
6a0ec6a
4f04c00
20e319d
6a0ec6a
6c37e10
 
1767e22
 
20e319d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
042246b
 
 
 
215368b
042246b
 
 
 
 
 
 
 
 
 
 
 
 
10e2935
 
 
042246b
 
 
 
 
7306c07
 
 
 
edb7e14
7306c07
 
edb7e14
7306c07
edb7e14
7306c07
1df3c5d
35cddc5
 
 
 
 
 
61d9b40
 
35cddc5
 
61d9b40
 
2443195
8cc89f4
2443195
8cc89f4
61d9b40
f8c651a
 
61d9b40
2e81bab
61d9b40
54c2240
2443195
10e2935
 
2443195
10e2935
2443195
1f7ee11
215368b
 
 
edb7e14
215368b
 
edb7e14
215368b
 
 
1df3c5d
1f7ee11
 
1df3c5d
1f7ee11
 
 
1767e22
926164d
1767e22
 
1df3c5d
b9e14b1
 
1767e22
1df3c5d
b9e14b1
1767e22
 
 
 
8be3748
6a0ec6a
 
0380e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import gradio as gr
from sqlalchemy import text
from smolagents import tool, CodeAgent, HfApiModel
import spaces
import pandas as pd

from database import engine, receipts

import pandas as pd

def get_receipts_table():
    """
    Fetches all data from the 'receipts' table and returns it as a Pandas DataFrame.

    Returns:
        A Pandas DataFrame containing all receipt data.
    """
    try:
        with engine.connect() as con:
            result = con.execute(text("SELECT * FROM receipts"))
            rows = result.fetchall()
        
        if not rows:
            return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])

        # Convert rows into a DataFrame
        df = pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
        return df

    except Exception as e:
        return pd.DataFrame({"Error": [str(e)]})  # Return error message in DataFrame format

@tool
def sql_engine(query: str) -> str:
    """
    Executes an SQL query on the 'receipts' table and returns formatted results.

    Args:
        query: The SQL query to execute.

    Returns:
        Query result as a formatted string.
    """
    try:
        with engine.connect() as con:
            rows = con.execute(text(query)).fetchall()

        if not rows:
            return "No results found."

        if len(rows) == 1 and len(rows[0]) == 1:
            return str(rows[0][0])  # Convert numerical result to string

        return "\n".join([", ".join(map(str, row)) for row in rows])

    except Exception as e:
        return f"Error: {str(e)}"

def query_sql(user_query: str) -> str:
    """
    Converts natural language input to an SQL query using CodeAgent
    and returns the execution results.

    Args:
        user_query: The user's request in natural language.

    Returns:
        The query result from the database as a formatted string.
    """

    schema_info = (
        "The database has a table named 'receipts' with the following schema:\n"
        "- receipt_id (INTEGER, primary key)\n"
        "- customer_name (VARCHAR(16))\n"
        "- price (FLOAT)\n"
        "- tip (FLOAT)\n"
        "Generate a valid SQL SELECT query using ONLY these column names.\n"
        "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
    )

    generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")

    if not isinstance(generated_sql, str):
        return f"{generated_sql}"  # Handle unexpected numerical result

    print(f"{generated_sql}")

    if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
        return "Error: Only SELECT queries are allowed."

    result = sql_engine(generated_sql)

    print(f"{result}")
    
    try:
        float_result = float(result)
        return f"{float_result:.2f}"
    except ValueError:
        return result 

def handle_query(user_input: str) -> str:
    """
    Calls query_sql, captures the output, and directly returns it to the UI.

    Args:
        user_input: The user's natural language question.

    Returns:
        The SQL query result as a plain string to be displayed in the UI.
    """
    return query_sql(user_input)

agent = CodeAgent(
    tools=[sql_engine],
    model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)

with gr.Blocks() as demo:
    gr.Markdown("## Plain Text Query Interface")

    with gr.Row():
        with gr.Column(scale=1):
            user_input = gr.Textbox(label="Ask a question about the data")
            query_output = gr.Textbox(label="Result")
        
        with gr.Column(scale=2):
            gr.Markdown("### Receipts Table")
            receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")

    user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)

    demo.load(fn=get_receipts_table, outputs=receipts_table)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)