cheesyFishes
commited on
Commit
β’
3eeb9d5
1
Parent(s):
59f5daa
update to llamaindex v0.6.13
Browse files- app.py +100 -48
- constants.py +1 -1
- requirements.txt +3 -2
- utils.py +6 -5
app.py
CHANGED
@@ -15,13 +15,15 @@ from constants import (
|
|
15 |
DEFAULT_BUSINESS_TABLE_DESCRP,
|
16 |
DEFAULT_VIOLATIONS_TABLE_DESCRP,
|
17 |
DEFAULT_INSPECTIONS_TABLE_DESCRP,
|
18 |
-
DEFAULT_LC_TOOL_DESCRP
|
19 |
)
|
20 |
from utils import get_sql_index_tool, get_llm
|
21 |
|
22 |
|
23 |
@st.cache_resource
|
24 |
-
def initialize_index(
|
|
|
|
|
25 |
"""Create the GPTSQLStructStoreIndex object."""
|
26 |
llm = get_llm(llm_name, model_temperature, api_key)
|
27 |
|
@@ -30,14 +32,18 @@ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, s
|
|
30 |
|
31 |
context_container = None
|
32 |
if table_context_dict is not None:
|
33 |
-
context_builder = SQLContextContainerBuilder(
|
|
|
|
|
34 |
context_container = context_builder.build_context_container()
|
35 |
-
|
36 |
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
|
37 |
-
index = GPTSQLStructStoreIndex(
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
42 |
return index
|
43 |
|
@@ -45,63 +51,97 @@ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, s
|
|
45 |
@st.cache_resource
|
46 |
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
|
47 |
"""Create a (rather hacky) custom agent and sql_index tool."""
|
48 |
-
sql_tool = Tool(
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
|
52 |
llm = get_llm(llm_name, model_temperature, api_key=api_key)
|
53 |
|
54 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
55 |
|
56 |
-
agent_chain = initialize_agent(
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
return agent_chain
|
59 |
|
60 |
|
61 |
st.title("π¦ Llama Index SQL Sandbox π¦")
|
62 |
-
st.markdown(
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
)
|
|
|
|
|
70 |
|
71 |
|
72 |
-
setup_tab, llama_tab, lc_tab = st.tabs(
|
|
|
|
|
73 |
|
74 |
with setup_tab:
|
75 |
st.subheader("LLM Setup")
|
76 |
api_key = st.text_input("Enter your OpenAI API key here", type="password")
|
77 |
-
llm_name = st.selectbox(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
80 |
st.subheader("Table Setup")
|
81 |
-
business_table_descrp = st.text_area(
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
use_table_descrp = st.checkbox("Use table descriptions?", value=True)
|
90 |
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
|
91 |
|
92 |
with llama_tab:
|
93 |
st.subheader("Text2SQL with Llama Index")
|
94 |
if st.button("Initialize Index", key="init_index_1"):
|
95 |
-
st.session_state[
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
if "llama_index" in st.session_state:
|
98 |
-
query_text = st.text_input(
|
|
|
|
|
|
|
99 |
if st.button("Run Query") and query_text:
|
100 |
with st.spinner("Getting response..."):
|
101 |
try:
|
102 |
-
response =
|
103 |
response_text = str(response)
|
104 |
-
response_sql = response.extra_info[
|
105 |
except Exception as e:
|
106 |
response_text = "Error running SQL Query."
|
107 |
response_sql = str(e)
|
@@ -119,19 +159,31 @@ with lc_tab:
|
|
119 |
st.subheader("Langchain + Llama Index SQL Demo")
|
120 |
|
121 |
if st.button("Initialize Agent"):
|
122 |
-
st.session_state[
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
model_input = "User: " + model_input
|
129 |
-
st.session_state[
|
130 |
with st.spinner("Getting response..."):
|
131 |
-
response =
|
132 |
-
st.session_state[
|
133 |
|
134 |
-
if
|
135 |
-
for msg in st.session_state[
|
136 |
st_message(msg.split("User: ")[-1], is_user="User: " in msg)
|
137 |
-
|
|
|
15 |
DEFAULT_BUSINESS_TABLE_DESCRP,
|
16 |
DEFAULT_VIOLATIONS_TABLE_DESCRP,
|
17 |
DEFAULT_INSPECTIONS_TABLE_DESCRP,
|
18 |
+
DEFAULT_LC_TOOL_DESCRP,
|
19 |
)
|
20 |
from utils import get_sql_index_tool, get_llm
|
21 |
|
22 |
|
23 |
@st.cache_resource
|
24 |
+
def initialize_index(
|
25 |
+
llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH
|
26 |
+
):
|
27 |
"""Create the GPTSQLStructStoreIndex object."""
|
28 |
llm = get_llm(llm_name, model_temperature, api_key)
|
29 |
|
|
|
32 |
|
33 |
context_container = None
|
34 |
if table_context_dict is not None:
|
35 |
+
context_builder = SQLContextContainerBuilder(
|
36 |
+
sql_database, context_dict=table_context_dict
|
37 |
+
)
|
38 |
context_container = context_builder.build_context_container()
|
39 |
+
|
40 |
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
|
41 |
+
index = GPTSQLStructStoreIndex(
|
42 |
+
[],
|
43 |
+
sql_database=sql_database,
|
44 |
+
sql_context_container=context_container,
|
45 |
+
service_context=service_context,
|
46 |
+
)
|
47 |
|
48 |
return index
|
49 |
|
|
|
51 |
@st.cache_resource
|
52 |
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
|
53 |
"""Create a (rather hacky) custom agent and sql_index tool."""
|
54 |
+
sql_tool = Tool(
|
55 |
+
name="SQL Index",
|
56 |
+
func=get_sql_index_tool(
|
57 |
+
_sql_index, _sql_index.sql_context_container.context_dict
|
58 |
+
),
|
59 |
+
description=lc_descrp,
|
60 |
+
)
|
61 |
|
62 |
llm = get_llm(llm_name, model_temperature, api_key=api_key)
|
63 |
|
64 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
65 |
|
66 |
+
agent_chain = initialize_agent(
|
67 |
+
[sql_tool],
|
68 |
+
llm,
|
69 |
+
agent="chat-conversational-react-description",
|
70 |
+
verbose=True,
|
71 |
+
memory=memory,
|
72 |
+
)
|
73 |
|
74 |
return agent_chain
|
75 |
|
76 |
|
77 |
st.title("π¦ Llama Index SQL Sandbox π¦")
|
78 |
+
st.markdown(
|
79 |
+
(
|
80 |
+
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
|
81 |
+
"The database contains information on health violations and inspections at restaurants in San Francisco."
|
82 |
+
"This data is spread across three tables - businesses, inspections, and violations.\n\n"
|
83 |
+
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
|
84 |
+
"The other tabs will perform chatbot and text2sql operations.\n\n"
|
85 |
+
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
|
86 |
+
)
|
87 |
+
)
|
88 |
|
89 |
|
90 |
+
setup_tab, llama_tab, lc_tab = st.tabs(
|
91 |
+
["Setup", "Llama Index", "Langchain+Llama Index"]
|
92 |
+
)
|
93 |
|
94 |
with setup_tab:
|
95 |
st.subheader("LLM Setup")
|
96 |
api_key = st.text_input("Enter your OpenAI API key here", type="password")
|
97 |
+
llm_name = st.selectbox(
|
98 |
+
"Which LLM?", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]
|
99 |
+
)
|
100 |
+
model_temperature = st.slider(
|
101 |
+
"LLM Temperature", min_value=0.0, max_value=1.0, step=0.1
|
102 |
+
)
|
103 |
|
104 |
st.subheader("Table Setup")
|
105 |
+
business_table_descrp = st.text_area(
|
106 |
+
"Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP
|
107 |
+
)
|
108 |
+
violations_table_descrp = st.text_area(
|
109 |
+
"Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP
|
110 |
+
)
|
111 |
+
inspections_table_descrp = st.text_area(
|
112 |
+
"Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP
|
113 |
+
)
|
114 |
+
|
115 |
+
table_context_dict = {
|
116 |
+
"businesses": business_table_descrp,
|
117 |
+
"inspections": inspections_table_descrp,
|
118 |
+
"violations": violations_table_descrp,
|
119 |
+
}
|
120 |
+
|
121 |
use_table_descrp = st.checkbox("Use table descriptions?", value=True)
|
122 |
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
|
123 |
|
124 |
with llama_tab:
|
125 |
st.subheader("Text2SQL with Llama Index")
|
126 |
if st.button("Initialize Index", key="init_index_1"):
|
127 |
+
st.session_state["llama_index"] = initialize_index(
|
128 |
+
llm_name,
|
129 |
+
model_temperature,
|
130 |
+
table_context_dict if use_table_descrp else None,
|
131 |
+
api_key,
|
132 |
+
)
|
133 |
+
|
134 |
if "llama_index" in st.session_state:
|
135 |
+
query_text = st.text_input(
|
136 |
+
"Query:", value="Which restaurant has the most violations?"
|
137 |
+
)
|
138 |
+
use_nl = st.checkbox("Return natural language response?")
|
139 |
if st.button("Run Query") and query_text:
|
140 |
with st.spinner("Getting response..."):
|
141 |
try:
|
142 |
+
response = st.session_state["llama_index"].as_query_engine(synthesize_response=use_nl).query(query_text)
|
143 |
response_text = str(response)
|
144 |
+
response_sql = response.extra_info["sql_query"]
|
145 |
except Exception as e:
|
146 |
response_text = "Error running SQL Query."
|
147 |
response_sql = str(e)
|
|
|
159 |
st.subheader("Langchain + Llama Index SQL Demo")
|
160 |
|
161 |
if st.button("Initialize Agent"):
|
162 |
+
st.session_state["llama_index"] = initialize_index(
|
163 |
+
llm_name,
|
164 |
+
model_temperature,
|
165 |
+
table_context_dict if use_table_descrp else None,
|
166 |
+
api_key,
|
167 |
+
)
|
168 |
+
st.session_state["lc_agent"] = initialize_chain(
|
169 |
+
llm_name,
|
170 |
+
model_temperature,
|
171 |
+
lc_descrp,
|
172 |
+
api_key,
|
173 |
+
st.session_state["llama_index"],
|
174 |
+
)
|
175 |
+
st.session_state["chat_history"] = []
|
176 |
+
|
177 |
+
model_input = st.text_input(
|
178 |
+
"Message:", value="Which restaurant has the most violations?"
|
179 |
+
)
|
180 |
+
if "lc_agent" in st.session_state and st.button("Send"):
|
181 |
model_input = "User: " + model_input
|
182 |
+
st.session_state["chat_history"].append(model_input)
|
183 |
with st.spinner("Getting response..."):
|
184 |
+
response = st.session_state["lc_agent"].run(input=model_input)
|
185 |
+
st.session_state["chat_history"].append(response)
|
186 |
|
187 |
+
if "chat_history" in st.session_state:
|
188 |
+
for msg in st.session_state["chat_history"]:
|
189 |
st_message(msg.split("User: ")[-1], is_user="User: " in msg)
|
|
constants.py
CHANGED
@@ -21,4 +21,4 @@ DEFAULT_INGEST_DOCUMENT = (
|
|
21 |
"The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
|
22 |
"The business achieved a score of 50. We two violations, a high risk "
|
23 |
"vermin infestation as well as a high risk food holding temperatures."
|
24 |
-
)
|
|
|
21 |
"The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
|
22 |
"The business achieved a score of 50. We two violations, a high risk "
|
23 |
"vermin infestation as well as a high risk food holding temperatures."
|
24 |
+
)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
streamlit==1.19.0
|
4 |
streamlit-chat==0.0.2.2
|
|
|
1 |
+
altair==4.2.2
|
2 |
+
langchain==0.0.154
|
3 |
+
llama-index==0.6.13
|
4 |
streamlit==1.19.0
|
5 |
streamlit-chat==0.0.2.2
|
utils.py
CHANGED
@@ -5,21 +5,22 @@ from langchain.chat_models import ChatOpenAI
|
|
5 |
|
6 |
def get_sql_index_tool(sql_index, table_context_dict):
|
7 |
table_context_str = "\n".join(table_context_dict.values())
|
|
|
8 |
def run_sql_index_query(query_text):
|
9 |
try:
|
10 |
-
response = sql_index.query(query_text)
|
11 |
except Exception as e:
|
12 |
return f"Error running SQL {e}.\nNot able to retrieve answer."
|
13 |
text = str(response)
|
14 |
-
sql = response.extra_info[
|
15 |
return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
|
16 |
-
#return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
|
17 |
-
return run_sql_index_query
|
18 |
|
|
|
19 |
|
20 |
|
21 |
def get_llm(llm_name, model_temperature, api_key):
|
22 |
-
os.environ[
|
23 |
if llm_name == "text-davinci-003":
|
24 |
return OpenAI(temperature=model_temperature, model_name=llm_name)
|
25 |
else:
|
|
|
5 |
|
6 |
def get_sql_index_tool(sql_index, table_context_dict):
|
7 |
table_context_str = "\n".join(table_context_dict.values())
|
8 |
+
|
9 |
def run_sql_index_query(query_text):
|
10 |
try:
|
11 |
+
response = sql_index.as_query_engine(synthesize_response=False).query(query_text)
|
12 |
except Exception as e:
|
13 |
return f"Error running SQL {e}.\nNot able to retrieve answer."
|
14 |
text = str(response)
|
15 |
+
sql = response.extra_info["sql_query"]
|
16 |
return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
|
17 |
+
# return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
|
|
|
18 |
|
19 |
+
return run_sql_index_query
|
20 |
|
21 |
|
22 |
def get_llm(llm_name, model_temperature, api_key):
|
23 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
24 |
if llm_name == "text-davinci-003":
|
25 |
return OpenAI(temperature=model_temperature, model_name=llm_name)
|
26 |
else:
|