File size: 3,591 Bytes
e14ab2c
 
 
 
 
 
 
 
 
 
 
 
 
4a4fa74
 
 
e14ab2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a4fa74
 
 
 
 
 
e14ab2c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import asyncio
from typing import List
from chainlit.types import AskFileResponse
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
from PyPDF2 import PdfReader
import chainlit as cl

# Check if the API key is set
if not os.getenv("OPENAI_API_KEY"):
    raise ValueError("OPENAI_API_KEY environment variable is not set")

# Set up prompts
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)

human_template = "Context:\n{context}\n\nQuestion:\n{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
        self.llm = llm
        self.vector_db = vector_db

    async def arun_pipeline(self, user_query: str):
        context_docs = self.vector_db.similarity_search(user_query, k=2)
        context_list = [doc.page_content for doc in context_docs]
        context_prompt = "\n".join(context_list)
        
        max_context_length = 12000
        if len(context_prompt) > max_context_length:
            context_prompt = context_prompt[:max_context_length]
        
        messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()

        async for chunk in self.llm.astream(messages):
            yield chunk.content

def process_pdf(file: AskFileResponse) -> List[str]:
    pdf_reader = PdfReader(file.content)
    text = "\n".join([page.extract_text() for page in pdf_reader.pages])
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
    return text_splitter.split_text(text)

@cl.on_chat_start
async def on_chat_start():
    files = await cl.AskFileMessage(
        content="Please upload a PDF file to begin!",
        accept=["application/pdf"],
        max_size_mb=20,
    ).send()

    if not files:
        await cl.Message(content="No file was uploaded. Please try again.").send()
        return

    file = files[0]
    msg = cl.Message(content=f"Processing `{file.name}`...")
    await msg.send()

    texts = process_pdf(file)

    embeddings = OpenAIEmbeddings()
    vector_db = Chroma.from_texts(texts, embeddings)
    
    chat_openai = ChatOpenAI()
    retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)
    
    cl.user_session.set("pipeline", retrieval_augmented_qa_pipeline)

    msg.content = f"Processing `{file.name}` done. You can now ask questions!"
    await msg.update()

@cl.on_message
async def main(message: cl.Message):
    pipeline = cl.user_session.get("pipeline")
    if not pipeline:
        await cl.Message(content="Please upload a PDF file first.").send()
        return

    msg = cl.Message(content="")
    try:
        async for chunk in pipeline.arun_pipeline(message.content):
            await msg.stream_token(chunk)
    except Exception as e:
        await cl.Message(content=f"An error occurred: {str(e)}").send()
        return

    await msg.send()