rag_app / app.py
sofiia19's picture
Update app.py
6b023f2 verified
raw
history blame
5.41 kB
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# Install required libraries
for package in ["litellm", "gradio", "datasets", "rank_bm25", "sentence-transformers","typing"]:
try:
__import__(package)
except ImportError:
install(package)
from litellm import completion
import os
os.environ['GROQ_API_KEY'] = "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"
response = completion(
model="groq/llama3-8b-8192",
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
from datasets import load_dataset
dataset = load_dataset("hugginglearners/russia-ukraine-conflict-articles")
docs = [item['articles'] for item in dataset['train'].select(range(10))]
def chunk_document(doc: str, doc_id: int, desired_chunk_size: int = 100, max_chunk_size: int = 3000):
chunk = ''
chunk_number = 0
for line in doc.splitlines():
chunk += line + '\n'
if len(chunk) >= desired_chunk_size:
yield (doc_id, chunk_number, chunk[:max_chunk_size])
chunk = ''
chunk_number += 1
if chunk:
yield (doc_id, chunk_number, chunk)
def chunk_documents(docs: list[str], desired_chunk_size: int = 100, max_chunk_size: int = 3000):
chunks = []
for doc_id, doc in enumerate(docs):
chunks.extend(chunk_document(doc, doc_id, desired_chunk_size, max_chunk_size))
return chunks
#from typing import list
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import torch
class Retriever:
def __init__(self, docs: list[str]):
self.chunks = chunk_documents(docs)
self.docs = [chunk[2] for chunk in self.chunks]
tokenized_docs = [doc.lower().split(" ") for doc in self.docs]
self.bm25 = BM25Okapi(tokenized_docs)
self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
self.doc_embeddings = self.sbert.encode(self.docs)
def get_docs(self, query, method="bm25", n=3):
if method == "bm25":
scores = self._get_bm25_scores(query)
elif method == "sbert":
scores = self._get_semantic_scores(query)
elif method == "hybrid":
bm25_scores = self._get_bm25_scores(query)
semantic_scores = self._get_semantic_scores(query)
scores = 0.3 * bm25_scores + 0.7 * semantic_scores
else:
raise ValueError("Invalid method. Choose 'bm25', 'sbert', or 'hybrid'.")
sorted_indices = np.argsort(scores)[::-1]
# Повертаємо перші n документів із інформацією про джерело
return [(self.chunks[i][0], self.chunks[i][1], self.docs[i]) for i in sorted_indices[:n]]
def _get_bm25_scores(self, query):
tokenized_query = query.lower().split(" ")
return self.bm25.get_scores(tokenized_query)
def _get_semantic_scores(self, query):
query_embedding = self.sbert.encode(query)
scores = torch.cosine_similarity(
torch.tensor(query_embedding).unsqueeze(0),
torch.tensor(self.doc_embeddings),
dim=1
)
return scores.numpy()
class QuestionAnsweringBot:
PROMPT = '''\
You are a helpful assistant that can answer questions.
Rules:
-Reply with the answer only and nothing but the answer.
-Say 'I don't know(((' if you don't know the answer.
-Use the provided context.
'''
def __init__(self, docs):
self.retriever = Retriever(docs)
def answer_question(self, question: str, method: str = "bm25") -> str:
context_with_indices = self.retriever.get_docs(question, method=method)
if not context_with_indices:
return "I don't know((("
# контекст для моделі
context = "\n".join([f"Doc {doc_id}, Chunk {chunk_id}: {text}" for doc_id, chunk_id, text in context_with_indices])
messages = [
{"role": "system", "content": self.PROMPT},
{"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
]
try:
completionn = completion(
model="groq/llama3-8b-8192",
messages=messages,
)
# Відповідь
answer = completionn['choices'][0]['message']['content']
# джерела
sources = [f"Doc {doc_id}: Chunk {chunk_id}; " for doc_id, chunk_id, _ in context_with_indices]
return f"{answer} [{', '.join(sources)}]"
except Exception as e:
return f"Error: {str(e)}"
# question = "Tell about war"
docs = docs
# bot = QuestionAnsweringBot(docs)
# answer = bot.answer_question(question)
# print(f'Q: {question}')
# print(f'A: {answer}')
import gradio as gr
def answer_question_with_method(query, method):
bot = QuestionAnsweringBot(docs)
return bot.answer_question(query, method=method)
# Створення інтерфейсу
demo = gr.Interface(
fn=answer_question_with_method,
inputs=[
gr.Textbox(label="Your Question"),
gr.Dropdown(
choices=["bm25", "sbert", "hybrid"],
value="hybrid",
label="Select Retrieval Method"
)
],
outputs="text"
)
demo.launch()