|
import streamlit as st
|
|
from Retrieval import bsic_chain, multiQuery_chain
|
|
import time
|
|
|
|
from htmlTemplates import css, bot_template, user_template, source_template
|
|
|
|
st.set_page_config(page_title="Chat with StockGPT", page_icon=":currency_exchange:")
|
|
st.write(css, unsafe_allow_html=True)
|
|
|
|
def main():
|
|
|
|
st.sidebar.title("Guideline")
|
|
st.sidebar.markdown("""
|
|
1. Type your message in the chat box on the right.
|
|
2. Hit Enter or click the send button to send your message.
|
|
3. Chat bot responses will appear below.
|
|
4. Source documents will be displayed in the sidebar.
|
|
""")
|
|
|
|
|
|
model_selection = st.sidebar.selectbox("Select Model", ["Basic", "MultiQuery"])
|
|
print(model_selection)
|
|
|
|
|
|
st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
|
|
'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
|
|
'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
|
|
unsafe_allow_html=True)
|
|
|
|
st.title("StockGPT Chat App")
|
|
|
|
|
|
user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
|
|
|
|
|
|
js_submit = f"""
|
|
document.addEventListener("keydown", function(event) {{
|
|
if (event.code === "Enter" && !event.shiftKey) {{
|
|
document.querySelector(".stTextInput").dispatchEvent(new Event("submit"));
|
|
}}
|
|
}});
|
|
"""
|
|
st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
|
|
|
|
if st.button("Send"):
|
|
if user_input:
|
|
with st.spinner('Waiting for response...'):
|
|
|
|
response, metadata, source_documents = generate_bot_response(user_input, model_selection)
|
|
st.write(user_template.replace("{{MSG}}", user_input), unsafe_allow_html=True)
|
|
st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)
|
|
|
|
|
|
st.sidebar.title("Source Documents")
|
|
for i, doc in enumerate(source_documents, 1):
|
|
tit = metadata[i-1]["source"].split("\\")[-1]
|
|
with st.sidebar.expander(f"{tit}"):
|
|
st.write(doc)
|
|
|
|
def generate_bot_response(user_input, model):
|
|
|
|
start_time = time.time()
|
|
print(f"User Input: {user_input}")
|
|
|
|
if model == "Basic":
|
|
res = bsic_chain(user_input)
|
|
elif model == "MultiQuery":
|
|
res = multiQuery_chain(user_input)
|
|
|
|
response = res['result']
|
|
metadata = [i.metadata for i in res.get("source_documents", [])]
|
|
end_time = time.time()
|
|
response_time = end_time - start_time
|
|
print(f"Response Time: {response_time} seconds")
|
|
return response, metadata, res.get('source_documents', [])
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|