Update app.py
#6
by
Addaci
- opened
app.py
CHANGED
@@ -1,79 +1,113 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import
|
3 |
|
4 |
-
# Load model and tokenizer
|
5 |
-
model = T5ForConditionalGeneration.from_pretrained("google/
|
6 |
-
tokenizer = T5Tokenizer.from_pretrained("google/
|
7 |
|
8 |
-
#
|
9 |
-
def
|
10 |
-
|
|
|
11 |
outputs = model.generate(
|
12 |
-
|
13 |
-
max_new_tokens=max_new_tokens,
|
14 |
-
temperature=temperature
|
15 |
-
do_sample=True
|
16 |
)
|
17 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
prompt = "summarize: " + text
|
22 |
inputs = tokenizer(prompt, return_tensors="pt")
|
23 |
outputs = model.generate(
|
24 |
-
|
25 |
-
max_new_tokens=max_new_tokens,
|
26 |
-
temperature=temperature
|
27 |
-
do_sample=True
|
28 |
)
|
29 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
prompt = f"question: {question} context: {context}"
|
34 |
inputs = tokenizer(prompt, return_tensors="pt")
|
35 |
outputs = model.generate(
|
36 |
-
|
37 |
-
max_new_tokens=max_new_tokens,
|
38 |
-
temperature=temperature
|
39 |
-
do_sample=True
|
40 |
)
|
41 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
42 |
|
43 |
-
# Gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
with gr.Blocks() as demo:
|
45 |
-
|
46 |
-
gr.
|
47 |
|
48 |
-
with gr.Row():
|
49 |
-
gr.Markdown('[Admiralty Court Legal Glossary](http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary)')
|
50 |
-
gr.Markdown('[HCA 13/70 Ground Truth](https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt)')
|
51 |
-
|
52 |
-
# Tabs for different functionalities
|
53 |
with gr.Tab("Correct Raw HTR"):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
gr.Button("Correct HTR")
|
59 |
-
gr.Button("Clear")
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
with gr.Tab("Summarize Legal Text"):
|
62 |
-
|
63 |
-
|
64 |
max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
|
65 |
temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
|
66 |
-
gr.Button("Summarize Text")
|
67 |
-
gr.Button("Clear")
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
with gr.Tab("Answer Legal Question"):
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
gr.Button("Get Answer")
|
76 |
-
gr.Button("Clear")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
# Launch the demo
|
79 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
|
3 |
|
4 |
+
# Load model and tokenizer for mT5-small
|
5 |
+
model = T5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
6 |
+
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
7 |
|
8 |
+
# Define task-specific prompts
|
9 |
+
def correct_htr_text(input_text, max_new_tokens, temperature):
|
10 |
+
prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}"
|
11 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
12 |
outputs = model.generate(
|
13 |
+
inputs.input_ids,
|
14 |
+
max_new_tokens=max_new_tokens,
|
15 |
+
temperature=temperature
|
|
|
16 |
)
|
17 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
18 |
|
19 |
+
def summarize_legal_text(input_text, max_new_tokens, temperature):
|
20 |
+
prompt = f"Summarize this legal text: {input_text}"
|
|
|
21 |
inputs = tokenizer(prompt, return_tensors="pt")
|
22 |
outputs = model.generate(
|
23 |
+
inputs.input_ids,
|
24 |
+
max_new_tokens=max_new_tokens,
|
25 |
+
temperature=temperature
|
|
|
26 |
)
|
27 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
28 |
|
29 |
+
def answer_legal_question(input_text, question, max_new_tokens, temperature):
|
30 |
+
prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}"
|
|
|
31 |
inputs = tokenizer(prompt, return_tensors="pt")
|
32 |
outputs = model.generate(
|
33 |
+
inputs.input_ids,
|
34 |
+
max_new_tokens=max_new_tokens,
|
35 |
+
temperature=temperature
|
|
|
36 |
)
|
37 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
38 |
|
39 |
+
# Define Gradio interface functions
|
40 |
+
def correct_htr_interface(text, max_new_tokens, temperature):
|
41 |
+
return correct_htr_text(text, max_new_tokens, temperature)
|
42 |
+
|
43 |
+
def summarize_interface(text, max_new_tokens, temperature):
|
44 |
+
return summarize_legal_text(text, max_new_tokens, temperature)
|
45 |
+
|
46 |
+
def question_interface(text, question, max_new_tokens, temperature):
|
47 |
+
return answer_legal_question(text, question, max_new_tokens, temperature)
|
48 |
+
|
49 |
+
def clear_all():
|
50 |
+
return "", ""
|
51 |
+
|
52 |
+
# External clickable buttons
|
53 |
+
def clickable_buttons():
|
54 |
+
button_html = """
|
55 |
+
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
|
56 |
+
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary"
|
57 |
+
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
|
58 |
+
Admiralty Court Legal Glossary</a>
|
59 |
+
<a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt"
|
60 |
+
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
|
61 |
+
HCA 13/70 Ground Truth</a>
|
62 |
+
</div>
|
63 |
+
"""
|
64 |
+
return button_html
|
65 |
+
|
66 |
+
# Interface layout
|
67 |
with gr.Blocks() as demo:
|
68 |
+
gr.HTML("<h1>Flan-T5 Legal Assistant</h1>")
|
69 |
+
gr.HTML(clickable_buttons())
|
70 |
|
|
|
|
|
|
|
|
|
|
|
71 |
with gr.Tab("Correct Raw HTR"):
|
72 |
+
input_text = gr.Textbox(lines=10, label="Textbox")
|
73 |
+
output_text = gr.Textbox(label="Textbox")
|
74 |
+
max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens")
|
75 |
+
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
|
76 |
+
correct_button = gr.Button("Correct HTR")
|
77 |
+
clear_button = gr.Button("Clear")
|
78 |
+
|
79 |
+
correct_button.click(fn=correct_htr_interface,
|
80 |
+
inputs=[input_text, max_new_tokens, temperature],
|
81 |
+
outputs=output_text)
|
82 |
+
clear_button.click(fn=clear_all, outputs=[input_text, output_text])
|
83 |
|
84 |
with gr.Tab("Summarize Legal Text"):
|
85 |
+
input_text_summarize = gr.Textbox(lines=10, label="Textbox")
|
86 |
+
output_text_summarize = gr.Textbox(label="Textbox")
|
87 |
max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
|
88 |
temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
|
89 |
+
summarize_button = gr.Button("Summarize Text")
|
90 |
+
clear_button_summarize = gr.Button("Clear")
|
91 |
+
|
92 |
+
summarize_button.click(fn=summarize_interface,
|
93 |
+
inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize],
|
94 |
+
outputs=output_text_summarize)
|
95 |
+
clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize])
|
96 |
|
97 |
with gr.Tab("Answer Legal Question"):
|
98 |
+
input_text_question = gr.Textbox(lines=10, label="Textbox")
|
99 |
+
question = gr.Textbox(label="Textbox")
|
100 |
+
output_text_question = gr.Textbox(label="Textbox")
|
101 |
+
max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens")
|
102 |
+
temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
|
103 |
+
question_button = gr.Button("Get Answer")
|
104 |
+
clear_button_question = gr.Button("Clear")
|
105 |
+
|
106 |
+
question_button.click(fn=question_interface,
|
107 |
+
inputs=[input_text_question, question, max_new_tokens_question, temperature_question],
|
108 |
+
outputs=output_text_question)
|
109 |
+
clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question])
|
110 |
+
|
111 |
+
gr.Button("Clear", elem_id="clear_button").click(clear_all)
|
112 |
|
|
|
113 |
demo.launch()
|