File size: 4,046 Bytes
aa774c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3191ae3
aa774c1
be3f145
 
 
 
 
 
aa774c1
 
8426fd8
 
aa774c1
 
 
 
8426fd8
aa774c1
 
8426fd8
aa774c1
 
 
 
 
 
 
8426fd8
 
aa774c1
 
 
8426fd8
aa774c1
 
8426fd8
aa774c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
047b487
801253f
a6529a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa774c1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#import os
#os.system("bash setup.sh")

import streamlit as st
# import fitz  # PyMuPDF for extracting text from PDFs
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
import torch
import re
import transformers
from torch import bfloat16
from langchain_community.document_loaders import DirectoryLoader

# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

# loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True)
loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db")
books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings)

books_db_client = books_db.as_retriever()

# Initialize the model and tokenizer
model_name = "unsloth/Llama-3.2-3B-Instruct"

# bnb_config = transformers.BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)


model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    config=model_config,
    device_map="auto" if device == "cuda" else None,
)


tokenizer = AutoTokenizer.from_pretrained(model_name)

query_pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    max_new_tokens=128  # Reduce this from 256
)


llm = HuggingFacePipeline(pipeline=query_pipeline)

books_db_client_retriever = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=books_db_client,
    verbose=True
)

st.title("RAG System with ChromaDB")

if 'messages' not in st.session_state:
    st.session_state.messages = [{'role': 'assistant', "content": 'Hello! Upload PDF files and ask me anything about their content.'}]

# Function to retrieve answer using the RAG system
def test_rag(qa, query):
    return qa.run(query)

user_prompt = st.text_input("Ask me anything about the content of the PDF(s):")
print("user input:", user_prompt)
# if st.button("Submit"):
#     print("user input after submit button: ", user_prompt)
if user_prompt:
    print("user input after if user prompt condition: ", user_prompt)
    st.session_state.messages.append({'role': 'user', "content": user_prompt})
    books_retriever = test_rag(books_db_client_retriever, user_prompt)
    print("books retriver:",books_retriever)
    # Extracting the relevant answer using regex
    corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
    print("corrected text match:", corrected_text_match)
    if corrected_text_match:
        corrected_text_books = corrected_text_match.group(1).strip()
    else:
        corrected_text_books = "No helpful answer found."
    print("corrected text books: ",corrected_text_books)
    st.session_state.messages.append({'role': 'assistant', "content": corrected_text_books})

for message in st.session_state.messages:
    with st.chat_message(message['role']):
        st.write(message['content'])