DrishtiSharma commited on
Commit
c9e66b7
Β·
verified Β·
1 Parent(s): 9abae49

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +43 -38
interim.py CHANGED
@@ -7,7 +7,6 @@ from pathlib import Path
7
  from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai_tools import tool
10
- from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_groq import ChatGroq
12
  from langchain.schema.output import LLMResult
13
  from langchain_core.callbacks.base import BaseCallbackHandler
@@ -18,18 +17,13 @@ from langchain_community.tools.sql_database.tool import (
18
  QuerySQLDataBaseTool,
19
  )
20
  from langchain_community.utilities.sql_database import SQLDatabase
 
21
  import tempfile
22
 
23
- # Setup GROQ API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # Callback handler for logging LLM responses
27
- class Event:
28
- def __init__(self, event, text):
29
- self.event = event
30
- self.timestamp = datetime.now(timezone.utc).isoformat()
31
- self.text = text
32
-
33
  class LLMCallbackHandler(BaseCallbackHandler):
34
  def __init__(self, log_path: Path):
35
  self.log_path = log_path
@@ -50,32 +44,44 @@ llm = ChatGroq(
50
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
51
  )
52
 
53
- # App Header
54
- st.title("SQL-RAG with CrewAI πŸš€")
55
- st.write("Provide your query, and the app will extract, analyze, and summarize the data dynamically.")
56
-
57
- # File Upload for Dataset
58
- uploaded_file = st.file_uploader("Upload your dataset (CSV file)", type=["csv"])
59
-
60
- if uploaded_file:
61
- st.success("File uploaded successfully!")
62
-
63
- # Temporary directory for SQLite DB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  temp_dir = tempfile.TemporaryDirectory()
65
  db_path = os.path.join(temp_dir.name, "data.db")
66
-
67
- # Create SQLite database
68
- df = pd.read_csv(uploaded_file)
69
  connection = sqlite3.connect(db_path)
70
  df.to_sql("data_table", connection, if_exists="replace", index=False)
71
-
72
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
73
 
74
  # Tools
75
  @tool("list_tables")
76
  def list_tables() -> str:
77
  return ListSQLDatabaseTool(db=db).invoke("")
78
-
79
  @tool("tables_schema")
80
  def tables_schema(tables: str) -> str:
81
  return InfoSQLDatabaseTool(db=db).invoke(tables)
@@ -90,23 +96,23 @@ if uploaded_file:
90
 
91
  # Agents
92
  sql_dev = Agent(
93
- role="Senior Database Developer",
94
- goal="Extract data from the database based on user query",
95
  llm=llm,
96
  tools=[list_tables, tables_schema, execute_sql, check_sql],
97
  allow_delegation=False,
98
  )
99
 
100
  data_analyst = Agent(
101
- role="Senior Data Analyst",
102
- goal="Analyze the database response and provide insights",
103
  llm=llm,
104
  allow_delegation=False,
105
  )
106
 
107
  report_writer = Agent(
108
- role="Senior Report Editor",
109
- goal="Summarize the analysis into a short report",
110
  llm=llm,
111
  allow_delegation=False,
112
  )
@@ -119,20 +125,19 @@ if uploaded_file:
119
  )
120
 
121
  analyze_data = Task(
122
- description="Analyze the data and generate insights for: {query}.",
123
  expected_output="Detailed analysis text",
124
  agent=data_analyst,
125
  context=[extract_data],
126
  )
127
 
128
  write_report = Task(
129
- description="Summarize the analysis into a concise executive report.",
130
  expected_output="Markdown report",
131
  agent=report_writer,
132
  context=[analyze_data],
133
  )
134
 
135
- # Crew
136
  crew = Crew(
137
  agents=[sql_dev, data_analyst, report_writer],
138
  tasks=[extract_data, analyze_data, write_report],
@@ -141,8 +146,7 @@ if uploaded_file:
141
  memory=False,
142
  )
143
 
144
- # User Input Query
145
- query = st.text_input("Enter your query:")
146
  if query:
147
  with st.spinner("Processing your query..."):
148
  inputs = {"query": query}
@@ -150,5 +154,6 @@ if uploaded_file:
150
  st.markdown("### Analysis Report:")
151
  st.markdown(result)
152
 
