evanperez commited on
Commit
6035089
1 Parent(s): aafb505

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PyPDF2 import PdfReader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
5
+ import google.generativeai as genai
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain.chains.question_answering import load_qa_chain
9
+ from langchain.prompts import PromptTemplate
10
+ import os
11
+ import json
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer, ConversationalPipeline
13
+
14
+ ####CREDIT#####
15
+ # Credit to author (Sri Laxmi) of original code reference: SriLaxmi1993
16
+ # Sri LaxmiGithub Link: https://github.com/SriLaxmi1993/Document-Genie-using-RAG-Framwork
17
+ # Sri Laxmi Youtube:https://www.youtube.com/watch?v=SkY2u4UUr6M&t=112s
18
+ ###############
19
+ os.system("pip install -r requirements.txt")
20
+
21
+ # some model
22
+
23
+ #tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
24
+ #model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")
25
+
26
+
27
+ st.set_page_config(page_title="Gemini RAG", layout="wide")
28
+
29
+ # This is the first API key input; no need to repeat it in the main function.
30
+ api_key = 'AIzaSyCvXRggpO2yNwIpZmoMy_5Xhm2bDyD-pOo'
31
+
32
+
33
+ #os.mkdir('faiss_index')
34
+
35
+ # empty faiss_index and chat_history.json
36
+ def delete_files_in_folder(folder_path):
37
+ try:
38
+ # Iterate over all the files in the folder
39
+ chat_history_file = "chat_history.json"
40
+ if os.path.exists(chat_history_file):
41
+ os.remove(chat_history_file)
42
+ for file_name in os.listdir(folder_path):
43
+ file_path = os.path.join(folder_path, file_name)
44
+ if os.path.isfile(file_path): # Check if it's a file
45
+ os.remove(file_path) # Delete the file
46
+ print(f"Deleted file: {file_path}")
47
+ print("All files within the folder have been deleted successfully!")
48
+ except Exception as e:
49
+ print(f"An error occurred: {e}")
50
+
51
+
52
+ if st.button("Reset Files", key="reset_button"):
53
+ folder_path = 'faiss_index'
54
+ delete_files_in_folder(folder_path)
55
+
56
+ CH_size = 450
57
+ CH_overlap = 50
58
+
59
+
60
+ def get_pdf_text(pdf_docs):
61
+ text = ""
62
+ for pdf in pdf_docs:
63
+ pdf_reader = PdfReader(pdf)
64
+ for page in pdf_reader.pages:
65
+ text += page.extract_text()
66
+ return text
67
+
68
+
69
+ def get_text_chunks(text):
70
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CH_size, chunk_overlap=CH_overlap)
71
+ chunks = text_splitter.split_text(text)
72
+ return chunks
73
+
74
+
75
+ def get_vector_store(text_chunks, api_key):
76
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
77
+ vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
78
+ vector_store.save_local("faiss_index")
79
+
80
+
81
+ def get_conversational_chain():
82
+ prompt_template = """
83
+ Answer the question as detailed as possible from the provided context, make sure to provide all the details, if the answer is not in
84
+ provided context just say, "answer is not available in the context", don't provide the wrong answer. When giving an answer, try to include all mentionings of the subject being asked and include this within your response\n\n
85
+ Context:\n {context}?\n
86
+ Question: \n{question}\n
87
+
88
+ Answer:
89
+ """
90
+ model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.2, google_api_key=api_key)
91
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
92
+ chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
93
+ return chain
94
+
95
+
96
+ # chat history functionality
97
+ def update_chat_history(question, reply):
98
+ # Check if chat history file exists
99
+ chat_history_file = "chat_history.json"
100
+ if os.path.exists(chat_history_file):
101
+ # If file exists, load existing chat history
102
+ with open(chat_history_file, "r") as file:
103
+ chat_history = json.load(file)
104
+ else:
105
+ # If file doesn't exist, initialize chat history
106
+ chat_history = {"conversations": []}
107
+
108
+ # Add current conversation to chat history
109
+ chat_history["conversations"].append({"question": question, "reply": reply})
110
+
111
+ # Write updated chat history back to file
112
+ with open(chat_history_file, "w") as file:
113
+ json.dump(chat_history, file, indent=4)
114
+ # Display chat history
115
+ st.subheader("Chat History")
116
+ for conversation in chat_history["conversations"]:
117
+ st.write(f"**Question:** {conversation['question']}")
118
+ st.write(f"**Reply:** {conversation['reply']}")
119
+ st.write("---")
120
+
121
+
122
+
123
+ def user_input(user_question, api_key):
124
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
125
+ new_db = FAISS.load_local("faiss_index", embeddings,allow_dangerous_deserialization=True)
126
+ docs = new_db.similarity_search(user_question)
127
+ chain = get_conversational_chain()
128
+ response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
129
+ st.write("Reply: ", response["output_text"])
130
+
131
+ #chat history
132
+ update_chat_history(user_question, response["output_text"])
133
+
134
+ ''''''''''''''''''
135
+
136
+ def clear_faiss_index(folder_path):
137
+ try:
138
+ if os.path.exists(folder_path):
139
+ for file_name in os.listdir(folder_path):
140
+ file_path = os.path.join(folder_path, file_name)
141
+ if os.path.isfile(file_path):
142
+ os.remove(file_path)
143
+ st.write("Existing FAISS index files cleared successfully!")
144
+ else:
145
+ st.write("No existing FAISS index files found.")
146
+ except Exception as e:
147
+ st.error(f"An error occurred while clearing FAISS index files: {e}")
148
+ # Function to process PDF files and recreate FAISS index
149
+
150
+
151
+ def recreate_faiss_index(pdf_docs, chunk_size, chunk_overlap, api_key):
152
+ try:
153
+ # Clear existing FAISS index files
154
+ clear_faiss_index("faiss_index")
155
+
156
+ # Process PDF files and extract text
157
+ text = ""
158
+ for pdf in pdf_docs:
159
+ pdf_reader = PdfReader(pdf)
160
+ for page in pdf_reader.pages:
161
+ text += page.extract_text()
162
+
163
+ # Split text into chunks
164
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
165
+ chunks = text_splitter.split_text(text)
166
+
167
+ # Generate embeddings for text chunks
168
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
169
+ vector_store = FAISS.from_texts(chunks, embedding=embeddings)
170
+
171
+ # Save FAISS index
172
+ vector_store.save_local("faiss_index")
173
+
174
+ st.success("FAISS index recreated successfully!")
175
+ except Exception as e:
176
+ st.error(f"An error occurred while recreating FAISS index: {e}")
177
+
178
+
179
+ def main():
180
+ st.header("RAG based LLM Application")
181
+
182
+ user_question = st.text_input("Ask a Question from the PDF Files", key="user_question")
183
+
184
+ if user_question and api_key:
185
+ user_input(user_question, api_key)
186
+
187
+ with st.sidebar:
188
+ st.title("Menu:")
189
+
190
+ pdf_docs = st.file_uploader("Upload your PDF Files and Click on the Submit & Process Button",
191
+ accept_multiple_files=True, key="pdf_uploader")
192
+ if st.button("Submit & Process", key="process_button") and api_key:
193
+ with st.spinner("Processing..."):
194
+ recreate_faiss_index(pdf_docs, CH_size, CH_overlap, api_key)
195
+
196
+ raw_text = get_pdf_text(pdf_docs)
197
+ text_chunks = get_text_chunks(raw_text)
198
+ get_vector_store(text_chunks, api_key)
199
+ st.success("Done")
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
204
+