Spaces:
Runtime error
Runtime error
import os | |
import chainlit as cl | |
from dotenv import load_dotenv | |
from operator import itemgetter | |
from langchain_community.vectorstores import Qdrant | |
from qdrant_client import QdrantClient | |
from langchain_openai.chat_models import ChatOpenAI | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain_openai import OpenAIEmbeddings | |
from helpers import process_file, add_to_qdrant | |
load_dotenv() | |
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
llm = ChatOpenAI(model="gpt-4") | |
qdrant_client = QdrantClient(url=constants.QDRANT_ENDPOINT, api_key=constants.QDRANT_API_KEY) # TO DO: Add constants, info from Mark | |
collection_name = "marketing_data" | |
RAG_PROMPT_TEMPLATE = """\ | |
<|start_header_id|>system<|end_header_id|> | |
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|> | |
<|start_header_id|>user<|end_header_id|> | |
User Query: | |
{query} | |
Context: | |
{context}<|eot_id|> | |
<|start_header_id|>assistant<|end_header_id|> | |
""" | |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE) | |
### On Chat Start (Session Start) ### | |
async def on_chat_start(): | |
files = None | |
# Wait for the user to upload a file | |
while files == None: | |
files = await cl.AskFileMessage( | |
content="Please ask a question or upload a Text or PDF File file to begin!", | |
accept=["text/plain", "application/pdf"], | |
max_size_mb=2, | |
timeout=180, | |
).send() | |
file = files[0] | |
msg = cl.Message( | |
content=f"Processing `{file.name}`...", disable_human_feedback=True | |
) | |
await msg.send() | |
# load the file | |
docs = process_file(file) | |
for i, doc in enumerate(docs): | |
doc.metadata["source"] = f"source_{i}" # TO DO: Add metadata | |
add_to_qdrant(doc, embeddings, qdrant_client, collection_name) | |
print(f"Processing {len(docs)} text chunks") | |
# Create the vectorstore | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
vectorstore = Qdrant.from_documents( | |
documents=splits, | |
embedding=embeddings, | |
location=":memory:" # TO DO: Add Qdrant server URL | |
) | |
retriever = vectorstore.as_retriever() | |
# Create a chain | |
rag_chain = ( | |
{"context": itemgetter("question") | retriever, "question": itemgetter("question")} | |
| rag_prompt | llm | StrOutputParser() | |
) | |
# Let the user know that the system is ready | |
msg.content = f"Processing `{file.name}` done. You can now ask questions!" | |
await msg.update() | |
cl.user_session.set("chain", rag_chain) | |
async def main(message: cl.Message): | |
chain = cl.user_session.get("chain") | |
result = chain.invoke({"question":message.content}) | |
msg = cl.Message(content=result) | |
await msg.send() |