Rag_proj / app.py
sgt444pepper's picture
Update app.py
44018ec verified
raw
history blame
5.12 kB
import gradio as gr
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from litellm import completion
import os
import torch
from sentence_transformers import CrossEncoder
os.environ['GROQ_API_KEY'] = "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"
PROMPT = """/
Ти - віртуальний представник ритейл-компанії та консультант для клієнтів.
Для генерації відповідей, використовуй тільки інформацію з контексту!
Не задавай додаткових запитань, а просто запропонуй наявний в контексті товар!
Твоя ціль відповідати на запитання клієнтів, таким чином допомогти їм.
Ти повинен проконсультувати клієнта у виборі товарів використовуючи контекст.
Відповідай тільки українською мовою, назви товарів та компанії залишай англійською.
У разі, якщо конкретну відповідь знайти не вдалося:
- Відповідай "Я не знаю. Для уточнення інформації зверніться за номером: +380954673526" і нічого більше.
Ти завжди дотримуєшся ввічливого, професійного тону. Формат відповіді має бути простим, зрозумілим і чітким. Уникай довгих пояснень, якщо вони не потрібні.
"""
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6")
def load_documents(file_paths):
documents = []
for path in file_paths:
with open(path, 'r', encoding='utf-8') as file:
documents.append(file.read().strip())
return documents
class DocumentIndexer:
def __init__(self, documents, tokenizer, model):
self.documents = documents
self.bm25 = BM25Okapi([doc.split() for doc in documents])
self.tokenizer = tokenizer
self.model = model
self.index = self.create_faiss_index()
def create_faiss_index(self):
embeddings = self.embed_documents(self.documents)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
def embed_documents(self, docs):
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy()
return embeddings
def search_bm25(self, query, top_k=5):
query_terms = query.split()
scores = self.bm25.get_scores(query_terms)
top_indices = np.argsort(scores)[::-1][:top_k]
return [self.documents[i] for i in top_indices]
def search_semantic(self, query, top_k=5):
query_embedding = self.embed_documents([query])
distances, indices = self.index.search(query_embedding, top_k)
return [self.documents[i] for i in indices[0]]
class Reranker:
def __init__(self, reranker):
self.model = reranker
def rank(self, query, documents):
pairs = [(query, doc) for doc in documents]
scores = self.model.predict(pairs)
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]]
return ranked_docs
class QAChatbot:
def __init__(self, indexer, reranker):
self.indexer = indexer
self.reranker = reranker
def generate_answer(self, query):
bm25_results = self.indexer.search_bm25(query)
semantic_results = self.indexer.search_semantic(query)
combined_results = list(set(bm25_results + semantic_results))
ranked_docs = self.reranker.rank(query, combined_results)
context = "\n".join(ranked_docs[:3])
response = completion(
model="groq/llama3-8b-8192",
messages=[
{
"role": "system",
"content": PROMPT
},
{
"role": "user",
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:",
}
],
)
return response
def chatbot_interface(query, history):
file_paths = ["company.txt", "Base.txt"]
documents = load_documents(file_paths)
indexer = DocumentIndexer(documents, tokenizer, model)
reranker = Reranker(reranker_model)
chatbot = QAChatbot(indexer, reranker)
answer = chatbot.generate_answer(query)
return answer["choices"][0]["message"]["content"]
iface = gr.ChatInterface(fn=chatbot_interface, type="messages")
if __name__ == "__main__":
iface.launch()