Ari commited on
Commit
1c7e913
·
verified ·
1 Parent(s): a6107d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -34,12 +34,21 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
 
 
 
37
  # Step 3: Set up the SQL Database for LangChain
38
  db = SQLDatabase.from_uri('sqlite:///:memory:')
39
  db.raw_connection = conn # Use the in-memory connection for LangChain
40
 
41
  # Step 4: Create the SQL agent with the correct parameter name
42
- sql_agent = create_sql_agent(OpenAI(temperature=0), db=db, verbose=True)
 
 
 
 
 
 
43
 
44
  # Step 5: Use FAISS with RAG for context retrieval
45
  embeddings = OpenAIEmbeddings()
@@ -53,9 +62,12 @@ rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=ret
53
  # Step 6: Define SQL validation helpers
54
  def validate_sql(query, valid_columns):
55
  """Validates the SQL query by ensuring it references only valid columns."""
56
- for column in valid_columns:
57
- if column not in query:
58
- return False
 
 
 
59
  return True
60
 
61
  def validate_sql_with_sqlparse(query):
@@ -73,6 +85,8 @@ if user_prompt:
73
 
74
  # Step 9: Generate SQL query using SQL agent
75
  generated_sql = sql_agent.run(f"{user_prompt} {context}")
 
 
76
  st.write(f"Generated SQL Query: {generated_sql}")
77
 
78
  # Step 10: Validate SQL query
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
37
+ # Debug: Display valid columns for user to verify
38
+ st.write(f"Valid columns: {valid_columns}")
39
+
40
  # Step 3: Set up the SQL Database for LangChain
41
  db = SQLDatabase.from_uri('sqlite:///:memory:')
42
  db.raw_connection = conn # Use the in-memory connection for LangChain
43
 
44
  # Step 4: Create the SQL agent with the correct parameter name
45
+ sql_agent = create_sql_agent(
46
+ OpenAI(temperature=0),
47
+ db=db,
48
+ verbose=True,
49
+ max_iterations=15, # Increased iteration limit
50
+ max_execution_time=60 # Set timeout limit to 60 seconds
51
+ )
52
 
53
  # Step 5: Use FAISS with RAG for context retrieval
54
  embeddings = OpenAIEmbeddings()
 
62
  # Step 6: Define SQL validation helpers
63
  def validate_sql(query, valid_columns):
64
  """Validates the SQL query by ensuring it references only valid columns."""
65
+ parsed = sqlparse.parse(query)
66
+ for token in parsed[0].tokens:
67
+ if token.ttype is None: # If it's a column name
68
+ column_name = str(token).strip()
69
+ if column_name not in valid_columns:
70
+ return False
71
  return True
72
 
73
  def validate_sql_with_sqlparse(query):
 
85
 
86
  # Step 9: Generate SQL query using SQL agent
87
  generated_sql = sql_agent.run(f"{user_prompt} {context}")
88
+
89
+ # Debug: Display generated SQL query for inspection
90
  st.write(f"Generated SQL Query: {generated_sql}")
91
 
92
  # Step 10: Validate SQL query