import os |
from typing import List |
from langchain.embeddings import CohereEmbeddings |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
from langchain.vectorstores.chroma import Chroma |
from langchain.chains import ( |
ConversationalRetrievalChain, |
) |
from langchain.llms.fireworks import Fireworks |
from langchain.chat_models.fireworks import ChatFireworks |
from langchain.prompts.chat import ( |
ChatPromptTemplate, |
SystemMessagePromptTemplate, |
HumanMessagePromptTemplate, |
) |
from langchain.docstore.document import Document |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory |
from langsmith_config import setup_langsmith_config |
import openai |
import fireworks.client |
import chainlit as cl |
setup_langsmith_config() |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
system_template = """Use the following pieces of context to answer the users question. |
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
ALWAYS return a "SOURCES" part in your answer. |
The "SOURCES" part should be a reference to the source of the document from which you got your answer. |
And if the user greets with greetings like Hi, hello, How are you, etc reply accordingly as well. |
Example of your response should be: |
The answer is foo |
SOURCES: xyz |
Begin! |
---------------- |
{summaries}""" |
messages = [ |
SystemMessagePromptTemplate.from_template(system_template), |
HumanMessagePromptTemplate.from_template("{question}"), |
] |
prompt = ChatPromptTemplate.from_messages(messages) |
chain_type_kwargs = {"prompt": prompt} |
@cl.on_chat_start |
async def on_chat_start(): |
files = None |
while files == None: |
files = await cl.AskFileMessage( |
content="Please upload a text file to begin!", |
accept=["text/plain"], |
max_size_mb=20, |
timeout=180, |
).send() |
file = files[0] |
msg = cl.Message( |
content=f"Processing `{file.name}`...", disable_human_feedback=True |
) |
await msg.send() |
text = file.content.decode("utf-8") |
texts = text_splitter.split_text(text) |
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))] |
embeddings = CohereEmbeddings(cohere_api_key="COHERE_API_KEY") |
docsearch = await cl.make_async(Chroma.from_texts)( |
texts, embeddings, metadatas=metadatas |
) |
message_history = ChatMessageHistory() |
memory = ConversationBufferMemory( |
memory_key="chat_history", |
output_key="answer", |
chat_memory=message_history, |
return_messages=True, |
) |
chain = ConversationalRetrievalChain.from_llm( |
ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True), |
chain_type="stuff", |
retriever=docsearch.as_retriever(), |
memory=memory, |
return_source_documents=True, |
) |
msg.content = f"Processing `{file.name}` done. You can now ask questions!" |
await msg.update() |
cl.user_session.set("chain", chain) |
@cl.on_message |
async def main(message: cl.Message): |
chain = cl.user_session.get("chain") |
cb = cl.AsyncLangchainCallbackHandler() |
res = await chain.acall(message.content, callbacks=[cb]) |
answer = res["answer"] |
source_documents = res["source_documents"] |
text_elements = [] |
if source_documents: |
for source_idx, source_doc in enumerate(source_documents): |
source_name = f"source_{source_idx}" |
text_elements.append( |
cl.Text(content=source_doc.page_content, name=source_name) |
) |
source_names = [text_el.name for text_el in text_elements] |
if source_names: |
answer += f"\nSources: {', '.join(source_names)}" |
else: |
answer += "\nNo sources found" |
await cl.Message(content=answer, elements=text_elements).send() |