153
- # Clean up
154
- temp_dir.cleanup()
 
 
7
  from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai_tools import tool
 
10
  from langchain_groq import ChatGroq
11
  from langchain.schema.output import LLMResult
12
  from langchain_core.callbacks.base import BaseCallbackHandler
 
17
  QuerySQLDataBaseTool,
18
  )
19
  from langchain_community.utilities.sql_database import SQLDatabase
20
+ from datasets import load_dataset
21
  import tempfile
22
 
23
+ # Setup API key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # Callback handler for logging
 
 
 
 
 
 
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
+ st.title("SQL-RAG using CrewAI πŸš€")
48
+ st.write("Analyze and summarize data using natural language queries with SQL-based retrieval.")
49
+
50
+ # File upload or Hugging Face dataset input
51
+ option = st.radio("Choose your input method:", ["Upload a CSV file", "Enter Hugging Face dataset name"])
52
+
53
+ if option == "Upload a CSV file":
54
+ uploaded_file = st.file_uploader("Upload your dataset (CSV format)", type=["csv"])
55
+ if uploaded_file:
56
+ df = pd.read_csv(uploaded_file)
57
+ st.success("File uploaded successfully!")
58
+ else:
59
+ dataset_name = st.text_input("Enter Hugging Face dataset name:", placeholder="e.g., imdb, ag_news")
60
+ if dataset_name:
61
+ try:
62
+ dataset = load_dataset(dataset_name, split="train")
63
+ df = pd.DataFrame(dataset)
64
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
+ except Exception as e:
66
+ st.error(f"Error loading Hugging Face dataset: {e}")
67
+ df = None
68
+
69
+ if 'df' in locals() and not df.empty:
70
+ st.write("### Dataset Preview:")
71
+ st.dataframe(df.head())
72
+
73
+ # Create a temporary SQLite database
74
  temp_dir = tempfile.TemporaryDirectory()
75
  db_path = os.path.join(temp_dir.name, "data.db")
 
 
 
76
  connection = sqlite3.connect(db_path)
77
  df.to_sql("data_table", connection, if_exists="replace", index=False)
 
78
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
 
80
  # Tools
81
  @tool("list_tables")
82
  def list_tables() -> str:
83
  return ListSQLDatabaseTool(db=db).invoke("")
84
+
85
  @tool("tables_schema")
86
  def tables_schema(tables: str) -> str:
87
  return InfoSQLDatabaseTool(db=db).invoke(tables)
 
96
 
97
  # Agents
98
  sql_dev = Agent(
99
+ role="Database Developer",
100
+ goal="Extract data from the database.",
101
  llm=llm,
102
  tools=[list_tables, tables_schema, execute_sql, check_sql],
103
  allow_delegation=False,
104
  )
105
 
106
  data_analyst = Agent(
107
+ role="Data Analyst",
108
+ goal="Analyze and provide insights.",
109
  llm=llm,
110
  allow_delegation=False,
111
  )
112
 
113
  report_writer = Agent(
114
+ role="Report Editor",
115
+ goal="Summarize the analysis.",
116
  llm=llm,
117
  allow_delegation=False,
118
  )
 
125
  )
126
 
127
  analyze_data = Task(
128
+ description="Analyze the data for: {query}.",
129
  expected_output="Detailed analysis text",
130
  agent=data_analyst,
131
  context=[extract_data],
132
  )
133
 
134
  write_report = Task(
135
+ description="Summarize the analysis into a short report.",
136
  expected_output="Markdown report",
137
  agent=report_writer,
138
  context=[analyze_data],
139
  )
140
 
 
141
  crew = Crew(
142
  agents=[sql_dev, data_analyst, report_writer],
143
  tasks=[extract_data, analyze_data, write_report],
 
146
  memory=False,
147
  )
148
 
149
+ query = st.text_input("Enter your query:", placeholder="e.g., 'What are the top 5 highest salaries?'")
 
150
  if query:
151
  with st.spinner("Processing your query..."):
152
  inputs = {"query": query}
 
154
  st.markdown("### Analysis Report:")
155
  st.markdown(result)
156
 
157
+ temp_dir.cleanup()
158
+ else:
159
+ st.warning("Please upload a valid file or provide a correct Hugging Face dataset name.")