Spaces:
Sleeping
Sleeping
cheesyFishes
commited on
Commit
•
3c8ea82
1
Parent(s):
4599370
Add files
Browse files- .gitattributes +1 -0
- README.md +4 -4
- app.py +137 -0
- constants.py +24 -0
- requirements.txt +4 -0
- sfscores.sqlite +3 -0
- sql_index.json +1 -0
- utils.py +26 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*sqlite filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
title: Llama Index Sql Sandbox
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
title: Llama Index Sql Sandbox
|
3 |
+
emoji: 🦙
|
4 |
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.19.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
---
|
12 |
|
app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit_chat import message as st_message
|
4 |
+
from sqlalchemy import create_engine
|
5 |
+
|
6 |
+
from langchain.agents import Tool, initialize_agent
|
7 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
8 |
+
|
9 |
+
from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext
|
10 |
+
from llama_index import SQLDatabase as llama_SQLDatabase
|
11 |
+
from llama_index.indices.struct_store import SQLContextContainerBuilder
|
12 |
+
|
13 |
+
from constants import (
|
14 |
+
DEFAULT_SQL_PATH,
|
15 |
+
DEFAULT_BUSINESS_TABLE_DESCRP,
|
16 |
+
DEFAULT_VIOLATIONS_TABLE_DESCRP,
|
17 |
+
DEFAULT_INSPECTIONS_TABLE_DESCRP,
|
18 |
+
DEFAULT_LC_TOOL_DESCRP
|
19 |
+
)
|
20 |
+
from utils import get_sql_index_tool, get_llm
|
21 |
+
|
22 |
+
|
23 |
+
@st.cache_resource
|
24 |
+
def initialize_index(llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH):
|
25 |
+
"""Create the GPTSQLStructStoreIndex object."""
|
26 |
+
llm = get_llm(llm_name, model_temperature, api_key)
|
27 |
+
|
28 |
+
engine = create_engine(sql_path)
|
29 |
+
sql_database = llama_SQLDatabase(engine)
|
30 |
+
|
31 |
+
context_container = None
|
32 |
+
if table_context_dict is not None:
|
33 |
+
context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
|
34 |
+
context_container = context_builder.build_context_container()
|
35 |
+
|
36 |
+
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
|
37 |
+
index = GPTSQLStructStoreIndex([],
|
38 |
+
sql_database=sql_database,
|
39 |
+
sql_context_container=context_container,
|
40 |
+
service_context=service_context)
|
41 |
+
|
42 |
+
return index
|
43 |
+
|
44 |
+
|
45 |
+
@st.cache_resource
|
46 |
+
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
|
47 |
+
"""Create a (rather hacky) custom agent and sql_index tool."""
|
48 |
+
sql_tool = Tool(name="SQL Index",
|
49 |
+
func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
|
50 |
+
description=lc_descrp)
|
51 |
+
|
52 |
+
llm = get_llm(llm_name, model_temperature, api_key=api_key)
|
53 |
+
|
54 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
55 |
+
|
56 |
+
agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory)
|
57 |
+
|
58 |
+
return agent_chain
|
59 |
+
|
60 |
+
|
61 |
+
st.title("🦙 Llama Index SQL Sandbox 🦙")
|
62 |
+
st.markdown((
|
63 |
+
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
|
64 |
+
"The database contains information on health violations and inspections at restaurants in San Francisco."
|
65 |
+
"This data is spread across three tables - businesses, inspections, and violations.\n\n"
|
66 |
+
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
|
67 |
+
"The other tabs will perform chatbot and text2sql operations.\n\n"
|
68 |
+
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
|
69 |
+
))
|
70 |
+
|
71 |
+
|
72 |
+
setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])
|
73 |
+
|
74 |
+
with setup_tab:
|
75 |
+
st.subheader("LLM Setup")
|
76 |
+
api_key = st.text_input("Enter your OpenAI API key here", type="password")
|
77 |
+
llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])
|
78 |
+
model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
|
79 |
+
|
80 |
+
st.subheader("Table Setup")
|
81 |
+
business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
|
82 |
+
violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
|
83 |
+
inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP)
|
84 |
+
|
85 |
+
table_context_dict = {"businesses": business_table_descrp,
|
86 |
+
"inspections": inspections_table_descrp,
|
87 |
+
"violations": violations_table_descrp}
|
88 |
+
|
89 |
+
use_table_descrp = st.checkbox("Use table descriptions?", value=True)
|
90 |
+
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
|
91 |
+
|
92 |
+
with llama_tab:
|
93 |
+
st.subheader("Text2SQL with Llama Index")
|
94 |
+
if st.button("Initialize Index", key="init_index_1"):
|
95 |
+
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
|
96 |
+
|
97 |
+
if "llama_index" in st.session_state:
|
98 |
+
query_text = st.text_input("Query:", value="Which restaurant has the most violations?")
|
99 |
+
if st.button("Run Query") and query_text:
|
100 |
+
with st.spinner("Getting response..."):
|
101 |
+
try:
|
102 |
+
response = st.session_state['llama_index'].query(query_text)
|
103 |
+
response_text = str(response)
|
104 |
+
response_sql = response.extra_info['sql_query']
|
105 |
+
except Exception as e:
|
106 |
+
response_text = "Error running SQL Query."
|
107 |
+
response_sql = str(e)
|
108 |
+
|
109 |
+
col1, col2 = st.columns(2)
|
110 |
+
with col1:
|
111 |
+
st.text("SQL Result:")
|
112 |
+
st.markdown(response_text)
|
113 |
+
|
114 |
+
with col2:
|
115 |
+
st.text("SQL Query:")
|
116 |
+
st.markdown(response_sql)
|
117 |
+
|
118 |
+
with lc_tab:
|
119 |
+
st.subheader("Langchain + Llama Index SQL Demo")
|
120 |
+
|
121 |
+
if st.button("Initialize Agent"):
|
122 |
+
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
|
123 |
+
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, api_key, st.session_state['llama_index'])
|
124 |
+
st.session_state['chat_history'] = []
|
125 |
+
|
126 |
+
model_input = st.text_input("Message:", value="Which restaurant has the most violations?")
|
127 |
+
if 'lc_agent' in st.session_state and st.button("Send"):
|
128 |
+
model_input = "User: " + model_input
|
129 |
+
st.session_state['chat_history'].append(model_input)
|
130 |
+
with st.spinner("Getting response..."):
|
131 |
+
response = st.session_state['lc_agent'].run(input=model_input)
|
132 |
+
st.session_state['chat_history'].append(response)
|
133 |
+
|
134 |
+
if 'chat_history' in st.session_state:
|
135 |
+
for msg in st.session_state['chat_history']:
|
136 |
+
st_message(msg.split("User: ")[-1], is_user="User: " in msg)
|
137 |
+
|
constants.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_SQL_PATH = "sqlite:///sfscores.sqlite"
|
2 |
+
DEFAULT_BUSINESS_TABLE_DESCRP = (
|
3 |
+
"This table gives information on the IDs, addresses, and other location "
|
4 |
+
"information for several restaurants in San Francisco. This table will "
|
5 |
+
"need to be referenced when users ask about specific businesses."
|
6 |
+
)
|
7 |
+
DEFAULT_VIOLATIONS_TABLE_DESCRP = (
|
8 |
+
"This table gives information on which business IDs have recorded health violations, "
|
9 |
+
"including the date, risk, and description of each violation. The user may query "
|
10 |
+
"about specific businesses, whose names can be found by mapping the business_id "
|
11 |
+
"to the 'businesses' table."
|
12 |
+
)
|
13 |
+
DEFAULT_INSPECTIONS_TABLE_DESCRP = (
|
14 |
+
"This table gives information on when each business ID was inspected, including "
|
15 |
+
"the score, date, and type of inspection. The user may query about specific "
|
16 |
+
"businesses, whose names can be found by mapping the business_id to the 'businesses' table."
|
17 |
+
)
|
18 |
+
DEFAULT_LC_TOOL_DESCRP = "Useful for when you want to answer queries about violations and inspections of businesses."
|
19 |
+
|
20 |
+
DEFAULT_INGEST_DOCUMENT = (
|
21 |
+
"The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
|
22 |
+
"The business achieved a score of 50. We two violations, a high risk "
|
23 |
+
"vermin infestation as well as a high risk food holding temperatures."
|
24 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain==0.0.123
|
2 |
+
llama-index==0.5.1
|
3 |
+
streamlit==1.19.0
|
4 |
+
streamlit-chat==0.0.2.2
|
sfscores.sqlite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:240deebf58f54606266cdd4a5dca14c48d58f8b530941c0249a9b23f00589afa
|
3 |
+
size 9639936
|
sql_index.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"index_struct_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "docstore": {"docs": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"text": null, "doc_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "embedding": null, "doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4", "extra_info": null, "context_dict": {}, "__type__": "sql"}}, "ref_doc_info": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4"}}}, "sql_context_container": {"context_dict": {"violations": "Schema of table violations:\nTable 'violations' has columns: business_id (TEXT), date (TEXT), ViolationTypeID (TEXT), risk_category (TEXT), description (TEXT) and foreign keys: .\nContext of table violations:\nThis table gives information on which business IDs have recorded health violations, including the date, risk, and description of each violation. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table.", "businesses": "Schema of table businesses:\nTable 'businesses' has columns: business_id (INTEGER), name (VARCHAR(64)), address (VARCHAR(50)), city (VARCHAR(23)), postal_code (VARCHAR(9)), latitude (FLOAT), longitude (FLOAT), phone_number (BIGINT), TaxCode (VARCHAR(4)), business_certificate (INTEGER), application_date (DATE), owner_name (VARCHAR(99)), owner_address (VARCHAR(74)), owner_city (VARCHAR(22)), owner_state (VARCHAR(14)), owner_zip (VARCHAR(15)) and foreign keys: .\nContext of table businesses:\nThis table gives information on the IDs, addresses, and other location information for several restaruants in San Fransisco. This table will need to be referenced when users ask about specific bussinesses.", "inspections": "Schema of table inspections:\nTable 'inspections' has columns: business_id (TEXT), Score (INTEGER), date (TEXT), type (VARCHAR(33)) and foreign keys: .\nContext of table inspections:\nThis table gives information on when each bussiness ID was inspected, including the score, date, and type of inspection. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table."}, "context_str": null}}
|
utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain import OpenAI
|
3 |
+
from langchain.chat_models import ChatOpenAI
|
4 |
+
|
5 |
+
|
6 |
+
def get_sql_index_tool(sql_index, table_context_dict):
|
7 |
+
table_context_str = "\n".join(table_context_dict.values())
|
8 |
+
def run_sql_index_query(query_text):
|
9 |
+
try:
|
10 |
+
response = sql_index.query(query_text)
|
11 |
+
except Exception as e:
|
12 |
+
return f"Error running SQL {e}.\nNot able to retrieve answer."
|
13 |
+
text = str(response)
|
14 |
+
sql = response.extra_info['sql_query']
|
15 |
+
return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
|
16 |
+
#return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
|
17 |
+
return run_sql_index_query
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def get_llm(llm_name, model_temperature, api_key):
|
22 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
23 |
+
if llm_name == "text-davinci-003":
|
24 |
+
return OpenAI(temperature=model_temperature, model_name=llm_name)
|
25 |
+
else:
|
26 |
+
return ChatOpenAI(temperature=model_temperature, model_name=llm_name)
|