barathm111 commited on
Commit
283a0f0
·
verified ·
1 Parent(s): 3de53f6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -1,7 +1,8 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
- import mysql.connector
 
5
  import os
6
  from dotenv import load_dotenv
7
 
@@ -69,20 +70,20 @@ def home():
69
  @app.post("/query")
70
  def handle_query(request: QueryRequest):
71
  try:
 
72
  text = request.query
73
 
74
  # Fetch the database schema
75
  schema = get_database_schema()
76
- schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()])
 
77
 
78
- # Construct the system message
79
  system_message = f"""
80
- You are a helpful, cheerful database assistant.
81
- Use the following dynamically retrieved database schema when creating your answers:
82
 
83
- {schema_str}
84
-
85
- When creating your answers, consider the following:
86
 
87
  1. If a query involves a column or value that is not present in the provided database schema, correct it and mention the correction in the summary. If a column or value is missing, provide an explanation of the issue and adjust the query accordingly.
88
  2. If there is a spelling mistake in the column name or value, attempt to correct it by matching the closest possible column or value from the schema. Mention this correction in the summary to clarify any changes made.
@@ -97,19 +98,18 @@ def handle_query(request: QueryRequest):
97
  In the preceding JSON response, substitute "your-query" with a MariaDB query to retrieve the requested data.
98
  In the preceding JSON response, substitute "your-summary" with a summary of the query and any corrections or clarifications made.
99
  Always include all columns in the table.
100
- """
101
-
102
  prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
103
  output = pipe(prompt, max_new_tokens=100)
104
-
 
105
  generated_text = output[0]['generated_text']
106
  sql_query = generated_text.split("SQL query:")[-1].strip()
107
 
108
- # Basic validation
109
  if not sql_query.lower().startswith(('select', 'show', 'describe')):
110
  raise ValueError("Generated text is not a valid SQL query")
111
-
112
- # Example: execute the generated SQL query and return the results
113
  conn = get_db_connection()
114
  cursor = conn.cursor()
115
  cursor.execute(sql_query)
@@ -120,8 +120,10 @@ def handle_query(request: QueryRequest):
120
 
121
  return {"sql": sql_query, "results": results}
122
  except Exception as e:
 
123
  raise HTTPException(status_code=500, detail=str(e))
124
 
 
125
  if __name__ == "__main__":
126
  import uvicorn
127
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
+ import mysql.connector
5
+ import json
6
  import os
7
  from dotenv import load_dotenv
8
 
 
70
  @app.post("/query")
71
  def handle_query(request: QueryRequest):
72
  try:
73
+ print("Received query:", request.query) # Debugging: Print the received query
74
  text = request.query
75
 
76
  # Fetch the database schema
77
  schema = get_database_schema()
78
+ schema_str = json.dumps(schema, indent=4)
79
+ print("Fetched schema:", schema) # Debugging: Print the fetched schema
80
 
 
81
  system_message = f"""
82
+ You are a helpful, cheerful database assistant.
83
+ Use the following dynamically retrieved database schema when creating your answers:
84
 
85
+ {schema_str}
86
+ When creating your answers, consider the following:
 
87
 
88
  1. If a query involves a column or value that is not present in the provided database schema, correct it and mention the correction in the summary. If a column or value is missing, provide an explanation of the issue and adjust the query accordingly.
89
  2. If there is a spelling mistake in the column name or value, attempt to correct it by matching the closest possible column or value from the schema. Mention this correction in the summary to clarify any changes made.
 
98
  In the preceding JSON response, substitute "your-query" with a MariaDB query to retrieve the requested data.
99
  In the preceding JSON response, substitute "your-summary" with a summary of the query and any corrections or clarifications made.
100
  Always include all columns in the table.
101
+ """
102
+
103
  prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
104
  output = pipe(prompt, max_new_tokens=100)
105
+ print("Generated output:", output) # Debugging: Print the generated output
106
+
107
  generated_text = output[0]['generated_text']
108
  sql_query = generated_text.split("SQL query:")[-1].strip()
109
 
 
110
  if not sql_query.lower().startswith(('select', 'show', 'describe')):
111
  raise ValueError("Generated text is not a valid SQL query")
112
+
 
113
  conn = get_db_connection()
114
  cursor = conn.cursor()
115
  cursor.execute(sql_query)
 
120
 
121
  return {"sql": sql_query, "results": results}
122
  except Exception as e:
123
+ print("Error occurred:", str(e)) # Debugging: Print the error
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
126
+
127
  if __name__ == "__main__":
128
  import uvicorn
129
  uvicorn.run(app, host="0.0.0.0", port=7860)