from langchain.prompts.prompt import PromptTemplate from langchain.llms import OpenAIChat from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings from langchain.callbacks import StdOutCallbackHandler from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.vectorstores import FAISS from langchain.memory import ConversationBufferMemory import os from typing import Optional, Tuple import gradio as gr import pickle from threading import Lock from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema import ( AIMessage, HumanMessage, SystemMessage ) from langchain.prompts import PromptTemplate prefix_messages = [{"role": "system", "content": "You are a helpful assistant that is very good at answering questions about investments using the information given."}] site_options = {'US': 'vanguard_embeddings_US', 'AUS': 'vanguard_embeddings'} site_options_list = list(site_options.keys()) memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer') def load_prompt(): system_template="""Use only the following pieces of context that has been scraped from a website to answer the users question accurately. Do not use any information not provided in the website context. If you don't know the answer, just say 'There is no relevant answer in the Investor Website', don't try to make up an answer. ALWAYS return a "SOURCES" part in your answer. The "SOURCES" part should be a reference to the source of the document from which you got your answer. Remember, do not reference any information not given in the context. If the answer is not available in the given context just say 'There is no relevant answer in the website content' Follow the below format when answering: Question: {question} SOURCES: [xyz] Begin! ---------------- {context}""" messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}") ] prompt = ChatPromptTemplate.from_messages(messages) return prompt def load_vectorstore(site): '''load embeddings and vectorstore''' emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2") return FAISS.load_local(site_options[site], emb) #default embeddings and store vectorstore = load_vectorstore(site_options_list[0]) def on_value_change(site): '''When radio changes, change the website reference data''' global vectorstore vectorstore = load_vectorstore(site) # vectorstore = load_vectorstore('vanguard-embeddings',sbert_emb) _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. You can assume the question about investing and the investment management industry. Chat History: {chat_history} Follow Up Input: {question} Standalone question:""" CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) template = """You are an AI assistant for answering questions about investing and the investment management industry. You are given the following extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer. If the question is not about investing, politely inform them that you are tuned to only answer questions about investing and the investment management industry. Question: {question} ========= {context} ========= Answer in Markdown:""" QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"]) def get_chain(vectorstore): llm = OpenAIChat(streaming=True, callbacks=[StdOutCallbackHandler()], verbose=True, temperature=0, model_name='gpt-4') question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) doc_chain = load_qa_chain(llm=llm,chain_type="stuff",prompt=load_prompt()) chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 4}), question_generator=question_generator, combine_docs_chain=doc_chain, memory=memory, return_source_documents=True, get_chat_history=lambda h :h) return chain def load_chain(): chain = get_chain(vectorstore) return chain class ChatWrapper: def __init__(self): self.lock = Lock() def __call__( self, inp: str, history: Optional[Tuple[str, str]], chain ): """Execute the chat functionality.""" self.lock.acquire() try: history = history or [] # Set OpenAI key # chain = get_chain(vectorstore) # Run chain and append input. output = chain({"question": inp})["answer"] history.append((inp, output)) except Exception as e: raise e finally: self.lock.release() return history, history block = gr.Blocks(css=".gradio-container {background-color: lightgray}") with block: with gr.Row(): gr.Markdown("

Chat-Your-Data (Investor Education)

") embed_but = gr.Button(value='Step 1: Click Me to Load the QA System') with gr.Row(): websites = gr.Radio(choices=site_options_list,value=site_options_list[0],label='Select US or AUS website data', interactive=True) websites.change(on_value_change, websites) vectorstore = load_vectorstore(websites.value) chatbot = gr.Chatbot() chat = ChatWrapper() with gr.Row(): message = gr.Textbox( label="What's your question?", placeholder="Ask questions about Investing", lines=1, ) submit = gr.Button(value="Send", variant="secondary").style(full_width=False) gr.Examples( examples=[ "What are the benefits of investing in ETFs?", "What is the average cost of investing in a managed fund?", "At what age can I start investing?", "Do you offer investment accounts for kids?" ], inputs=message, ) gr.HTML("Demo application of a LangChain chain.") gr.HTML( "
Powered by LangChain 🦜️🔗
" ) state = gr.State() agent_state = gr.State() submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state]) message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state]) embed_but.click( load_chain, outputs=[agent_state], ) gr.Markdown("![](https://komarev.com/ghpvc/?username=nickmuchi87&style=flat-square)") block.launch(debug=True)