legal-ease / app.py
shivi's picture
Update app.py
07fceec
from __future__ import annotations
from typing import Iterable
import gradio as gr
from base.legal_document_utils import (
summarize,
question_answer,
load_gpl_license,
load_pokemon_license,
)
from base.document_search import cross_lingual_document_search, translate_search_result
from gradio.themes.base import Base
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
class CustomTheme(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.blue,
secondary_hue: colors.Color | str = colors.cyan,
neutral_hue: colors.Color | str = colors.zinc,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_md,
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
)
self.name = "custom_theme"
super().set(
# Colors
background_fill_primary="*neutral_50",
slider_color="*primary_500",
slider_color_dark="*primary_600",
# Shadows
shadow_drop="0 1px 4px 0 rgb(0 0 0 / 0.1)",
shadow_drop_lg="0 2px 5px 0 rgb(0 0 0 / 0.1)",
# Block Labels
block_background_fill="white",
block_label_padding="*spacing_sm *spacing_md",
block_label_background_fill="*primary_100",
block_label_background_fill_dark="*primary_600",
block_label_radius="*radius_md",
block_label_text_size="*text_md",
block_label_text_weight="600",
block_label_text_color="*primary_500",
block_label_text_color_dark="*white",
block_title_radius="*block_label_radius",
block_title_padding="*block_label_padding",
block_title_background_fill="*block_label_background_fill",
block_title_text_weight="600",
block_title_text_color="*primary_500",
block_title_text_color_dark="*white",
block_label_margin="*spacing_md",
block_shadow="*shadow_drop_lg",
# Inputs
input_border_color="*neutral_50",
input_shadow="*shadow_drop",
input_shadow_focus="*shadow_drop_lg",
checkbox_shadow="none",
# Buttons
shadow_spread="6px",
button_shadow="*shadow_drop_lg",
button_shadow_hover="*shadow_drop_lg",
button_shadow_active="*shadow_inset",
button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
button_primary_text_color="white",
button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
button_primary_background_fill_hover_dark="*primary_500",
button_secondary_background_fill="white",
button_secondary_background_fill_hover="*neutral_100",
button_secondary_background_fill_hover_dark="*primary_500",
button_secondary_text_color="*neutral_800",
button_cancel_background_fill="*button_secondary_background_fill",
button_cancel_background_fill_hover="*button_secondary_background_fill_hover",
button_cancel_background_fill_hover_dark="*button_secondary_background_fill_hover",
button_cancel_text_color="*button_secondary_text_color",
#checkboxes
checkbox_label_shadow="*shadow_drop_lg",
checkbox_label_background_fill_selected="*primary_500",
checkbox_label_background_fill_selected_dark="*primary_600",
checkbox_border_width="1px",
checkbox_border_color="*neutral_100",
checkbox_border_color_dark="*neutral_600",
checkbox_background_color_selected="*primary_600",
checkbox_background_color_selected_dark="*primary_700",
checkbox_border_color_focus="*primary_500",
checkbox_border_color_focus_dark="*primary_600",
checkbox_border_color_selected="*primary_600",
checkbox_border_color_selected_dark="*primary_700",
checkbox_label_text_color_selected="white",
# Borders
block_border_width="0px",
panel_border_width="1px",
)
custom_theme = CustomTheme()
max_search_results = 3
def reset_chatbot():
return gr.update(value="")
def get_user_input(input_question, history):
return "", history + [[input_question, None]]
def legal_doc_qa_bot(input_document, history):
bot_message = question_answer(input_document, history)
history[-1][1] = bot_message
return history
with gr.Blocks(theme=custom_theme) as demo:
gr.HTML(
"""<html><center><img src='file/logo/flc_design4.png', alt='Legal-ease logo', width=250, height=250 /></center><br></html>"""
)
qa_bot_state = gr.State(value=[])
with gr.Tabs():
with gr.TabItem("Q&A"):
gr.HTML(
"""<p style="text-align:center;"><b>Legal documents can be difficult to comprehend and understand. Add a legal document below and ask any questions related to it.</p>"""
)
with gr.Row():
with gr.Column():
input_document = gr.Text(label="Copy your document here", lines=10)
with gr.Column():
chatbot = gr.Chatbot(label="Chat History")
input_question = gr.Text(
label="Ask a question",
placeholder="Type a question here and hit enter.",
)
clear = gr.Button("Clear", variant="primary")
with gr.Row():
with gr.Accordion("Show example inputs I can load:", open=False):
example_1 = gr.Button(
"Load GPL License Document", variant="primary"
)
example_2 = gr.Button(
"Load Pokemon Go Terms of Service", variant="primary"
)
with gr.TabItem("Summarize"):
gr.HTML(
"""<p style="text-align:center;"><b>Legal documents can be very lengthy. Add a legal document below and generate a quick summary for it.</p>"""
)
with gr.Row():
with gr.Column():
summary_input = gr.Text(label="Document", lines=10)
generate_summary = gr.Button("Generate Summary", variant="primary")
with gr.Column():
summary_output = gr.Text(label="Summary", lines=10)
invisible_comp = gr.Text(label="Dummy Component", visible=False)
with gr.Row():
with gr.Accordion("Advanced Settings:", open=False):
summary_length = gr.Radio(
["short", "medium", "long"],
label="Summary Length",
value="long",
)
summary_format = gr.Radio(
["paragraph", "bullets"],
label="Summary Format",
value="bullets",
)
extractiveness = gr.Radio(
["low", "medium", "high"],
label="Extractiveness",
info="Controls how close to the original text the summary is.",
visible=False,
value="high",
)
temperature = gr.Slider(
minimum=0,
maximum=5.0,
value=0.64,
step=0.1,
interactive=True,
visible=False,
label="Temperature",
info="Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output.",
)
with gr.Row():
with gr.Accordion("Show example inputs I can load:", open=False):
example_3 = gr.Button(
"Load GPL License Document", variant="primary"
)
example_4 = gr.Button(
"Load Pokemon Go Terms of Service", variant="primary"
)
with gr.TabItem("Document Search"):
gr.HTML(
"""<p style="text-align:center;"><b>Search across a set of legal documents in any language or even a mix of languages. Query them using any one of over 100 supported languages.</p>"""
)
gr.HTML(
"""<p style="text-align:center; font-style:italic;">Get started with a pre-indexed set of documents from eight European countries (Belgium, France, Hungary, Italy, Netherlands, Norway, Poland, UK) in seven languages, outlining legislation passed during the COVID-19 pandemic.</p>"""
)
with gr.Row():
text_match = gr.CheckboxGroup(
["Full Text Search"], label="find exact text in documents"
)
with gr.Row():
lang_choices = gr.CheckboxGroup(
[
"English",
"French",
"Italian",
"Dutch",
"Polish",
"Hungarian",
"Norwegian",
],
label="Filter results based on language",
)
with gr.Row():
with gr.Column():
user_query = gr.Text(
label="Enter query here",
placeholder="Search through all your documents",
)
num_search_results = gr.Slider(
1,
max_search_results,
visible=False,
value=max_search_results,
step=1,
interactive=True,
label="How many search results to show:",
)
with gr.Row():
with gr.Column():
query_match_out_1 = gr.Textbox(label=f"Search Result 1")
with gr.Column():
with gr.Accordion("Translate Search Result", open=False):
translate_1 = gr.Button(
label="Translate",
value="Translate",
variant="primary",
)
translate_res_1 = gr.Textbox(
label=f"Translation Result 1"
)
with gr.Row():
with gr.Column():
query_match_out_2 = gr.Textbox(label=f"Search Result 2")
with gr.Column():
with gr.Accordion("Translate Search Result", open=False):
translate_2 = gr.Button(
label="Translate",
value="Translate",
variant="primary",
)
translate_res_2 = gr.Textbox(
label=f"Translation Result 2"
)
with gr.Row():
with gr.Column():
query_match_out_3 = gr.Textbox(label=f"Search Result 3")
with gr.Column():
with gr.Accordion("Translate Search Result", open=False):
translate_3 = gr.Button(
label="Translate",
value="Translate",
variant="primary",
)
translate_res_3 = gr.Textbox(
label=f"Translation Result 3"
)
# fetch answer for submitted question corresponding to input document
input_question.submit(
get_user_input,
[input_question, chatbot],
[input_question, chatbot],
queue=False,
).then(legal_doc_qa_bot, [input_document, chatbot], chatbot)
# reset the chatbot Q&A history when input document changes
input_document.change(fn=reset_chatbot, inputs=[], outputs=chatbot)
# Loading examples on click for Q&A module
example_1.click(
load_gpl_license,
[],
[input_document, input_question],
queue=False,
)
example_2.click(
load_pokemon_license,
[],
[input_document, input_question],
queue=False,
)
# Loading examples on click for Q&A module
example_3.click(
load_gpl_license,
[],
[summary_input, invisible_comp],
queue=False,
)
example_4.click(
load_pokemon_license,
[],
[summary_input, invisible_comp],
queue=False,
)
# generate summary corresponding to document submitted by the user.
generate_summary.click(
summarize,
[summary_input, summary_length, summary_format, extractiveness, temperature],
[summary_output],
queue=False,
)
# clear the chatbot Q&A history when this button is clicked by the user
clear.click(lambda: None, None, chatbot, queue=False)
# run search as user is typing the query
user_query.change(
cross_lingual_document_search,
[user_query, num_search_results, lang_choices, text_match],
[query_match_out_1, query_match_out_2, query_match_out_3],
queue=False,
)
# run search if user submits query
user_query.submit(
cross_lingual_document_search,
[user_query, num_search_results, lang_choices, text_match],
[query_match_out_1, query_match_out_2, query_match_out_3],
queue=False,
)
# translate results corresponding to 1st search result obtained if user clicks 'Translate'
translate_1.click(
translate_search_result,
[query_match_out_1, user_query],
[translate_res_1],
queue=False,
)
# translate results corresponding to 2nd search result obtained if user clicks 'Translate'
translate_2.click(
translate_search_result,
[query_match_out_2, user_query],
[translate_res_2],
queue=False,
)
# translate results corresponding to 3rd search result obtained if user clicks 'Translate'
translate_3.click(
translate_search_result,
[query_match_out_3, user_query],
[translate_res_3],
queue=False,
)
if __name__ == "__main__":
demo.launch(debug=True)