|
import subprocess |
|
import sys |
|
|
|
def install(package): |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
|
|
for package in ["litellm", "gradio", "datasets", "rank_bm25", "sentence-transformers","typing"]: |
|
try: |
|
__import__(package) |
|
except ImportError: |
|
install(package) |
|
|
|
from litellm import completion |
|
import os |
|
|
|
os.environ['GROQ_API_KEY'] = "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" |
|
response = completion( |
|
model="groq/llama3-8b-8192", |
|
messages=[ |
|
{"role": "user", "content": "hello from litellm"} |
|
], |
|
) |
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("hugginglearners/russia-ukraine-conflict-articles") |
|
|
|
|
|
docs = [item['articles'] for item in dataset['train'].select(range(10))] |
|
def chunk_document(doc: str, doc_id: int, desired_chunk_size: int = 100, max_chunk_size: int = 3000): |
|
chunk = '' |
|
chunk_number = 0 |
|
for line in doc.splitlines(): |
|
chunk += line + '\n' |
|
if len(chunk) >= desired_chunk_size: |
|
yield (doc_id, chunk_number, chunk[:max_chunk_size]) |
|
chunk = '' |
|
chunk_number += 1 |
|
if chunk: |
|
yield (doc_id, chunk_number, chunk) |
|
|
|
def chunk_documents(docs: list[str], desired_chunk_size: int = 100, max_chunk_size: int = 3000): |
|
chunks = [] |
|
for doc_id, doc in enumerate(docs): |
|
chunks.extend(chunk_document(doc, doc_id, desired_chunk_size, max_chunk_size)) |
|
return chunks |
|
|
|
|
|
import numpy as np |
|
from rank_bm25 import BM25Okapi |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
class Retriever: |
|
def __init__(self, docs: list[str]): |
|
|
|
self.chunks = chunk_documents(docs) |
|
self.docs = [chunk[2] for chunk in self.chunks] |
|
tokenized_docs = [doc.lower().split(" ") for doc in self.docs] |
|
self.bm25 = BM25Okapi(tokenized_docs) |
|
self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1') |
|
self.doc_embeddings = self.sbert.encode(self.docs) |
|
|
|
def get_docs(self, query, method="bm25", n=3): |
|
if method == "bm25": |
|
scores = self._get_bm25_scores(query) |
|
elif method == "sbert": |
|
scores = self._get_semantic_scores(query) |
|
elif method == "hybrid": |
|
bm25_scores = self._get_bm25_scores(query) |
|
semantic_scores = self._get_semantic_scores(query) |
|
scores = 0.3 * bm25_scores + 0.7 * semantic_scores |
|
else: |
|
raise ValueError("Invalid method. Choose 'bm25', 'sbert', or 'hybrid'.") |
|
|
|
sorted_indices = np.argsort(scores)[::-1] |
|
|
|
return [(self.chunks[i][0], self.chunks[i][1], self.docs[i]) for i in sorted_indices[:n]] |
|
|
|
def _get_bm25_scores(self, query): |
|
tokenized_query = query.lower().split(" ") |
|
return self.bm25.get_scores(tokenized_query) |
|
|
|
def _get_semantic_scores(self, query): |
|
query_embedding = self.sbert.encode(query) |
|
scores = torch.cosine_similarity( |
|
torch.tensor(query_embedding).unsqueeze(0), |
|
torch.tensor(self.doc_embeddings), |
|
dim=1 |
|
) |
|
return scores.numpy() |
|
class QuestionAnsweringBot: |
|
PROMPT = '''\ |
|
You are a helpful assistant that can answer questions. |
|
|
|
Rules: |
|
-Reply with the answer only and nothing but the answer. |
|
-Say 'I don't know(((' if you don't know the answer. |
|
-Use the provided context. |
|
''' |
|
|
|
def __init__(self, docs): |
|
self.retriever = Retriever(docs) |
|
|
|
def answer_question(self, question: str, method: str = "bm25") -> str: |
|
context_with_indices = self.retriever.get_docs(question, method=method) |
|
if not context_with_indices: |
|
return "I don't know(((" |
|
|
|
|
|
context = "\n".join([f"Doc {doc_id}, Chunk {chunk_id}: {text}" for doc_id, chunk_id, text in context_with_indices]) |
|
|
|
messages = [ |
|
{"role": "system", "content": self.PROMPT}, |
|
{"role": "user", "content": f"Context: {context}\nQuestion: {question}"} |
|
] |
|
|
|
try: |
|
|
|
completionn = completion( |
|
model="groq/llama3-8b-8192", |
|
messages=messages, |
|
) |
|
|
|
answer = completionn['choices'][0]['message']['content'] |
|
|
|
|
|
sources = [f"Doc {doc_id}: Chunk {chunk_id}; " for doc_id, chunk_id, _ in context_with_indices] |
|
return f"{answer} [{', '.join(sources)}]" |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
docs = docs |
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
def answer_question_with_method(query, method): |
|
bot = QuestionAnsweringBot(docs) |
|
return bot.answer_question(query, method=method) |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=answer_question_with_method, |
|
inputs=[ |
|
gr.Textbox(label="Your Question"), |
|
gr.Dropdown( |
|
choices=["bm25", "sbert", "hybrid"], |
|
value="hybrid", |
|
label="Select Retrieval Method" |
|
) |
|
], |
|
outputs="text" |
|
) |
|
|
|
demo.launch() |
|
|
|
|