|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.llms import OpenAIChat |
|
from langchain.chains import ConversationalRetrievalChain, LLMChain |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings |
|
from langchain.callbacks import StdOutCallbackHandler |
|
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT |
|
from langchain.vectorstores import FAISS |
|
from langchain.memory import ConversationBufferMemory |
|
import os |
|
from typing import Optional, Tuple |
|
import gradio as gr |
|
import pickle |
|
from threading import Lock |
|
|
|
from langchain.prompts.chat import ( |
|
ChatPromptTemplate, |
|
SystemMessagePromptTemplate, |
|
AIMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
) |
|
from langchain.schema import ( |
|
AIMessage, |
|
HumanMessage, |
|
SystemMessage |
|
) |
|
|
|
from langchain.prompts import PromptTemplate |
|
|
|
prefix_messages = [{"role": "system", "content": "You are a helpful assistant that is very good at answering questions about investments using the information given."}] |
|
|
|
model_options = {'all-mpnet-base-v2': "sentence-transformers/all-mpnet-base-v2", |
|
'instructor-base': "hkunlp/instructor-base"} |
|
|
|
model_options_list = list(model_options.keys()) |
|
|
|
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer') |
|
|
|
def load_prompt(): |
|
|
|
system_template="""Use only the following pieces of context that has been scraped from a website to answer the users question accurately. |
|
Do not use any information not provided in the website context. |
|
If you don't know the answer, just say 'There is no relevant answer in the Investor Website', |
|
don't try to make up an answer. |
|
|
|
ALWAYS return a "SOURCES" part in your answer. |
|
The "SOURCES" part should be a reference to the source of the document from which you got your answer. |
|
|
|
Remember, do not reference any information not given in the context. |
|
If the answer is not available in the given context just say 'There is no relevant answer in the website content' |
|
|
|
Follow the below format when answering: |
|
|
|
Question: {question} |
|
SOURCES: [xyz] |
|
|
|
Begin! |
|
---------------- |
|
{context}""" |
|
|
|
messages = [ |
|
SystemMessagePromptTemplate.from_template(system_template), |
|
HumanMessagePromptTemplate.from_template("{question}") |
|
] |
|
prompt = ChatPromptTemplate.from_messages(messages) |
|
|
|
return prompt |
|
|
|
def load_vectorstore(model): |
|
'''load embeddings and vectorstore''' |
|
|
|
if 'mpnet' in model: |
|
|
|
emb = HuggingFaceEmbeddings(model_name=model) |
|
return FAISS.load_local('vanguard-embeddings', emb) |
|
|
|
elif 'instructor'in model: |
|
|
|
emb = HuggingFaceInstructEmbeddings(model_name=model, |
|
query_instruction='Represent the Financial question for retrieving supporting paragraphs: ', |
|
embed_instruction='Represent the Financial paragraph for retrieval: ') |
|
return FAISS.load_local('vanguard_embeddings_inst', emb) |
|
|
|
|
|
vectorstore = load_vectorstore(model_options['all-mpnet-base-v2']) |
|
|
|
def on_value_change(change): |
|
'''When radio changes, change the embeddings''' |
|
global vectorstore |
|
vectorstore = load_vectorstore(model_options[change]) |
|
|
|
|
|
|
|
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. |
|
You can assume the question about investing and the investment management industry. |
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question:""" |
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
template = """You are an AI assistant for answering questions about investing and the investment management industry. |
|
You are given the following extracted parts of a long document and a question. Provide a conversational answer. |
|
If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer. |
|
If the question is not about investing, politely inform them that you are tuned to only answer questions about investing and the investment management industry. |
|
Question: {question} |
|
========= |
|
{context} |
|
========= |
|
Answer in Markdown:""" |
|
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"]) |
|
|
|
|
|
def get_chain(vectorstore): |
|
llm = OpenAIChat(streaming=True, |
|
callbacks=[StdOutCallbackHandler()], |
|
verbose=True, |
|
temperature=0, |
|
model_name='gpt-4-0613') |
|
|
|
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) |
|
doc_chain = load_qa_chain(llm=llm,chain_type="stuff",prompt=load_prompt()) |
|
|
|
chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 3}), |
|
question_generator=question_generator, |
|
combine_docs_chain=doc_chain, |
|
memory=memory, |
|
return_source_documents=True, |
|
get_chat_history=lambda h :h) |
|
|
|
|
|
return chain |
|
|
|
def load_chain(): |
|
chain = get_chain(vectorstore) |
|
return chain |
|
|
|
|
|
class ChatWrapper: |
|
|
|
def __init__(self): |
|
self.lock = Lock() |
|
def __call__( |
|
self, inp: str, history: Optional[Tuple[str, str]], chain |
|
): |
|
"""Execute the chat functionality.""" |
|
self.lock.acquire() |
|
try: |
|
history = history or [] |
|
|
|
|
|
|
|
output = chain({"question": inp})["answer"] |
|
history.append((inp, output)) |
|
except Exception as e: |
|
raise e |
|
finally: |
|
self.lock.release() |
|
return history, history |
|
|
|
block = gr.Blocks(css=".gradio-container {background-color: lightgray}") |
|
|
|
with block: |
|
with gr.Row(): |
|
gr.Markdown("<h3><center>Chat-Your-Data (Investor Education)</center></h3>") |
|
embed_but = gr.Button(value='Load QA Chain') |
|
|
|
with gr.Row(): |
|
embeddings = gr.Radio(choices=model_options_list,value=model_options_list[0], label='Choose your Embedding Model', |
|
interactive=True) |
|
embeddings.change(on_value_change, embeddings) |
|
|
|
vectorstore = load_vectorstore(embeddings.value) |
|
|
|
chatbot = gr.Chatbot() |
|
|
|
chat = ChatWrapper() |
|
|
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="What's your question?", |
|
placeholder="Ask questions about Investing", |
|
lines=1, |
|
) |
|
submit = gr.Button(value="Send", variant="secondary").style(full_width=False) |
|
|
|
gr.Examples( |
|
examples=[ |
|
"What are the benefits of investing in ETFs?", |
|
"What is the average cost of investing in a managed fund?", |
|
"At what age can I start investing?", |
|
"Do you offer investment accounts for kids?" |
|
], |
|
inputs=message, |
|
) |
|
|
|
gr.HTML("Demo application of a LangChain chain.") |
|
|
|
gr.HTML( |
|
"<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain π¦οΈπ</a></center>" |
|
) |
|
|
|
state = gr.State() |
|
agent_state = gr.State() |
|
|
|
|
|
submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state]) |
|
message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state]) |
|
|
|
embed_but.click( |
|
load_chain, |
|
outputs=[agent_state], |
|
) |
|
|
|
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-investor-chatchain)") |
|
|
|
block.launch(debug=True) |
|
|