File size: 3,566 Bytes
6a0ec6a
 
6c37e10
6a0ec6a
4f04c00
6a0ec6a
79f396e
6c37e10
 
042246b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0ec6a
7306c07
 
 
 
edb7e14
7306c07
 
edb7e14
7306c07
edb7e14
7306c07
61d9b40
35cddc5
 
 
 
 
 
61d9b40
 
35cddc5
 
61d9b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35cddc5
 
7306c07
edb7e14
 
 
35cddc5
 
1f7ee11
 
7306c07
edb7e14
 
35cddc5
edb7e14
 
35cddc5
1f7ee11
 
 
 
 
 
 
6a0ec6a
0380e03
6a0ec6a
7306c07
6a0ec6a
7306c07
 
c6d6658
6a0ec6a
 
 
0380e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import gradio as gr
from sqlalchemy import text
from smolagents import tool, CodeAgent, HfApiModel
import spaces

# Import the persistent database
from database import engine, receipts

@tool
def sql_engine(query: str) -> str:
    """
    Executes an SQL query on the 'receipts' table and returns formatted results.
    
    Args:
        query: The SQL query to execute.

    Returns:
        Query result as a formatted string.
    """
    try:
        with engine.connect() as con:
            rows = con.execute(text(query)).fetchall()

        if not rows:
            return "No results found."

        # Convert results into a readable string format
        return "\n".join([", ".join(map(str, row)) for row in rows])

    except Exception as e:
        return f"Error: {str(e)}"

@tool
def query_sql(user_query: str) -> str:
    """
    Converts natural language input to an SQL query using CodeAgent
    and returns the execution results.

    Args:
        user_query: The user's request in natural language.

    Returns:
        The query result from the database as a formatted string.
    """
    # Provide the AI with the correct schema and strict instructions
    schema_info = (
        "The database has a table named 'receipts' with the following schema:\n"
        "- receipt_id (INTEGER, primary key)\n"
        "- customer_name (VARCHAR(16))\n"
        "- price (FLOAT)\n"
        "- tip (FLOAT)\n"
        "Generate a valid SQL SELECT query using ONLY these column names.\n"
        "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
    )

    # Generate SQL query using the provided schema
    generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")

    # Log the generated SQL for debugging
    print(f"Generated SQL: {generated_sql}")

    # Ensure we only execute valid SELECT queries
    if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
        return "Error: Only SELECT queries are allowed."

    # Execute the SQL query and return the result
    result = sql_engine(generated_sql)

    # Log the SQL query result
    print(f"SQL Query Result: {result}")

    return result  # Return the final result, NOT the generated SQL


    # Generate SQL query using the provided schema
    generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")

    # Log the generated SQL for debugging
    print(f"Generated SQL: {generated_sql}")

    # Ensure we only execute valid SELECT queries
    if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
        return "Error: Only SELECT queries are allowed."

    # Execute the SQL query and return the result
    result = sql_engine(generated_sql)

    # Log the SQL query result
    print(f"SQL Query Result: {result}")

    return result 

# Initialize CodeAgent to generate SQL queries from natural language
agent = CodeAgent(
    tools=[sql_engine],  # Ensure sql_engine is properly registered
    model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)

# Define Gradio interface
demo = gr.Interface(
    fn=query_sql,
    inputs=gr.Textbox(label="Enter your query in plain English"),
    outputs=gr.Textbox(label="Query Result"),
    title="Natural Language to SQL Executor",
    description="Enter a plain English request, and the AI will generate an SQL query and return the results.",
    flagging_mode="never",
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)