hatim00101's picture
Update app.py
f3ec75b verified
raw
history blame contribute delete
No virus
3.71 kB
from transformers import pipeline
generator = pipeline("text-generation", model="tiiuae/falcon-7B")
from sentence_transformers import SentenceTransformer , CrossEncoder
from transformers import pipeline
embedder = SentenceTransformer('all-MiniLM-L6-v2')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
import numpy as np
import faiss
import pickle
from rank_bm25 import BM25Okapi
import gradio as gr
# Load embeddings and FAISS index
with open("Assemesment5_day4.my_faiss_embeddings.pkl", "rb") as f:
embeddings = pickle.load(f)
faiss_index = faiss.read_index("my_faiss_index.faiss")
# Load chunks
with open('chunks.pkl', 'rb') as f:
chunks = pickle.load(f)
bm25 = BM25Okapi([chunk['text'].split() for chunk in chunks])
def hybrid_search(query, top_k=5):
query_tokens = query.split()
# BM25 retrieval
bm25_scores = bm25.get_scores(query_tokens)
top_bm25_indices = np.argsort(bm25_scores)[::-1][:top_k]
# FAISS retrieval
query_embedding = embedder.encode([query])
distances, faiss_indices = faiss_index.search(query_embedding, top_k)
# Combine results
combined_indices = np.unique(np.concatenate((top_bm25_indices, faiss_indices[0])), axis=0)[:top_k]
combined_chunks = [chunks[i] for i in combined_indices]
inputs = [(query, chunk['text']) for chunk in combined_chunks]
# Cross-encoder reranking
scores = cross_encoder.predict(inputs)
reranked_chunks = [chunk for _, chunk in sorted(zip(scores, combined_chunks), reverse=True)]
return reranked_chunks
def two_stage_rag_search(query, top_k=5):
results = hybrid_search(query, top_k)
context = "\n\n".join([chunk['text'] for chunk in results])
extraction_prompt = (
f"Given the following context, extract the most relevant passage that answers the question.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Relevant Passage:\n"
)
extraction_response = generator(extraction_prompt, max_length=1000, num_return_sequences=1)
relevant_passage = extraction_response[0]['generated_text'].strip()
answer_prompt = (
f"Based on the passage below, generate a detailed and thoughtful answer to the question.\n\n"
f"Relevant Passage: {relevant_passage}\n\n"
f"Question: {query}\n\n"
f"Answer:\n"
f"Format your response as follows:\n"
f"Metadata:\n"
f"Author: 'author'\n"
f"Title: 'title'\n"
f"Date: 'date'\n"
f"Description: 'description'\n\n"
f"Content or text:\n"
f"{relevant_passage}"
)
answer_response = generator(answer_prompt, max_length=1500, num_return_sequences=1)
final_answer = answer_response[0]['generated_text'].strip()
return final_answer
def gradio_interface(query, feedback):
results = hybrid_search(query, top_k=5)
# Convert results to a format suitable for Gradio
result_texts = "\n\n".join([f"Text: {chunk['text']}\nMetadata: {chunk['metadata']}" for chunk in results])
# Provide a detailed answer
detailed_answer = two_stage_rag_search(query)
return result_texts, detailed_answer
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your query here...")
# gr.Dropdown(choices=["positive", "negative"], label="Feedback"),
],
outputs=[
gr.Textbox(lines=20, placeholder="The search results will be displayed here..."),
gr.Textbox(lines=20, placeholder="The detailed answer will be displayed here...")
],
title="News share engine_zz",
description="."
)
interface.launch(share=True)