Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from streamlit_chat import message as st_message | |
from sqlalchemy import create_engine | |
from langchain.agents import Tool, initialize_agent | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext | |
from llama_index import SQLDatabase as llama_SQLDatabase | |
from llama_index.indices.struct_store import SQLContextContainerBuilder | |
from constants import ( | |
DEFAULT_SQL_PATH, | |
DEFAULT_BUSINESS_TABLE_DESCRP, | |
DEFAULT_VIOLATIONS_TABLE_DESCRP, | |
DEFAULT_INSPECTIONS_TABLE_DESCRP, | |
DEFAULT_LC_TOOL_DESCRP, | |
) | |
from utils import get_sql_index_tool, get_llm | |
def initialize_index( | |
llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH | |
): | |
"""Create the GPTSQLStructStoreIndex object.""" | |
llm = get_llm(llm_name, model_temperature, api_key) | |
engine = create_engine(sql_path) | |
sql_database = llama_SQLDatabase(engine) | |
context_container = None | |
if table_context_dict is not None: | |
context_builder = SQLContextContainerBuilder( | |
sql_database, context_dict=table_context_dict | |
) | |
context_container = context_builder.build_context_container() | |
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm)) | |
index = GPTSQLStructStoreIndex( | |
[], | |
sql_database=sql_database, | |
sql_context_container=context_container, | |
service_context=service_context, | |
) | |
return index | |
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index): | |
"""Create a (rather hacky) custom agent and sql_index tool.""" | |
sql_tool = Tool( | |
name="SQL Index", | |
func=get_sql_index_tool( | |
_sql_index, _sql_index.sql_context_container.context_dict | |
), | |
description=lc_descrp, | |
) | |
llm = get_llm(llm_name, model_temperature, api_key=api_key) | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
agent_chain = initialize_agent( | |
[sql_tool], | |
llm, | |
agent="chat-conversational-react-description", | |
verbose=True, | |
memory=memory, | |
) | |
return agent_chain | |
st.title("π¦ Llama Index SQL Sandbox π¦") | |
st.markdown( | |
( | |
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n" | |
"The database contains information on health violations and inspections at restaurants in San Francisco." | |
"This data is spread across three tables - businesses, inspections, and violations.\n\n" | |
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain." | |
"The other tabs will perform chatbot and text2sql operations.\n\n" | |
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)" | |
) | |
) | |
setup_tab, llama_tab, lc_tab = st.tabs( | |
["Setup", "Llama Index", "Langchain+Llama Index"] | |
) | |
with setup_tab: | |
st.subheader("LLM Setup") | |
api_key = st.text_input("Enter your OpenAI API key here", type="password") | |
llm_name = st.selectbox( | |
"Which LLM?", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"] | |
) | |
model_temperature = st.slider( | |
"LLM Temperature", min_value=0.0, max_value=1.0, step=0.1 | |
) | |
st.subheader("Table Setup") | |
business_table_descrp = st.text_area( | |
"Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP | |
) | |
violations_table_descrp = st.text_area( | |
"Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP | |
) | |
inspections_table_descrp = st.text_area( | |
"Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP | |
) | |
table_context_dict = { | |
"businesses": business_table_descrp, | |
"inspections": inspections_table_descrp, | |
"violations": violations_table_descrp, | |
} | |
use_table_descrp = st.checkbox("Use table descriptions?", value=True) | |
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP) | |
with llama_tab: | |
st.subheader("Text2SQL with Llama Index") | |
if st.button("Initialize Index", key="init_index_1"): | |
st.session_state["llama_index"] = initialize_index( | |
llm_name, | |
model_temperature, | |
table_context_dict if use_table_descrp else None, | |
api_key, | |
) | |
if "llama_index" in st.session_state: | |
query_text = st.text_input( | |
"Query:", value="Which restaurant has the most violations?" | |
) | |
use_nl = st.checkbox("Return natural language response?") | |
if st.button("Run Query") and query_text: | |
with st.spinner("Getting response..."): | |
try: | |
response = st.session_state["llama_index"].as_query_engine(synthesize_response=use_nl).query(query_text) | |
response_text = str(response) | |
response_sql = response.extra_info["sql_query"] | |
except Exception as e: | |
response_text = "Error running SQL Query." | |
response_sql = str(e) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.text("SQL Result:") | |
st.markdown(response_text) | |
with col2: | |
st.text("SQL Query:") | |
st.markdown(response_sql) | |
with lc_tab: | |
st.subheader("Langchain + Llama Index SQL Demo") | |
if st.button("Initialize Agent"): | |
st.session_state["llama_index"] = initialize_index( | |
llm_name, | |
model_temperature, | |
table_context_dict if use_table_descrp else None, | |
api_key, | |
) | |
st.session_state["lc_agent"] = initialize_chain( | |
llm_name, | |
model_temperature, | |
lc_descrp, | |
api_key, | |
st.session_state["llama_index"], | |
) | |
st.session_state["chat_history"] = [] | |
model_input = st.text_input( | |
"Message:", value="Which restaurant has the most violations?" | |
) | |
if "lc_agent" in st.session_state and st.button("Send"): | |
model_input = "User: " + model_input | |
st.session_state["chat_history"].append(model_input) | |
with st.spinner("Getting response..."): | |
response = st.session_state["lc_agent"].run(input=model_input) | |
st.session_state["chat_history"].append(response) | |
if "chat_history" in st.session_state: | |
for msg in st.session_state["chat_history"]: | |
st_message(msg.split("User: ")[-1], is_user="User: " in msg) | |