DrishtiSharma commited on
Commit
b40c374
Β·
verified Β·
1 Parent(s): dfda1b6

Delete mylab/app.py

Browse files
Files changed (1) hide show
  1. mylab/app.py +0 -162
mylab/app.py DELETED
@@ -1,162 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import sqlite3
4
- import os
5
- import json
6
- from pathlib import Path
7
- from datetime import datetime, timezone
8
- from crewai import Agent, Crew, Process, Task
9
- from crewai.tools import tool
10
- from langchain_groq import ChatGroq
11
- from langchain_openai import ChatOpenAI
12
- from langchain.schema.output import LLMResult
13
- from langchain_core.callbacks.base import BaseCallbackHandler
14
- from langchain_community.tools.sql_database.tool import (
15
- InfoSQLDatabaseTool,
16
- ListSQLDatabaseTool,
17
- QuerySQLCheckerTool,
18
- QuerySQLDataBaseTool,
19
- )
20
- from langchain_community.utilities.sql_database import SQLDatabase
21
- from datasets import load_dataset
22
- import tempfile
23
-
24
- st.title("SQL-RAG Using CrewAI πŸš€")
25
- st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
26
-
27
- # Initialize LLM
28
- llm = None
29
-
30
- # Model Selection
31
- model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
32
-
33
-
34
- # API Key Validation and LLM Initialization
35
- groq_api_key = os.getenv("GROQ_API_KEY")
36
- openai_api_key = os.getenv("OPENAI_API_KEY")
37
-
38
- if model_choice == "llama-3.3-70b":
39
- if not groq_api_key:
40
- st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
41
- llm = None
42
- else:
43
- llm = ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
44
- elif model_choice == "GPT-4o":
45
- if not openai_api_key:
46
- st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
47
- llm = None
48
- else:
49
- llm = ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
50
-
51
- # Initialize session state for data persistence
52
- if "df" not in st.session_state:
53
- st.session_state.df = None
54
-
55
- # Dataset Input
56
- input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
57
- if input_option == "Use Hugging Face Dataset":
58
- dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
59
- if st.button("Load Dataset"):
60
- try:
61
- with st.spinner("Loading dataset..."):
62
- dataset = load_dataset(dataset_name, split="train")
63
- st.session_state.df = pd.DataFrame(dataset)
64
- st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
- st.dataframe(st.session_state.df.head())
66
- except Exception as e:
67
- st.error(f"Error: {e}")
68
- elif input_option == "Upload CSV File":
69
- uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
70
- if uploaded_file:
71
- st.session_state.df = pd.read_csv(uploaded_file)
72
- st.success("File uploaded successfully!")
73
- st.dataframe(st.session_state.df.head())
74
-
75
- # SQL-RAG Analysis
76
- if st.session_state.df is not None:
77
- temp_dir = tempfile.TemporaryDirectory()
78
- db_path = os.path.join(temp_dir.name, "data.db")
79
- connection = sqlite3.connect(db_path)
80
- st.session_state.df.to_sql("salaries", connection, if_exists="replace", index=False)
81
- db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
82
-
83
- @tool("list_tables")
84
- def list_tables() -> str:
85
- """List all tables in the database."""
86
- return ListSQLDatabaseTool(db=db).invoke("")
87
-
88
- @tool("tables_schema")
89
- def tables_schema(tables: str) -> str:
90
- """Get schema and sample rows for given tables."""
91
- return InfoSQLDatabaseTool(db=db).invoke(tables)
92
-
93
- @tool("execute_sql")
94
- def execute_sql(sql_query: str) -> str:
95
- """Execute a SQL query against the database."""
96
- return QuerySQLDataBaseTool(db=db).invoke(sql_query)
97
-
98
- @tool("check_sql")
99
- def check_sql(sql_query: str) -> str:
100
- """Check the validity of a SQL query."""
101
- return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
102
-
103
- sql_dev = Agent(
104
- role="Senior Database Developer",
105
- goal="Extract data using optimized SQL queries.",
106
- backstory="An expert in writing optimized SQL queries for complex databases.",
107
- llm=llm,
108
- tools=[list_tables, tables_schema, execute_sql, check_sql],
109
- )
110
-
111
- data_analyst = Agent(
112
- role="Senior Data Analyst",
113
- goal="Analyze the data and produce insights.",
114
- backstory="A seasoned analyst who identifies trends and patterns in datasets.",
115
- llm=llm,
116
- )
117
-
118
- report_writer = Agent(
119
- role="Technical Report Writer",
120
- goal="Summarize the insights into a clear report.",
121
- backstory="An expert in summarizing data insights into readable reports.",
122
- llm=llm,
123
- )
124
-
125
- extract_data = Task(
126
- description="Extract data based on the query: {query}.",
127
- expected_output="Database results matching the query.",
128
- agent=sql_dev,
129
- )
130
-
131
- analyze_data = Task(
132
- description="Analyze the extracted data for query: {query}.",
133
- expected_output="Analysis text summarizing findings.",
134
- agent=data_analyst,
135
- context=[extract_data],
136
- )
137
-
138
- write_report = Task(
139
- description="Summarize the analysis into an executive report.",
140
- expected_output="Markdown report of insights.",
141
- agent=report_writer,
142
- context=[analyze_data],
143
- )
144
-
145
- crew = Crew(
146
- agents=[sql_dev, data_analyst, report_writer],
147
- tasks=[extract_data, analyze_data, write_report],
148
- process=Process.sequential,
149
- verbose=True,
150
- )
151
-
152
- query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary for senior employees?'")
153
- if st.button("Submit Query"):
154
- with st.spinner("Processing query..."):
155
- inputs = {"query": query}
156
- result = crew.kickoff(inputs=inputs)
157
- st.markdown("### Analysis Report:")
158
- st.markdown(result)
159
-
160
- temp_dir.cleanup()
161
- else:
162
- st.info("Please load a dataset to proceed.")