Wintersmith commited on
Commit
257dc3a
·
verified ·
1 Parent(s): eb707f3

Upload main_class.py

Browse files
Files changed (1) hide show
  1. main_class.py +92 -0
main_class.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_openai import OpenAIEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_openai import ChatOpenAI
6
+ from langchain.retrievers import ContextualCompressionRetriever
7
+ from langchain.retrievers.document_compressors import LLMChainExtractor
8
+ from langchain.tools.retriever import create_retriever_tool
9
+ from langchain import hub
10
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
11
+ import os
12
+ import gradio as gr
13
+
14
+ # The Agent retriever is based on: https://python.langchain.com/docs/use_cases/question_answering/conversational_retrieval_agents?ref=blog.langchain.dev
15
+ # The chat history is based on: https://python.langchain.com/docs/use_cases/question_answering/chat_history
16
+ # Inspired by https://github.com/Niez-Gharbi/PDF-RAG-with-Llama2-and-Gradio/tree/master
17
+ # Inspired by https://github.com/mirabdullahyaser/Retrieval-Augmented-Generation-Engine-with-LangChain-and-Streamlit/tree/master
18
+
19
+ class PDFChatBot:
20
+ # Initialize the class with the api_key and the model_name
21
+ def __init__(self, api_key):
22
+ self.processed = False
23
+ self.final_agent = None
24
+ self.chat_history = []
25
+ self.api_key = api_key
26
+ self.llm = ChatOpenAI(openai_api_key=self.api_key, temperature=0, model_name="gpt-3.5-turbo-0125")
27
+
28
+ # add text to Gradio text block (not needed without Gradio)
29
+ def add_text(self, history, text):
30
+ if not text:
31
+ raise gr.Error("Please enter text.")
32
+ history.append((text, ''))
33
+ return history
34
+
35
+ # Load a pdf document with langchain textloader
36
+ def load_document(self, file_name):
37
+ loader = PyPDFLoader(file_name)
38
+ raw_document = loader.load()
39
+ return raw_document
40
+
41
+ # Split the document
42
+ def split_documents(self, raw_document):
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000,
44
+ chunk_overlap=100,
45
+ length_function=len,
46
+ is_separator_regex=False,
47
+ separators=["\n\n", "\n", " ", ""])
48
+ chunks = text_splitter.split_documents(raw_document)
49
+ return chunks
50
+
51
+ # Embed the document with OpenAI Embeddings & store it to vectorstore
52
+ def create_retriever(self, chunks):
53
+ embedding_func = OpenAIEmbeddings(openai_api_key=self.api_key)
54
+ # Create a new vectorstore from the chunks
55
+ vectorstore = FAISS.from_documents(chunks, embedding_func)
56
+
57
+ # Create a retriever
58
+ basic_retriever = vectorstore.as_retriever()
59
+ compressor = LLMChainExtractor.from_llm(self.llm)
60
+ compression_retriever = ContextualCompressionRetriever(base_compressor=compressor,
61
+ base_retriever=basic_retriever)
62
+ return basic_retriever # or compression_retriever
63
+
64
+ # Create an agent
65
+ def create_agent(self, retriever):
66
+ tool = create_retriever_tool(retriever,
67
+ f"search_document",
68
+ f"Searches and returns excerpts from the provided document.")
69
+ tools = [tool]
70
+ prompt = hub.pull("hwchase17/openai-tools-agent")
71
+ agent = create_openai_tools_agent(self.llm, tools, prompt)
72
+ self.final_agent = AgentExecutor(agent=agent, tools=tools)
73
+
74
+ # Process files
75
+ def process_file(self, file_name):
76
+ documents = self.load_document(file_name)
77
+ texts = self.split_documents(documents)
78
+ db = self.create_retriever(texts)
79
+ self.create_agent(db)
80
+ print("Files successfully processed")
81
+
82
+ # Generate a response and write to memory
83
+ def generate_response(self, history, query, path):
84
+ if not self.processed:
85
+ self.process_file(path)
86
+ self.processed = True
87
+ result = self.final_agent.invoke({'input': query, 'chat_history': self.chat_history})['output']
88
+ self.chat_history.extend((query, result))
89
+ for char in result: # history argument and the subsequent code is only for the purpose of Gradio
90
+ history[-1][1] += char
91
+ return history, " "
92
+