audit_assistant / app.py
ppsingh's picture
Update app.py
723ac7e verified
raw
history blame
No virus
11.6 kB
import gradio as gr
import pandas as pd
import numpy as np
import os
import time
import re
import json
from auditqa.sample_questions import QUESTIONS
from auditqa.reports import POSSIBLE_REPORTS
from auditqa.engine.prompts import audience_prompts, answer_prompt_template
from auditqa.doc_process import process_pdf
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.llms import HuggingFaceEndpoint
from dotenv import load_dotenv
load_dotenv()
HF_token = os.environ["HF_TOKEN"]
vectorstores = process_pdf()
async def chat(query,history,audience,sources,reports):
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
(messages in gradio format, messages in langchain format, source documents)"""
print(f">> NEW QUESTION : {query}")
print(f"history:{history}")
print(f"audience:{audience}")
print(f"sources:{sources}")
print(f"reports:{reports}")
docs_html = ""
output_query = ""
output_language = "english"
if audience == "Children":
audience_prompt = audience_prompts["children"]
elif audience == "General public":
audience_prompt = audience_prompts["general"]
elif audience == "Experts":
audience_prompt = audience_prompts["experts"]
else:
audience_prompt = audience_prompts["experts"]
# Prepare default values
if len(sources) == 0:
sources = ["ABC"]
if len(reports) == 0:
reports = []
if sources == ["ABC"]:
vectorstore = vectorstores["ABC"]
else:
vectorstore = vectorstores["XYZ"]
# get context
context_retrieved_lst = []
question_lst= [query]
for question in question_lst:
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 1})
context_retrieved = retriever.get_relevant_documents(question)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
context_retrieved_formatted = format_docs(context_retrieved)
context_retrieved_lst.append(context_retrieved_formatted)
# get prompt
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
# get llm
llm_qa = HuggingFaceEndpoint(
endpoint_url= "https://fesg9gjsfde5yfr4.us-east-1.aws.endpoints.huggingface.cloud",
task="text-generation",
huggingfacehub_api_token=HF_token,
model_kwargs={})
# create rag chain
chain = prompt | llm_qa | StrOutputParser()
# get answers
answer_lst = []
for question, context in zip(question_list , context_retrieved_lst):
answer = chain.invoke({"context": context, "question": question,'audience':audience_prompt, 'language':'english'})
answer_lst.append(answer)
docs_html = []
for i, d in enumerate(context_retrieved, 1):
docs_html.append(make_html_source(d, i))
docs_html = "".join(docs_html)
previous_answer = history[-1][1]
previous_answer = previous_answer if previous_answer is not None else ""
answer_yet = previous_answer + answer_lst[0]
answer_yet = parse_output_llm_with_sources(answer_yet)
history[-1] = (query,answer_yet)
history = [tuple(x) for x in history]
yield history,docs_html,output_query,output_language
def make_html_source(source,i):
meta = source.metadata
# content = source.page_content.split(":",1)[1].strip()
content = source.page_content.strip()
toc_levels = []
for j in range(2):
level = meta[f"toc_level{j}"]
if level != "N/A":
toc_levels.append(level)
else:
break
toc_levels = " > ".join(toc_levels)
if len(toc_levels) > 0:
name = f"<b>{toc_levels}</b><br/>{meta['name']}"
else:
name = meta['name']
if meta["chunk_type"] == "text":
card = f"""
<div class="card" id="doc{i}">
<div class="card-content">
<h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
<p>{content}</p>
</div>
<div class="card-footer">
<span>{name}</span>
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
<span role="img" aria-label="Open PDF">🔗</span>
</a>
</div>
</div>
"""
return card
# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
init_prompt = """
Hello, I am Audit Q&A, a conversational assistant designed to help you understand audit Reports. I will answer your questions by **crawling through the Audit reports publishsed by Auditor General Office**.
❓ How to use
- **Language**: You can ask me your questions in any language.
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the Annual or District or Department focused reports, or all.
⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
What do you want to learn ?
"""
# Setting Tabs
with gr.Blocks(title="Audit Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
# user_id_state = gr.State([user_id])
with gr.Tab("AuditQ&A"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2):
# state = gr.State([system_template])
chatbot = gr.Chatbot(
value=[(None,init_prompt)],
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
)#,avatar_images = ("assets/logo4.png",None))
# bot.like(vote,None,None)
with gr.Row(elem_id = "input-message"):
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
with gr.Tabs() as tabs:
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
examples_hidden = gr.Textbox(visible = False)
first_key = list(QUESTIONS.keys())[0]
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
samples = []
for i,key in enumerate(QUESTIONS.keys()):
examples_visible = True if i == 0 else False
with gr.Row(visible = examples_visible) as group_examples:
examples_questions = gr.Examples(
QUESTIONS[key],
[examples_hidden],
examples_per_page=8,
run_on_click=False,
elem_id=f"examples{i}",
api_name=f"examples{i}",
# label = "Click on the example question or enter your own",
# cache_examples=True,
)
samples.append(group_examples)
with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
docs_textbox = gr.State("")
# with Modal(visible = False) as config_modal:
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
gr.Markdown("Reminder: You can talk in any language, Audit Q&A is multi-lingual!")
dropdown_sources = gr.CheckboxGroup(
["ABC", "XYZ"],
label="Select source",
value=["ABC"],
interactive=True,
)
dropdown_reports = gr.Dropdown(
POSSIBLE_REPORTS,
label="Or select specific reports",
multiselect=True,
value=None,
interactive=True,
)
dropdown_audience = gr.Dropdown(
["Children","General public","Experts"],
label="Select audience",
value="Experts",
interactive=True,
)
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
with gr.Tab("About",elem_classes = "max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("See more info at [https://www.oag.go.ug/](https://www.oag.go.ug/welcome)")
def start_chat(query,history):
history = history + [(query,None)]
history = [tuple(x) for x in history]
return (gr.update(interactive = False),gr.update(selected=1),history)
def finish_chat():
return (gr.update(interactive = True,value = ""))
(textbox
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language],concurrency_limit = 8,api_name = "chat_textbox")
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
)
(examples_hidden
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language],concurrency_limit = 8,api_name = "chat_examples")
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
)
def change_sample_questions(key):
index = list(QUESTIONS.keys()).index(key)
visible_bools = [False] * len(samples)
visible_bools[index] = True
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
demo.queue()
demo.launch()