danicafisher's picture
Update app.py
8357756 verified
import os
import chainlit as cl
from dotenv import load_dotenv
from operator import itemgetter
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain_openai.chat_models import ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_openai import OpenAIEmbeddings
from helpers import process_file, add_to_qdrant
load_dotenv()
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
llm = ChatOpenAI(model="gpt-4")
qdrant_client = QdrantClient(url=constants.QDRANT_ENDPOINT, api_key=constants.QDRANT_API_KEY) # TO DO: Add constants, info from Mark
collection_name = "marketing_data"
RAG_PROMPT_TEMPLATE = """\
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
User Query:
{query}
Context:
{context}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
### On Chat Start (Session Start) ###
@cl.on_chat_start
async def on_chat_start():
files = None
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please ask a question or upload a Text or PDF File file to begin!",
accept=["text/plain", "application/pdf"],
max_size_mb=2,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# load the file
docs = process_file(file)
for i, doc in enumerate(docs):
doc.metadata["source"] = f"source_{i}" # TO DO: Add metadata
add_to_qdrant(doc, embeddings, qdrant_client, collection_name)
print(f"Processing {len(docs)} text chunks")
# Create the vectorstore
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
vectorstore = Qdrant.from_documents(
documents=splits,
embedding=embeddings,
location=":memory:" # TO DO: Add Qdrant server URL
)
retriever = vectorstore.as_retriever()
# Create a chain
rag_chain = (
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
| rag_prompt | llm | StrOutputParser()
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", rag_chain)
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
result = chain.invoke({"question":message.content})
msg = cl.Message(content=result)
await msg.send()