import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from sentence_transformers import SentenceTransformer from PyPDF2 import PdfReader import numpy as np import torch class RAGChatbot: def __init__(self, model_name="facebook/opt-350m", embedding_model="all-MiniLM-L6-v2"): # Initialize tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # self.bnb_config = BitsAndBytesConfig( # load_in_8bit=True, # Enable 8-bit loading # llm_int8_threshold=6.0, # Threshold for mixed-precision computation # ) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto" ) # Initialize embedding model self.embedding_model = SentenceTransformer(embedding_model) # Initialize document storage self.documents = [] self.embeddings = [] def extract_text_from_pdf(self, pdf_path): reader = PdfReader(pdf_path) text = "" for page in reader.pages: text += page.extract_text() + "\n" return text def load_documents(self, file_paths): self.documents = [] self.embeddings = [] for file_path in file_paths: if file_path.endswith('.pdf'): text = self.extract_text_from_pdf(file_path) else: with open(file_path, 'r', encoding='utf-8') as f: text = f.read() # Split text into chunks chunks = [text[i:i+500] for i in range(0, len(text), 500)] self.documents.extend(chunks) # Generate embeddings self.embeddings = self.embedding_model.encode(self.documents) return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files" def retrieve_relevant_context(self, query, top_k=3): if not self.documents: return "No documents loaded" # Generate query embedding query_embedding = self.embedding_model.encode([query])[0] # Calculate cosine similarities similarities = np.dot(self.embeddings, query_embedding) / ( np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) ) # Get top k most similar documents top_indices = similarities.argsort()[-top_k:][::-1] return " ".join([self.documents[i] for i in top_indices]) def generate_response(self, query, context): # Construct prompt with truncated_context = " ".join(context.split()[:100]) full_prompt = f"Context: {truncated_context}\n\nQuestion: {query}\n\nAnswer:" # Generate response tokens = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(self.model.device) inputs = tokens.input_ids.to(self.model.device) attention_mask = tokens.attention_mask outputs = self.model.generate(inputs, max_new_tokens=128,attention_mask=attention_mask,pad_token_id=self.tokenizer.eos_token_id,repetition_penalty=1.0) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split("Answer:")[-1].strip() def chat(self, query, history): if not query: return history, "" try: # Retrieve relevant context context = self.retrieve_relevant_context(query) # Generate response response = self.generate_response(query, context) # Append to history using messages format updated_history = history + [ {"role": "user", "content": query}, {"role": "assistant", "content": response} ] return updated_history, "" except Exception as e: error_response = f"An error occurred: {str(e)}" updated_history = history + [ {"role": "user", "content": query}, {"role": "assistant", "content": error_response} ] return updated_history, "" # Create Gradio interface def create_interface(): rag_chatbot = RAGChatbot() with gr.Blocks() as demo: gr.Markdown("# Ask your PDf!") with gr.Row(): file_input = gr.File(label="Upload Documents", file_count="multiple", type="filepath") load_btn = gr.Button("Load Documents") status_output = gr.Textbox(label="Load Status") chatbot = gr.Chatbot(type="messages") # Specify message type msg = gr.Textbox(label="Enter your query") submit_btn = gr.Button("Send") clear_btn = gr.Button("Clear Chat") # Event handlers load_btn.click( rag_chatbot.load_documents, inputs=[file_input], outputs=[status_output] ) submit_btn.click( rag_chatbot.chat, inputs=[msg, chatbot], outputs=[chatbot, msg] ) msg.submit( rag_chatbot.chat, inputs=[msg, chatbot], outputs=[chatbot, msg] ) clear_btn.click(lambda: (None, ""), None, [chatbot, msg]) return demo # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch()