AIE4MidtermV2 / app.py
Technocoloredgeek's picture
Update app.py
4a4fa74 verified
import os
import asyncio
from typing import List
from chainlit.types import AskFileResponse
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
from PyPDF2 import PdfReader
import chainlit as cl
# Check if the API key is set
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY environment variable is not set")
# Set up prompts
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
self.llm = llm
self.vector_db = vector_db
async def arun_pipeline(self, user_query: str):
context_docs = self.vector_db.similarity_search(user_query, k=2)
context_list = [doc.page_content for doc in context_docs]
context_prompt = "\n".join(context_list)
max_context_length = 12000
if len(context_prompt) > max_context_length:
context_prompt = context_prompt[:max_context_length]
messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()
async for chunk in self.llm.astream(messages):
yield chunk.content
def process_pdf(file: AskFileResponse) -> List[str]:
pdf_reader = PdfReader(file.content)
text = "\n".join([page.extract_text() for page in pdf_reader.pages])
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
return text_splitter.split_text(text)
@cl.on_chat_start
async def on_chat_start():
files = await cl.AskFileMessage(
content="Please upload a PDF file to begin!",
accept=["application/pdf"],
max_size_mb=20,
).send()
if not files:
await cl.Message(content="No file was uploaded. Please try again.").send()
return
file = files[0]
msg = cl.Message(content=f"Processing `{file.name}`...")
await msg.send()
texts = process_pdf(file)
embeddings = OpenAIEmbeddings()
vector_db = Chroma.from_texts(texts, embeddings)
chat_openai = ChatOpenAI()
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)
cl.user_session.set("pipeline", retrieval_augmented_qa_pipeline)
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
@cl.on_message
async def main(message: cl.Message):
pipeline = cl.user_session.get("pipeline")
if not pipeline:
await cl.Message(content="Please upload a PDF file first.").send()
return
msg = cl.Message(content="")
try:
async for chunk in pipeline.arun_pipeline(message.content):
await msg.stream_token(chunk)
except Exception as e:
await cl.Message(content=f"An error occurred: {str(e)}").send()
return
await msg.send()