Govind commited on
Commit
aa774c1
·
1 Parent(s): aa2e117

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import os
2
+ #os.system("bash setup.sh")
3
+
4
+ import streamlit as st
5
+ # import fitz # PyMuPDF for extracting text from PDFs
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.docstore.document import Document
10
+ from langchain.llms import HuggingFacePipeline
11
+ from langchain.chains import RetrievalQA
12
+ from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
13
+ import torch
14
+ import re
15
+ import transformers
16
+ from torch import bfloat16
17
+ from langchain_community.document_loaders import DirectoryLoader
18
+
19
+ # Initialize embeddings and ChromaDB
20
+ model_name = "sentence-transformers/all-mpnet-base-v2"
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ model_kwargs = {"device": device}
23
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
24
+
25
+ # loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True)
26
+ loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True)
27
+ docs = loader.load()
28
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
29
+ all_splits = text_splitter.split_documents(docs)
30
+ vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db")
31
+ books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings)
32
+
33
+ books_db_client = books_db.as_retriever()
34
+
35
+ # Initialize the model and tokenizer
36
+ model_name = "stabilityai/stablelm-zephyr-3b"
37
+
38
+ bnb_config = transformers.BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_quant_type='nf4',
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_compute_dtype=torch.bfloat16
43
+ )
44
+
45
+ model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
46
+ model = transformers.AutoModelForCausalLM.from_pretrained(
47
+ model_name,
48
+ trust_remote_code=True,
49
+ config=model_config,
50
+ quantization_config=bnb_config,
51
+ device_map=device,
52
+ )
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+
56
+ query_pipeline = transformers.pipeline(
57
+ "text-generation",
58
+ model=model,
59
+ tokenizer=tokenizer,
60
+ return_full_text=True,
61
+ torch_dtype=torch.float16,
62
+ device_map=device,
63
+ temperature=0.7,
64
+ top_p=0.9,
65
+ top_k=50,
66
+ max_new_tokens=256
67
+ )
68
+
69
+ llm = HuggingFacePipeline(pipeline=query_pipeline)
70
+
71
+ books_db_client_retriever = RetrievalQA.from_chain_type(
72
+ llm=llm,
73
+ chain_type="stuff",
74
+ retriever=books_db_client,
75
+ verbose=True
76
+ )
77
+
78
+ st.title("RAG System with ChromaDB")
79
+
80
+ if 'messages' not in st.session_state:
81
+ st.session_state.messages = [{'role': 'assistant', "content": 'Hello! Upload PDF files and ask me anything about their content.'}]
82
+
83
+ # Function to retrieve answer using the RAG system
84
+ def test_rag(qa, query):
85
+ return qa.run(query)
86
+
87
+ user_prompt = st.chat_input("Ask me anything about the content of the PDF(s):")
88
+ if user_prompt:
89
+ st.session_state.messages.append({'role': 'user', "content": user_prompt})
90
+ books_retriever = test_rag(books_db_client_retriever, user_prompt)
91
+ # Extracting the relevant answer using regex
92
+ corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
93
+
94
+ if corrected_text_match:
95
+ corrected_text_books = corrected_text_match.group(1).strip()
96
+ else:
97
+ corrected_text_books = "No helpful answer found."
98
+ st.session_state.messages.append({'role': 'assistant', "content": corrected_text_books})
99
+
100
+ for message in st.session_state.messages:
101
+ with st.chat_message(message['role']):
102
+ st.write(message['content'])