Spaces:
Sleeping
Sleeping
import gradio as gr | |
import faiss | |
import numpy as np | |
from rank_bm25 import BM25Okapi | |
from transformers import AutoTokenizer, AutoModel | |
from litellm import completion | |
import os | |
import torch | |
from sentence_transformers import CrossEncoder | |
os.environ['GROQ_API_KEY'] = "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" | |
PROMPT = """/ | |
Ти - віртуальний представник ритейл-компанії та консультант для клієнтів. | |
Для генерації відповідей, використовуй тільки інформацію з контексту! | |
Не задавай додаткових запитань, а просто запропонуй наявний в контексті товар! | |
Твоя ціль відповідати на запитання клієнтів, таким чином допомогти їм. | |
Ти повинен проконсультувати клієнта у виборі товарів використовуючи контекст. | |
Відповідай тільки українською мовою, назви товарів та компанії залишай англійською. | |
У разі, якщо конкретну відповідь знайти не вдалося: | |
- Відповідай "Я не знаю. Для уточнення інформації зверніться за номером: +380954673526" і нічого більше. | |
Ти завжди дотримуєшся ввічливого, професійного тону. Формат відповіді має бути простим, зрозумілим і чітким. Уникай довгих пояснень, якщо вони не потрібні. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6") | |
def load_documents(file_paths): | |
documents = [] | |
for path in file_paths: | |
with open(path, 'r', encoding='utf-8') as file: | |
documents.append(file.read().strip()) | |
return documents | |
class DocumentIndexer: | |
def __init__(self, documents, tokenizer, model): | |
self.documents = documents | |
self.bm25 = BM25Okapi([doc.split() for doc in documents]) | |
self.tokenizer = tokenizer | |
self.model = model | |
self.index = self.create_faiss_index() | |
def create_faiss_index(self): | |
embeddings = self.embed_documents(self.documents) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index | |
def embed_documents(self, docs): | |
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy() | |
return embeddings | |
def search_bm25(self, query, top_k=5): | |
query_terms = query.split() | |
scores = self.bm25.get_scores(query_terms) | |
top_indices = np.argsort(scores)[::-1][:top_k] | |
return [self.documents[i] for i in top_indices] | |
def search_semantic(self, query, top_k=5): | |
query_embedding = self.embed_documents([query]) | |
distances, indices = self.index.search(query_embedding, top_k) | |
return [self.documents[i] for i in indices[0]] | |
class Reranker: | |
def __init__(self, reranker): | |
self.model = reranker | |
def rank(self, query, documents): | |
pairs = [(query, doc) for doc in documents] | |
scores = self.model.predict(pairs) | |
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]] | |
return ranked_docs | |
class QAChatbot: | |
def __init__(self, indexer, reranker): | |
self.indexer = indexer | |
self.reranker = reranker | |
def generate_answer(self, query): | |
bm25_results = self.indexer.search_bm25(query) | |
semantic_results = self.indexer.search_semantic(query) | |
combined_results = list(set(bm25_results + semantic_results)) | |
ranked_docs = self.reranker.rank(query, combined_results) | |
context = "\n".join(ranked_docs[:3]) | |
response = completion( | |
model="groq/llama3-8b-8192", | |
messages=[ | |
{ | |
"role": "system", | |
"content": PROMPT | |
}, | |
{ | |
"role": "user", | |
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:", | |
} | |
], | |
) | |
return response | |
def chatbot_interface(query, history): | |
file_paths = ["company.txt", "Base.txt"] | |
documents = load_documents(file_paths) | |
indexer = DocumentIndexer(documents, tokenizer, model) | |
reranker = Reranker(reranker_model) | |
chatbot = QAChatbot(indexer, reranker) | |
answer = chatbot.generate_answer(query) | |
return answer["choices"][0]["message"]["content"] | |
iface = gr.ChatInterface(fn=chatbot_interface, type="messages") | |
if __name__ == "__main__": | |
iface.launch() | |