Spaces:
Sleeping
Sleeping
import gradio as gr | |
from litellm import completion | |
import glob | |
import os | |
from retriever_reranker_final import Retriever | |
PROMPT = """\ | |
You are a helpful assistant that can answer questions. | |
Rules: | |
- Provide clear and concise answers to all questions. | |
- Always paraphrase the context when forming your response. Do not copy the text directly. | |
- Use only the provided context to answer questions. | |
- Only give the answer. | |
- Avoid repeating long phrases from the context. | |
- Structure your responses in a way that is easy to read and understand. | |
""" | |
#- Include only information relevant to the question and ignore unrelated details. | |
class QuestionAnsweringBot: | |
def __init__(self, docs) -> None: | |
self.retriever = Retriever(docs) | |
self.max_citations = 3 | |
def answer_question(self, question: str, api_key: str, methods: list[str]) -> list[str]: | |
try: | |
os.environ['GROQ_API_KEY'] = api_key | |
if not methods: | |
return ["No search method selected. Please select at least one search method.", "-", "-"] | |
retr_context = self.retriever.get_docs(question, methods) | |
reranked_context = self.retriever.rerank(question, retr_context) | |
context_with_all_chunks = "\n".join([ | |
f"[{chunk_id}] {chunk_text}" for chunk_id, chunk_text in reranked_context.items() | |
]) | |
messages = [ | |
{"role": "system", "content": PROMPT}, | |
{"role": "user", "content": f"Context:\n{context_with_all_chunks}\nQuestion: {question}"} | |
] | |
response = completion( | |
model="groq/llama3-8b-8192", | |
messages=messages | |
) | |
response_text = response.choices[0].message.content | |
used_chunks = self._filter_used_chunks(response_text, reranked_context) | |
top_chunks = self._get_top_chunks(used_chunks, self.max_citations) | |
context_with_citations = "\n".join([ | |
f"[{chunk_id}] {chunk_text}" for chunk_id, chunk_text in top_chunks.items() | |
]) | |
used_sources = " ".join([f"[{chunk_id}]" for chunk_id in top_chunks.keys()]) | |
final_response = f"{response_text} \nUsed chunks: {used_sources}" | |
return [ | |
final_response, | |
context_with_citations, | |
f"{reranked_context}" | |
] | |
except Exception as e: | |
print(f"Error: {e}") | |
return ["Error occurred during processing.", "-", "-"] | |
def _filter_used_chunks(self, response_text, reranked_context): | |
used_chunks = {} | |
for chunk_id, chunk_text in reranked_context.items(): | |
if any(word.lower() in response_text.lower() for word in chunk_text.split()): | |
used_chunks[chunk_id] = chunk_text | |
return used_chunks | |
def _get_top_chunks(self, used_chunks, max_citations): | |
unique_chunks = {} | |
added_texts = set() | |
for chunk_id, chunk_text in used_chunks.items(): | |
if any(chunk_text[:50] in text for text in added_texts): | |
continue | |
unique_chunks[chunk_id] = chunk_text | |
added_texts.add(chunk_text[:50]) | |
if len(unique_chunks) >= max_citations: | |
break | |
return unique_chunks | |
# Load documents | |
all_docs = {} | |
for path in glob.glob("data/*.txt"): | |
with open(path) as f: | |
doc_name = os.path.basename(path) | |
all_docs[doc_name] = f.read() | |
bot = QuestionAnsweringBot(all_docs) | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style="text-align: center; font-size: 40px; font-weight: bold; margin-bottom: 10px;"> | |
Harry Potter Encyclopedia | |
</div> | |
<div style="text-align: center; font-size: 25px ;;margin-bottom: 10px; color: #ffdac7"> | |
Performed by Subtelna Sofiia (CS-414) and Iryna Iskovych (CS-415) | |
</div> | |
<div style="font-size: 20px; "> | |
This bot answers questions about the world of Harry Potter and the plot of the first book. System uses 3 files - general information about the world, main characters and the text of the first book. However, some of the information about the universe or later years can be missing. | |
<div style="font-weight: bold; margin-top: 25px;"> | |
Instructions | |
</div> | |
<ol> | |
<li>Enter your Groq API Key in the textbox below.</li> | |
<li>The API key can be generated using this <a href="https://console.groq.com/keys">link</a></li> | |
<li>Input your query</li> | |
<li>Select the scoring method from the proposed ones:</li> | |
<ul> | |
<li>BM25</li> | |
<li>Semantic search</li> | |
<li>Combined search (combination of BM25 and semantic search)</li> | |
</ul> | |
<li>Click "Submit" button</li> | |
</ol> | |
</div> | |
""") | |
question_input = gr.Textbox( | |
label="Question", | |
placeholder="Ask your question here", | |
) | |
api_key_input = gr.Textbox(label="API key", placeholder="Provide API key here", info="Input key from Groq") | |
search_method = gr.Radio( | |
info="Choose the search method from the proposed ones", | |
choices=["BM25", "semantic", "combined search"], | |
label="Search Method", | |
value="BM25" | |
) | |
submit_button = gr.Button("Submit") | |
output_bot = gr.Textbox(label="Answer") | |
output_citations = gr.Textbox(label="Citations") | |
#output_context = gr.Textbox(label="Context") | |
submit_button.click( | |
bot.answer_question, | |
inputs=[question_input, api_key_input, search_method], | |
outputs=[output_bot,output_citations]#, output_context] | |
) | |
demo.launch() | |