hlydecker's picture
Update streamlit_langchain_chat/streamlit_app.py
c356dff
raw
history blame contribute delete
No virus
20.7 kB
# -*- coding: utf-8 -*-
"""
To run:
- activate the virtual environment
- streamlit run path\to\streamlit_app.py
"""
import logging
import os
import re
import sys
import time
import warnings
import shutil
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
import openai
import pandas as pd
import streamlit as st
from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode
from streamlit_chat import message
from streamlit_langchain_chat.constants import *
from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat
from streamlit_langchain_chat.dataset import Dataset
# Configure logger
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
warnings.filterwarnings('ignore')
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'costs' not in st.session_state:
st.session_state['costs'] = []
if 'contexts' not in st.session_state:
st.session_state['contexts'] = []
if 'chunks' not in st.session_state:
st.session_state['chunks'] = []
if 'user_input' not in st.session_state:
st.session_state['user_input'] = ""
if 'dataset' not in st.session_state:
st.session_state['dataset'] = None
def check_api_keys() -> bool:
source_id = app.params['source_id']
index_id = app.params['index_id']
open_api_key = os.getenv('OPENAI_API_KEY', '')
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False
return is_ready
def check_combination_point() -> bool:
type_id = app.params['type_id']
open_api_key = os.getenv('OPENAI_API_KEY', '')
openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0
api_base = app.params['api_base']
if type_id == 1:
deployment_id = app.params['deployment_id']
return True if openapi_api_key_ready and api_base and deployment_id else False
elif type_id == 2:
return True if openapi_api_key_ready and api_base else False
else:
return False
def check_index() -> bool:
dataset = st.session_state['dataset']
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
without_source = app.params['source_id'] == 4
is_ready = True if index_built or without_source else False
return is_ready
def check_index_point() -> bool:
index_id = app.params['index_id']
pinecone_api_key = os.getenv('PINECONE_API_KEY', '')
pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True
pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True
is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False
return is_ready
def check_params_point() -> bool:
max_sources = app.params['max_sources']
temperature = app.params['temperature']
is_ready = True if max_sources and isinstance(temperature, float) else False
return is_ready
def check_source_point() -> bool:
return True
def clear_chat_history():
if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']:
st.session_state['past'] = []
st.session_state['generated'] = []
st.session_state['contexts'] = []
st.session_state['chunks'] = []
st.session_state['costs'] = []
def clear_index():
if dataset := st.session_state['dataset']:
# delete directory (with files)
index_path = dataset.index_path
if index_path.exists():
shutil.rmtree(str(index_path))
# update variable
st.session_state['dataset'] = None
elif (TEMP_DIR / "default").exists():
shutil.rmtree(str(TEMP_DIR / "default"))
def check_sources() -> bool:
uploaded_files_rows = app.params['uploaded_files_rows']
urls_df = app.params['urls_df']
source_id = app.params['source_id']
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
only_local_files = some_files and not some_urls
only_urls = not some_files and some_urls
is_ready = only_local_files or only_urls or (source_id == 4)
return is_ready
def collect_dataset_and_built_index():
start = time.time()
uploaded_files_rows = app.params['uploaded_files_rows']
urls_df = app.params['urls_df']
type_id = app.params['type_id']
temperature = app.params['temperature']
index_id = app.params['index_id']
api_base = app.params['api_base']
deployment_id = app.params['deployment_id']
some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False
some_urls = bool([True for url, citation in urls_df.to_numpy() if url])
openai.api_type = "azure" if type_id == 1 else "open_ai"
openai.api_base = api_base
openai.api_version = "2023-03-15-preview" if type_id == 1 else None
if deployment_id != "text-davinci-003":
dataset = Dataset(
llm=ChatOpenAI(
temperature=temperature,
max_tokens=512,
deployment_id=deployment_id,
)
)
else:
dataset = Dataset(
llm=OpenAI(
temperature=temperature,
max_tokens=512,
deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'),
)
)
# get url documents
if some_urls:
urls_df = urls_df.reset_index()
for url_index, url_row in urls_df.iterrows():
url = url_row.get('urls', '')
citation = url_row.get('citation string', '')
if url:
try:
dataset.add(
url,
citation,
citation,
disable_check=True # True to accept Japanese letters
)
except Exception as e:
print(e)
pass
# dataset is pandas dataframe
if some_files:
for uploaded_files_row in uploaded_files_rows:
key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None
dataset.add(
uploaded_files_row.get('filepath'),
uploaded_files_row.get('citation string'),
key=key,
disable_check=True # True to accept Japanese letters
)
openai_embeddings = OpenAIEmbeddings(
document_model_name="text-embedding-ada-002",
query_model_name="text-embedding-ada-002",
)
if index_id == 1:
dataset._build_faiss_index(openai_embeddings)
else:
dataset._build_pinecone_index(openai_embeddings)
st.session_state['dataset'] = dataset
if OPERATING_MODE == "debug":
print(f"time to collect dataset: {time.time() - start:.2f} [s]")
def configure_streamlit_and_page():
# Configure Streamlit page and state
st.set_page_config(**ST_CONFIG)
# Force responsive layout for columns also on mobile
st.write(
"""<style>
[data-testid="column"] {
width: calc(50% - 1rem);
flex: 1 1 calc(50% - 1rem);
min-width: calc(50% - 1rem);
}
</style>""",
unsafe_allow_html=True,
)
def get_answer():
query = st.session_state['user_input']
dataset = st.session_state['dataset']
type_id = app.params['type_id']
index_id = app.params['index_id']
max_sources = app.params['max_sources']
if query and dataset and type_id and index_id:
chat_history = [(past, generated)
for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])]
marginal_relevance = False if not index_id == 1 else True
start = time.time()
openai_embeddings = OpenAIEmbeddings(
document_model_name="text-embedding-ada-002",
query_model_name="text-embedding-ada-002",
)
result = dataset.query(
query,
openai_embeddings,
chat_history,
marginal_relevance=marginal_relevance, # if pinecone is used it must be False
)
if OPERATING_MODE == "debug":
print(f"time to get answer: {time.time() - start:.2f} [s]")
print("-" * 10)
# response = {'generated_text': result.formatted_answer}
# response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug
return result
else:
return None
def load_main_page():
"""
Load the body of web.
"""
# Streamlit HTML Markdown
# st.title <h1> #
# st.header <h2> ##
# st.subheader <h3> ###
st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})")
validate_status()
st.markdown(f"#### **Status**: {app.params['status']}")
# hidden div with anchor
st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True)
col1, col2, col3 = st.columns(3)
col1.button(label="clear index", type="primary", on_click=clear_index)
col2.button(label="clear conversation", type="primary", on_click=clear_chat_history)
col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True)
if st.session_state["generated"]:
for i in range(len(st.session_state["generated"])):
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
message(st.session_state['generated'][i], key=str(i))
with st.expander("See context"):
st.write(st.session_state['contexts'][i])
with st.expander("See chunks"):
st.write(st.session_state['chunks'][i])
with st.expander("See costs"):
st.write(st.session_state['costs'][i])
dataset = st.session_state['dataset']
index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False
without_source = app.params['source_id'] == 4
enable_chat_button = index_built or without_source
st.text_input("You:",
key='user_input',
on_change=on_enter,
disabled=not enable_chat_button
)
st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True)
# hidden div with anchor
st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True)
def load_sidebar_page():
st.sidebar.markdown("## Instructions")
# ############ #
# SOURCES TYPE #
# ############ #
st.sidebar.markdown("1. Select a source:")
source_selected = st.sidebar.selectbox(
"Choose the location of your info to give context to chatgpt",
[key for key, value in SOURCES_IDS.items()])
app.params['source_id'] = SOURCES_IDS.get(source_selected, None)
# ##### #
# MODEL #
# ##### #
st.sidebar.markdown("2. Select a model (LLM):")
combination_selected = st.sidebar.selectbox(
"Choose type: MSF Azure OpenAI and model / OpenAI",
[key for key, value in TYPE_IDS.items()])
app.params['type_id'] = TYPE_IDS.get(combination_selected, None)
if app.params['type_id'] == 1: # with AzureOpenAI endpoint
# https://docs.streamlit.io/library/api-reference/widgets/st.text_input
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
label="Enter Azure OpenAI API Key",
type="password"
).strip()
app.params['api_base'] = st.sidebar.text_input(
label="Enter Azure API base",
placeholder="https://<api_base_endpoint>.openai.azure.com/",
).strip()
app.params['deployment_id'] = st.sidebar.text_input(
label="Enter Azure deployment_id",
).strip()
elif app.params['type_id'] == 2: # with OpenAI endpoint
os.environ['OPENAI_API_KEY'] = st.sidebar.text_input(
label="Enter OpenAI API Key",
placeholder="sk-...",
type="password"
).strip()
app.params['api_base'] = "https://api.openai.com/v1"
app.params['deployment_id'] = None
# ####### #
# INDEXES #
# ####### #
st.sidebar.markdown("3. Select a index store:")
index_selected = st.sidebar.selectbox(
"Type of Index",
[key for key, value in INDEX_IDS.items()])
app.params['index_id'] = INDEX_IDS.get(index_selected, None)
if app.params['index_id'] == 2: # with pinecone
os.environ['PINECONE_API_KEY'] = st.sidebar.text_input(
label="Enter pinecone API Key",
type="password"
).strip()
os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input(
label="Enter pinecone environment",
placeholder="eu-west1-gcp",
).strip()
# ############## #
# CONFIGURATIONS #
# ############## #
st.sidebar.markdown("4. Choose configuration:")
# https://docs.streamlit.io/library/api-reference/widgets/st.number_input
max_sources = st.sidebar.number_input(
label="Top-k: Number of chunks/sections (1-5)",
step=1,
format="%d",
value=5
)
app.params['max_sources'] = max_sources
temperature = st.sidebar.number_input(
label="Temperature (0.0 – 1.0)",
step=0.1,
format="%f",
value=0.0,
min_value=0.0,
max_value=1.0
)
app.params['temperature'] = round(temperature, 1)
# ############## #
# UPLOAD SOURCES #
# ############## #
app.params['uploaded_files_rows'] = []
if app.params['source_id'] == 1:
# https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader
# https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb
st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""")
uploaded_files = st.sidebar.file_uploader(
"Choose files",
accept_multiple_files=True,
type=['pdf', 'PDF',
'txt', 'TXT',
'html',
'docx', 'DOCX',
'pptx', 'PPTX',
],
)
uploaded_files_dataset = request_pathname(uploaded_files)
uploaded_files_df = pd.DataFrame(
uploaded_files_dataset,
columns=['filepath', 'citation string'])
uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df)
uploaded_files_grid_options_builder.configure_selection(
selection_mode='multiple',
pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [],
use_checkbox=True,
)
uploaded_files_grid_options_builder.configure_column("citation string", editable=True)
uploaded_files_grid_options_builder.configure_auto_height()
uploaded_files_grid_options = uploaded_files_grid_options_builder.build()
with st.sidebar:
uploaded_files_ag_grid = AgGrid(
uploaded_files_df,
gridOptions=uploaded_files_grid_options,
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
)
app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"]
app.params['urls_df'] = pd.DataFrame()
if app.params['source_id'] == 3:
st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""")
# option 1: with streamlit version 1.20.0+
# app.params['urls_df'] = st.sidebar.experimental_data_editor(
# pd.DataFrame([["", ""]], columns=['urls', 'citation string']),
# use_container_width=True,
# num_rows="dynamic",
# )
# option 2: with streamlit version 1.19.0
urls_dataset = [["", ""],
["", ""],
["", ""],
["", ""],
["", ""]]
urls_df = pd.DataFrame(
urls_dataset,
columns=['urls', 'citation string'])
urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df)
urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True)
urls_grid_options_builder.configure_auto_height()
urls_grid_options = urls_grid_options_builder.build()
with st.sidebar:
urls_ag_grid = AgGrid(
urls_df,
gridOptions=urls_grid_options,
update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED,
)
df = urls_ag_grid.data
df = df[df.urls != ""]
app.params['urls_df'] = df
if app.params['source_id'] in (1, 2, 3):
st.sidebar.markdown("""6. Build an index where you can ask""")
api_keys_ready = check_api_keys()
source_ready = check_sources()
enable_index_button = api_keys_ready and source_ready
if st.sidebar.button("Build index", disabled=not enable_index_button):
collect_dataset_and_built_index()
def main():
configure_streamlit_and_page()
load_sidebar_page()
load_main_page()
def on_enter():
output = get_answer()
if output:
st.session_state.past.append(st.session_state['user_input'])
st.session_state.generated.append(output.answer)
st.session_state.contexts.append(output.context)
st.session_state.chunks.append(output.chunks)
st.session_state.costs.append(output.cost_str)
st.session_state['user_input'] = ""
def request_pathname(files):
if not files:
return [["", ""]]
# check if temporal directory exist, if not create it
if not Path.exists(TEMP_DIR):
TEMP_DIR.mkdir(
parents=True,
exist_ok=True,
)
file_paths = []
for file in files:
# # absolut path
# file_path = str(TEMP_DIR / file.name)
# relative path
file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR))
file_paths.append(file_path)
with open(file_path, "wb") as f:
f.write(file.getbuffer())
return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)]
def validate_status():
source_point_ready = check_source_point()
combination_point_ready = check_combination_point()
index_point_ready = check_index_point()
params_point_ready = check_params_point()
sources_ready = check_sources()
index_ready = check_index()
if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready:
app.params['status'] = "✨Ready✨"
elif not source_point_ready:
app.params['status'] = "⚠️Review step 1 on the sidebar."
elif not combination_point_ready:
app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..."
elif not index_point_ready:
app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment."
elif not params_point_ready:
app.params['status'] = "⚠️Review step 4 on the sidebar"
elif not sources_ready:
app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..."
elif not index_ready:
app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..."
else:
app.params['status'] = "⚠️Something is not ready..."
class StreamlitLangchainChatApp():
def __init__(self) -> None:
"""Use __init__ to define instance variables. It cannot have any arguments."""
self.params = dict()
def run(self, **state) -> None:
"""Define here all logic required by your application."""
main()
if __name__ == "__main__":
app = StreamlitLangchainChatApp()
app.run()