DrishtiSharma commited on
Commit
c95d3e8
Β·
verified Β·
1 Parent(s): 2b71376

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +34 -17
interim.py CHANGED
@@ -20,9 +20,10 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
 
23
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
24
 
25
- # LLM Logging
26
  class LLMCallbackHandler(BaseCallbackHandler):
27
  def __init__(self, log_path: Path):
28
  self.log_path = log_path
@@ -36,6 +37,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
36
  with self.log_path.open("a", encoding="utf-8") as file:
37
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
38
 
 
39
  llm = ChatGroq(
40
  temperature=0,
41
  model_name="mixtral-8x7b-32768",
@@ -45,7 +47,7 @@ llm = ChatGroq(
45
  st.title("SQL-RAG Using CrewAI πŸš€")
46
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
47
 
48
- # Data Input Options
49
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
50
  df = None
51
 
@@ -67,7 +69,7 @@ else:
67
  st.success("File uploaded successfully!")
68
  st.dataframe(df.head())
69
 
70
- # SQL-RAG and Query Workflow
71
  if df is not None:
72
  temp_dir = tempfile.TemporaryDirectory()
73
  db_path = os.path.join(temp_dir.name, "data.db")
@@ -75,45 +77,60 @@ if df is not None:
75
  df.to_sql("salaries", connection, if_exists="replace", index=False)
76
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
77
 
 
78
  @tool("list_tables")
79
  def list_tables() -> str:
80
- """List all tables in the database."""
81
  return ListSQLDatabaseTool(db=db).invoke("")
82
 
83
  @tool("tables_schema")
84
  def tables_schema(tables: str) -> str:
85
- """Return schema and example rows for given tables."""
 
 
 
 
86
  return InfoSQLDatabaseTool(db=db).invoke(tables)
87
 
88
  @tool("execute_sql")
89
  def execute_sql(sql_query: str) -> str:
90
- """Execute a SQL query and return results."""
 
 
 
 
91
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
92
 
93
  @tool("check_sql")
94
  def check_sql(sql_query: str) -> str:
95
- """Check SQL query validity."""
 
 
 
 
96
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
97
 
 
98
  sql_dev = Agent(
99
- role="Senior Database Developer",
100
- goal="Construct and execute SQL queries.",
101
  llm=llm,
102
  tools=[list_tables, tables_schema, execute_sql, check_sql],
103
  )
104
 
105
  data_analyst = Agent(
106
- role="Senior Data Analyst",
107
- goal="Analyze the data returned from SQL queries.",
108
  llm=llm,
109
  )
110
 
111
  report_writer = Agent(
112
- role="Senior Report Editor",
113
- goal="Summarize the analysis into a short report.",
114
  llm=llm,
115
  )
116
 
 
117
  extract_data = Task(
118
  description="Extract data for the query: {query}.",
119
  expected_output="Database query results.",
@@ -122,14 +139,14 @@ if df is not None:
122
 
123
  analyze_data = Task(
124
  description="Analyze the query results for: {query}.",
125
- expected_output="Detailed analysis report.",
126
  agent=data_analyst,
127
  context=[extract_data],
128
  )
129
 
130
  write_report = Task(
131
- description="Summarize the analysis into a brief executive summary.",
132
- expected_output="Markdown report.",
133
  agent=report_writer,
134
  context=[analyze_data],
135
  )
@@ -138,7 +155,7 @@ if df is not None:
138
  agents=[sql_dev, data_analyst, report_writer],
139
  tasks=[extract_data, analyze_data, write_report],
140
  process=Process.sequential,
141
- verbose=2,
142
  )
143
 
144
  query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
+ # Environment setup
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # LLM Callback Logger
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
+ # Initialize the LLM
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
 
47
  st.title("SQL-RAG Using CrewAI πŸš€")
48
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
 
50
+ # Input Options
51
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
52
  df = None
53
 
 
69
  st.success("File uploaded successfully!")
70
  st.dataframe(df.head())
71
 
72
+ # SQL-RAG Analysis
73
  if df is not None:
74
  temp_dir = tempfile.TemporaryDirectory()
75
  db_path = os.path.join(temp_dir.name, "data.db")
 
77
  df.to_sql("salaries", connection, if_exists="replace", index=False)
78
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
 
80
+ # Tools with proper docstrings
81
  @tool("list_tables")
82
  def list_tables() -> str:
83
+ """List all tables in the SQLite database."""
84
  return ListSQLDatabaseTool(db=db).invoke("")
85
 
86
  @tool("tables_schema")
87
  def tables_schema(tables: str) -> str:
88
+ """
89
+ Get the schema and sample rows for specific tables in the database.
90
+ Input: Comma-separated table names.
91
+ Example: 'salaries'
92
+ """
93
  return InfoSQLDatabaseTool(db=db).invoke(tables)
94
 
95
  @tool("execute_sql")
96
  def execute_sql(sql_query: str) -> str:
97
+ """
98
+ Execute a valid SQL query on the database and return the results.
99
+ Input: A SQL query string.
100
+ Example: 'SELECT * FROM salaries LIMIT 5;'
101
+ """
102
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
103
 
104
  @tool("check_sql")
105
  def check_sql(sql_query: str) -> str:
106
+ """
107
+ Check the validity of a SQL query before execution.
108
+ Input: A SQL query string.
109
+ Example: 'SELECT salary FROM salaries WHERE salary > 10000;'
110
+ """
111
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
112
 
113
+ # Agents
114
  sql_dev = Agent(
115
+ role="Database Developer",
116
+ goal="Extract relevant data by executing SQL queries.",
117
  llm=llm,
118
  tools=[list_tables, tables_schema, execute_sql, check_sql],
119
  )
120
 
121
  data_analyst = Agent(
122
+ role="Data Analyst",
123
+ goal="Analyze the extracted data and generate detailed insights.",
124
  llm=llm,
125
  )
126
 
127
  report_writer = Agent(
128
+ role="Report Writer",
129
+ goal="Summarize the analysis into an executive report.",
130
  llm=llm,
131
  )
132
 
133
+ # Tasks
134
  extract_data = Task(
135
  description="Extract data for the query: {query}.",
136
  expected_output="Database query results.",
 
139
 
140
  analyze_data = Task(
141
  description="Analyze the query results for: {query}.",
142
+ expected_output="Analysis report.",
143
  agent=data_analyst,
144
  context=[extract_data],
145
  )
146
 
147
  write_report = Task(
148
+ description="Summarize the analysis into an executive summary.",
149
+ expected_output="Markdown-formatted report.",
150
  agent=report_writer,
151
  context=[analyze_data],
152
  )
 
155
  agents=[sql_dev, data_analyst, report_writer],
156
  tasks=[extract_data, analyze_data, write_report],
157
  process=Process.sequential,
158
+ verbose=True,
159
  )
160
 
161
  query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")