Spaces:
Sleeping
Sleeping
File size: 5,121 Bytes
803aa7a 6258aee 803aa7a ac2f0d0 44018ec 6258aee 803aa7a 6258aee 44018ec 6258aee 44018ec 6258aee 44018ec 6258aee 2e9d187 6258aee 9ffc8e2 2e9d187 6258aee 44018ec 6258aee 6e3fb2b 6258aee 803aa7a 6258aee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from litellm import completion
import os
import torch
from sentence_transformers import CrossEncoder
os.environ['GROQ_API_KEY'] = "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"
PROMPT = """/
Ти - віртуальний представник ритейл-компанії та консультант для клієнтів.
Для генерації відповідей, використовуй тільки інформацію з контексту!
Не задавай додаткових запитань, а просто запропонуй наявний в контексті товар!
Твоя ціль відповідати на запитання клієнтів, таким чином допомогти їм.
Ти повинен проконсультувати клієнта у виборі товарів використовуючи контекст.
Відповідай тільки українською мовою, назви товарів та компанії залишай англійською.
У разі, якщо конкретну відповідь знайти не вдалося:
- Відповідай "Я не знаю. Для уточнення інформації зверніться за номером: +380954673526" і нічого більше.
Ти завжди дотримуєшся ввічливого, професійного тону. Формат відповіді має бути простим, зрозумілим і чітким. Уникай довгих пояснень, якщо вони не потрібні.
"""
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6")
def load_documents(file_paths):
documents = []
for path in file_paths:
with open(path, 'r', encoding='utf-8') as file:
documents.append(file.read().strip())
return documents
class DocumentIndexer:
def __init__(self, documents, tokenizer, model):
self.documents = documents
self.bm25 = BM25Okapi([doc.split() for doc in documents])
self.tokenizer = tokenizer
self.model = model
self.index = self.create_faiss_index()
def create_faiss_index(self):
embeddings = self.embed_documents(self.documents)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
def embed_documents(self, docs):
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy()
return embeddings
def search_bm25(self, query, top_k=5):
query_terms = query.split()
scores = self.bm25.get_scores(query_terms)
top_indices = np.argsort(scores)[::-1][:top_k]
return [self.documents[i] for i in top_indices]
def search_semantic(self, query, top_k=5):
query_embedding = self.embed_documents([query])
distances, indices = self.index.search(query_embedding, top_k)
return [self.documents[i] for i in indices[0]]
class Reranker:
def __init__(self, reranker):
self.model = reranker
def rank(self, query, documents):
pairs = [(query, doc) for doc in documents]
scores = self.model.predict(pairs)
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]]
return ranked_docs
class QAChatbot:
def __init__(self, indexer, reranker):
self.indexer = indexer
self.reranker = reranker
def generate_answer(self, query):
bm25_results = self.indexer.search_bm25(query)
semantic_results = self.indexer.search_semantic(query)
combined_results = list(set(bm25_results + semantic_results))
ranked_docs = self.reranker.rank(query, combined_results)
context = "\n".join(ranked_docs[:3])
response = completion(
model="groq/llama3-8b-8192",
messages=[
{
"role": "system",
"content": PROMPT
},
{
"role": "user",
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:",
}
],
)
return response
def chatbot_interface(query, history):
file_paths = ["company.txt", "Base.txt"]
documents = load_documents(file_paths)
indexer = DocumentIndexer(documents, tokenizer, model)
reranker = Reranker(reranker_model)
chatbot = QAChatbot(indexer, reranker)
answer = chatbot.generate_answer(query)
return answer["choices"][0]["message"]["content"]
iface = gr.ChatInterface(fn=chatbot_interface, type="messages")
if __name__ == "__main__":
iface.launch()
|