# main.py import os import streamlit as st import anthropic from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings from langchain_community.vectorstores import SupabaseVectorStore from langchain_community.llms import HuggingFaceEndpoint from langchain_community.vectorstores import SupabaseVectorStore from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from supabase import Client, create_client from streamlit.logger import get_logger from stats import get_usage, add_usage supabase_url = st.secrets.SUPABASE_URL supabase_key = st.secrets.SUPABASE_KEY openai_api_key = st.secrets.openai_api_key anthropic_api_key = st.secrets.anthropic_api_key hf_api_key = st.secrets.hf_api_key username = st.secrets.username supabase: Client = create_client(supabase_url, supabase_key) logger = get_logger(__name__) embeddings = HuggingFaceInferenceAPIEmbeddings( api_key=hf_api_key, model_name="BAAI/bge-large-en-v1.5" ) if 'chat_history' not in st.session_state: st.session_state['chat_history'] = [] vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents") memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True) # model = "mistralai/Mixtral-8x7B-Instruct-v0.1" model = "meta-llama/Meta-Llama-3-70B-Instruct" temperature = 0.1 max_tokens = 500 stats = str(get_usage(supabase)) def response_generator(query): qa = None add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature}) logger.info('Using HF model %s', model) # print(st.session_state['max_tokens']) endpoint_url = ("https://api-inference.huggingface.co/models/"+ model) model_kwargs = {"temperature" : temperature, "max_new_tokens" : max_tokens, # "repetition_penalty" : 1.1, "return_full_text" : False} hf = HuggingFaceEndpoint( endpoint_url=endpoint_url, task="text-generation", huggingfacehub_api_token=hf_api_key, model_kwargs=model_kwargs ) qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True) # Generate model's response model_response = qa({"question": query}) logger.info('Result: %s', model_response["answer"]) sources = model_response["source_documents"] logger.info('Sources: %s', model_response["source_documents"]) if len(sources) > 0: response = model_response["answer"] else: response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email copilot@securade.ai." return response # Set the theme st.set_page_config( page_title="Securade.ai - Safety Copilot", page_icon="https://securade.ai/favicon.ico", layout="centered", initial_sidebar_state="collapsed", menu_items={ "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)", "Get Help" : "https://securade.ai", "Report a Bug": "mailto:hello@securade.ai" } ) st.title("👷‍♂️ Safety Copilot 🦺") st.markdown("Chat with your personal safety assistant about any health & safety related queries.") # st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.") st.markdown("_"+ stats + " queries answered!_") if 'chat_history' not in st.session_state: st.session_state['chat_history'] = [] # Display chat messages from history on app rerun for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.markdown(message["content"]) # Accept user input if prompt := st.chat_input("Ask a question"): # print(prompt) # Add user message to chat history st.session_state.chat_history.append({"role": "user", "content": prompt}) # Display user message in chat message container with st.chat_message("user"): st.markdown(prompt) with st.spinner('Safety briefing in progress...'): response = response_generator(prompt) # Display assistant response in chat message container with st.chat_message("assistant"): st.markdown(response) # Add assistant response to chat history # print(response) st.session_state.chat_history.append({"role": "assistant", "content": response}) # query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500) # columns = st.columns(2) # with columns[0]: # button = st.button("Ask") # with columns[1]: # clear_history = st.button("Clear History", type='secondary') # st.markdown("---\n\n") # if clear_history: # # Clear memory in Langchain # memory.clear() # st.session_state['chat_history'] = [] # st.experimental_rerun()