import gradio as gr import faiss import numpy as np from rank_bm25 import BM25Okapi from transformers import AutoTokenizer, AutoModel from litellm import completion import os import torch from sentence_transformers import CrossEncoder os.environ['GROQ_API_KEY'] = "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" PROMPT = """/ You are a virtual representative of a retail company and a consultant for customers. To generate answers, use only information from the context! Do not ask additional questions, but simply offer the product available in the context! Your goal is to answer customers' questions, thus helping them. You should advise the customer in choosing products using the context. If you could not find a specific answer: - Answer "I do not know. For more information, please contact: +380954673526" and nothing more. You always maintain a polite, professional tone. The format of the answer should be simple, understandable and clear. Avoid long explanations if they are not necessary. """ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2") model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6") def load_documents(file_paths): documents = [] for path in file_paths: with open(path, 'r', encoding='utf-8') as file: documents.append(file.read().strip()) return documents def load_documents_with_chunking(file_paths, chunk_size=500): documents = [] for path in file_paths: with open(path, 'r', encoding='utf-8') as file: text = file.read().strip() for i in range(0, len(text), chunk_size): chunk = text[i:i + chunk_size] documents.append(chunk) return documents class Retriver: def __init__(self, documents, tokenizer, model): self.documents = documents self.bm25 = BM25Okapi([doc.split() for doc in documents]) self.tokenizer = tokenizer self.model = model self.index = self.create_faiss_index() def create_faiss_index(self): embeddings = self.embed_documents(self.documents) dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) return index def embed_documents(self, docs): tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy() return embeddings def search_bm25(self, query, top_k=5): query_terms = query.split() scores = self.bm25.get_scores(query_terms) top_indices = np.argsort(scores)[::-1][:top_k] return [self.documents[i] for i in top_indices] def search_semantic(self, query, top_k=5): query_embedding = self.embed_documents([query]) distances, indices = self.index.search(query_embedding, top_k) return [self.documents[i] for i in indices[0]] class Reranker: def __init__(self, reranker): self.model = reranker def rank(self, query, documents): pairs = [(query, doc) for doc in documents] scores = self.model.predict(pairs) ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]] return ranked_docs class QAChatbot: def __init__(self, indexer, reranker): self.indexer = indexer self.reranker = reranker def generate_answer(self, query): bm25_results = self.indexer.search_bm25(query) semantic_results = self.indexer.search_semantic(query) combined_results = list(set(bm25_results + semantic_results)) ranked_docs = self.reranker.rank(query, combined_results) context = "\n".join(ranked_docs[:3]) response = completion( model="groq/llama3-8b-8192", messages=[ { "role": "system", "content": PROMPT }, { "role": "user", "content": f"Context: {context}\n\nQuestion: {query}\nAnswer:", } ], ) return response def chatbot_interface(query, history): # file_paths = ["Company_eng.txt", "base_eng.txt"] # documents = load_documents(file_paths) # indexer = Retriver(documents, tokenizer, model) # reranker = Reranker(reranker_model) #chatbot = QAChatbot(indexer, reranker) answer = chatbot.generate_answer(query) return answer["choices"][0]["message"]["content"] iface = gr.ChatInterface(fn=chatbot_interface, type="messages") if __name__ == "__main__": file_paths = ["Company_eng.txt", "base_eng.txt"] documents = load_documents(file_paths) indexer = Retriver(documents, tokenizer, model) reranker = Reranker(reranker_model) chatbot = QAChatbot(indexer, reranker) iface.launch()