Quazim0t0 commited on
Commit
237bccb
·
verified ·
1 Parent(s): 79d5248

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -151
app.py CHANGED
@@ -1,49 +1,20 @@
1
  import os
2
  import gradio as gr
 
 
 
3
  import pandas as pd
4
- from sqlalchemy import create_engine, text
5
- from langchain.tools import tool
6
- from code_agent import CodeAgent
7
- from hf_api_model import HfApiModel
8
-
9
- # Initialize SQLite database engine
10
- engine = create_engine('sqlite:///data.db')
11
-
12
- def clear_database():
13
- """
14
- Clear all tables from the database.
15
- """
16
- with engine.connect() as con:
17
- # Get all table names
18
- tables = con.execute(text(
19
- "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
20
- )).fetchall()
21
-
22
- # Drop each table
23
- for table in tables:
24
- con.execute(text(f"DROP TABLE IF EXISTS {table[0]}"))
25
-
26
- def create_dynamic_table(df):
27
- """
28
- Create a table dynamically based on DataFrame structure.
29
- """
30
- df.to_sql('data_table', engine, index=False, if_exists='replace')
31
- return 'data_table'
32
-
33
- def insert_rows_into_table(records, table_name):
34
- """
35
- Insert records into the specified table.
36
- """
37
- with engine.begin() as conn:
38
- for record in records:
39
- conn.execute(
40
- text(f"INSERT INTO {table_name} ({', '.join(record.keys())}) VALUES ({', '.join(['?' for _ in record])})")
41
- .bindparams(*record.values())
42
- )
43
 
44
  def get_data_table():
45
  """
46
- Get the current data table as a DataFrame.
47
  """
48
  try:
49
  # Get list of tables
@@ -57,10 +28,18 @@ def get_data_table():
57
 
58
  # Use the first table found
59
  table_name = tables[0][0]
60
-
61
- # Read the table into a DataFrame
62
- return pd.read_sql_table(table_name, engine)
63
-
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
  return pd.DataFrame({"Error": [str(e)]})
66
 
@@ -181,7 +160,7 @@ def process_uploaded_file(file):
181
  def sql_engine(query: str) -> str:
182
  """
183
  Executes an SQL query and returns formatted results.
184
-
185
  Args:
186
  query: The SQL query string to execute on the database. Must be a valid SELECT query.
187
 
@@ -203,10 +182,48 @@ def sql_engine(query: str) -> str:
203
  except Exception as e:
204
  return f"Error: {str(e)}"
205
 
206
- def process_sql_result(generated_sql, table_name, column_names):
 
 
 
 
 
207
  """
208
- Process and execute the generated SQL query.
209
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  # Remove any trailing semicolons
211
  generated_sql = generated_sql.strip().rstrip(';')
212
 
@@ -238,103 +255,12 @@ def process_sql_result(generated_sql, table_name, column_names):
238
  return generated_sql
239
  return f"Error executing query: {str(e)}"
240
 
241
- def query_sql(user_query: str, show_full: bool) -> tuple:
242
- """
243
- Converts natural language input to an SQL query using CodeAgent.
244
- Returns both short and full responses based on switch state.
245
- """
246
- table_name, column_names, column_info = get_table_info()
247
-
248
- if not table_name:
249
- return "Error: No data table exists. Please upload a file first.", ""
250
-
251
- schema_info = (
252
- f"The database has a table named '{table_name}' with the following columns:\n"
253
- + "\n".join([
254
- f"- {col} ({info['type']}{' primary key' if info['is_primary'] else ''})"
255
- for col, info in column_info.items()
256
- ])
257
- + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
258
- "The table name is '" + table_name + "'.\n"
259
- "If column names contain spaces, they must be quoted.\n"
260
- "You can use aggregate functions like COUNT, AVG, SUM, etc.\n"
261
- "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
262
- )
263
-
264
- # Get full response from the agent
265
- full_response = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
266
-
267
- # Process the short response as before
268
- if not isinstance(full_response, str):
269
- return "Error: Invalid query generated", ""
270
-
271
- # Extract and process SQL for short response
272
- generated_sql = full_response
273
- if generated_sql.isnumeric():
274
- short_response = generated_sql
275
- else:
276
- sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()]
277
- if sql_lines:
278
- generated_sql = sql_lines[0]
279
-
280
- # Process the SQL query and get the short result
281
- short_response = process_sql_result(generated_sql, table_name, column_names)
282
-
283
- return short_response, full_response
284
-
285
- def handle_upload(file_obj):
286
- if file_obj is None:
287
- return (
288
- "Please upload a file.",
289
- None,
290
- "No schema available",
291
- gr.update(visible=True),
292
- gr.update(visible=False)
293
- )
294
-
295
- success, message = process_uploaded_file(file_obj)
296
- if success:
297
- df = get_data_table()
298
- _, _, column_info = get_table_info()
299
- schema = "\n".join([
300
- f"- {col} ({info['type']}){'primary key' if info['is_primary'] else ''}"
301
- for col, info in column_info.items()
302
- ])
303
- return (
304
- message,
305
- df,
306
- f"### Current Schema:\n```\n{schema}\n```",
307
- gr.update(visible=False),
308
- gr.update(visible=True)
309
- )
310
- return (
311
- message,
312
- None,
313
- "No schema available",
314
- gr.update(visible=True),
315
- gr.update(visible=False)
316
- )
317
-
318
- def refresh_data():
319
- df = get_data_table()
320
- _, _, column_info = get_table_info()
321
- schema = "\n".join([
322
- f"- {col} ({info['type']}){'primary key' if info['is_primary'] else ''}"
323
- for col, info in column_info.items()
324
- ])
325
- return df, f"### Current Schema:\n```\n{schema}\n```"
326
-
327
- # Initialize the CodeAgent
328
- agent = CodeAgent(
329
- tools=[sql_engine],
330
- model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
331
- )
332
-
333
  # Create the Gradio interface
334
  with gr.Blocks() as demo:
335
  with gr.Group() as upload_group:
336
  gr.Markdown("""
337
  # CSVAgent
 
338
  Upload your data file to begin.
339
 
340
  ### Supported File Types:
@@ -351,7 +277,10 @@ with gr.Blocks() as demo:
351
  https://tableconvert.com/sql-to-csv
352
  - Will work on the handling of SQL files soon.
353
 
 
354
  ### Try it out! Upload a CSV file and then ask a question about the data!
 
 
355
  """)
356
 
357
  file_input = gr.File(
@@ -366,9 +295,6 @@ with gr.Blocks() as demo:
366
  with gr.Column(scale=1):
367
  user_input = gr.Textbox(label="Ask a question about the data")
368
  query_output = gr.Textbox(label="Result")
369
- # Add the switch and secondary result box
370
- full_response_switch = gr.Switch(label="Show Full Response", value=False)
371
- full_response_output = gr.Textbox(label="Full Response", visible=False)
372
 
373
  with gr.Column(scale=2):
374
  gr.Markdown("### Current Data")
@@ -381,6 +307,48 @@ with gr.Blocks() as demo:
381
  schema_display = gr.Markdown(value="Loading schema...")
382
  refresh_btn = gr.Button("Refresh Data")
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # Event handlers
385
  file_input.upload(
386
  fn=handle_upload,
@@ -396,15 +364,8 @@ with gr.Blocks() as demo:
396
 
397
  user_input.change(
398
  fn=query_sql,
399
- inputs=[user_input, full_response_switch],
400
- outputs=[query_output, full_response_output]
401
- )
402
-
403
- # Add switch change event to control visibility of full response
404
- full_response_switch.change(
405
- fn=lambda x: gr.update(visible=x),
406
- inputs=full_response_switch,
407
- outputs=full_response_output
408
  )
409
 
410
  refresh_btn.click(
 
1
  import os
2
  import gradio as gr
3
+ from sqlalchemy import text
4
+ from smolagents import tool, CodeAgent, HfApiModel
5
+ import spaces
6
  import pandas as pd
7
+ from database import (
8
+ engine,
9
+ create_dynamic_table,
10
+ clear_database,
11
+ insert_rows_into_table,
12
+ get_table_schema
13
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_data_table():
16
  """
17
+ Fetches all data from the current table and returns it as a Pandas DataFrame.
18
  """
19
  try:
20
  # Get list of tables
 
28
 
29
  # Use the first table found
30
  table_name = tables[0][0]
31
+
32
+ with engine.connect() as con:
33
+ result = con.execute(text(f"SELECT * FROM {table_name}"))
34
+ rows = result.fetchall()
35
+
36
+ if not rows:
37
+ return pd.DataFrame()
38
+
39
+ columns = result.keys()
40
+ df = pd.DataFrame(rows, columns=columns)
41
+ return df
42
+
43
  except Exception as e:
44
  return pd.DataFrame({"Error": [str(e)]})
45
 
 
160
  def sql_engine(query: str) -> str:
161
  """
162
  Executes an SQL query and returns formatted results.
163
+
164
  Args:
165
  query: The SQL query string to execute on the database. Must be a valid SELECT query.
166
 
 
182
  except Exception as e:
183
  return f"Error: {str(e)}"
184
 
185
+ agent = CodeAgent(
186
+ tools=[sql_engine],
187
+ model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
188
+ )
189
+
190
+ def query_sql(user_query: str) -> str:
191
  """
192
+ Converts natural language input to an SQL query using CodeAgent.
193
  """
194
+ table_name, column_names, column_info = get_table_info()
195
+
196
+ if not table_name:
197
+ return "Error: No data table exists. Please upload a file first."
198
+
199
+ schema_info = (
200
+ f"The database has a table named '{table_name}' with the following columns:\n"
201
+ + "\n".join([
202
+ f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
203
+ for col, info in column_info.items()
204
+ ])
205
+ + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
206
+ "The table name is '" + table_name + "'.\n"
207
+ "If column names contain spaces, they must be quoted.\n"
208
+ "You can use aggregate functions like COUNT, AVG, SUM, etc.\n"
209
+ "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
210
+ )
211
+
212
+ # Get SQL from the agent
213
+ generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
214
+
215
+ if not isinstance(generated_sql, str):
216
+ return "Error: Invalid query generated"
217
+
218
+ # Clean up the SQL
219
+ if generated_sql.isnumeric(): # If the agent returned just a number
220
+ return generated_sql
221
+
222
+ # Extract just the SQL query if there's additional text
223
+ sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()]
224
+ if sql_lines:
225
+ generated_sql = sql_lines[0]
226
+
227
  # Remove any trailing semicolons
228
  generated_sql = generated_sql.strip().rstrip(';')
229
 
 
255
  return generated_sql
256
  return f"Error executing query: {str(e)}"
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Create the Gradio interface
259
  with gr.Blocks() as demo:
260
  with gr.Group() as upload_group:
261
  gr.Markdown("""
262
  # CSVAgent
263
+
264
  Upload your data file to begin.
265
 
266
  ### Supported File Types:
 
277
  https://tableconvert.com/sql-to-csv
278
  - Will work on the handling of SQL files soon.
279
 
280
+
281
  ### Try it out! Upload a CSV file and then ask a question about the data!
282
+ - There is issues with the UI displaying the answer correctly, some questions such as "How many Customers are located in Korea?"
283
+ The right answer will appear in the logs, but throws an error on the "Results" section.
284
  """)
285
 
286
  file_input = gr.File(
 
295
  with gr.Column(scale=1):
296
  user_input = gr.Textbox(label="Ask a question about the data")
297
  query_output = gr.Textbox(label="Result")
 
 
 
298
 
299
  with gr.Column(scale=2):
300
  gr.Markdown("### Current Data")
 
307
  schema_display = gr.Markdown(value="Loading schema...")
308
  refresh_btn = gr.Button("Refresh Data")
309
 
310
+ def handle_upload(file_obj):
311
+ if file_obj is None:
312
+ return (
313
+ "Please upload a file.",
314
+ None,
315
+ "No schema available",
316
+ gr.update(visible=True),
317
+ gr.update(visible=False)
318
+ )
319
+
320
+ success, message = process_uploaded_file(file_obj)
321
+ if success:
322
+ df = get_data_table()
323
+ _, _, column_info = get_table_info()
324
+ schema = "\n".join([
325
+ f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
326
+ for col, info in column_info.items()
327
+ ])
328
+ return (
329
+ message,
330
+ df,
331
+ f"### Current Schema:\n```\n{schema}\n```",
332
+ gr.update(visible=False),
333
+ gr.update(visible=True)
334
+ )
335
+ return (
336
+ message,
337
+ None,
338
+ "No schema available",
339
+ gr.update(visible=True),
340
+ gr.update(visible=False)
341
+ )
342
+
343
+ def refresh_data():
344
+ df = get_data_table()
345
+ _, _, column_info = get_table_info()
346
+ schema = "\n".join([
347
+ f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
348
+ for col, info in column_info.items()
349
+ ])
350
+ return df, f"### Current Schema:\n```\n{schema}\n```"
351
+
352
  # Event handlers
353
  file_input.upload(
354
  fn=handle_upload,
 
364
 
365
  user_input.change(
366
  fn=query_sql,
367
+ inputs=user_input,
368
+ outputs=query_output
 
 
 
 
 
 
 
369
  )
370
 
371
  refresh_btn.click(