|
import gradio as gr |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained("Addaci/mT5-small-experiment-13-checkpoint-2790") |
|
tokenizer = T5Tokenizer.from_pretrained("Addaci/mT5-small-experiment-13-checkpoint-2790") |
|
|
|
|
|
def correct_htr_text(input_text, max_new_tokens, temperature): |
|
prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}" |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def summarize_legal_text(input_text, max_new_tokens, temperature): |
|
prompt = f"Summarize this legal text: {input_text}" |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def answer_legal_question(input_text, question, max_new_tokens, temperature): |
|
prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}" |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def correct_htr_interface(text, max_new_tokens, temperature): |
|
return correct_htr_text(text, max_new_tokens, temperature) |
|
|
|
def summarize_interface(text, max_new_tokens, temperature): |
|
return summarize_legal_text(text, max_new_tokens, temperature) |
|
|
|
def question_interface(text, question, max_new_tokens, temperature): |
|
return answer_legal_question(text, question, max_new_tokens, temperature) |
|
|
|
def clear_all(): |
|
return "", "" |
|
|
|
|
|
def clickable_buttons(): |
|
button_html = """ |
|
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> |
|
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" |
|
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;"> |
|
Admiralty Court Legal Glossary</a> |
|
<a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" |
|
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;"> |
|
HCA 13/70 Ground Truth</a> |
|
</div> |
|
""" |
|
return button_html |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML("<h1>Marinelives mT5-small Legal Assistant</h1>") |
|
gr.HTML(clickable_buttons()) |
|
|
|
with gr.Tab("Correct Raw HTR"): |
|
input_text = gr.Textbox(lines=10, label="Input Text") |
|
output_text = gr.Textbox(label="Corrected Text") |
|
max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens") |
|
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") |
|
correct_button = gr.Button("Correct HTR") |
|
clear_button = gr.Button("Clear") |
|
|
|
correct_button.click(fn=correct_htr_interface, |
|
inputs=[input_text, max_new_tokens, temperature], |
|
outputs=output_text) |
|
clear_button.click(fn=clear_all, outputs=[input_text, output_text]) |
|
|
|
with gr.Tab("Summarize Legal Text"): |
|
input_text_summarize = gr.Textbox(lines=10, label="Input Text") |
|
output_text_summarize = gr.Textbox(label="Summary") |
|
max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens") |
|
temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature") |
|
summarize_button = gr.Button("Summarize Text") |
|
clear_button_summarize = gr.Button("Clear") |
|
|
|
summarize_button.click(fn=summarize_interface, |
|
inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize], |
|
outputs=output_text_summarize) |
|
clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize]) |
|
|
|
with gr.Tab("Answer Legal Question"): |
|
input_text_question = gr.Textbox(lines=10, label="Input Text") |
|
question = gr.Textbox(label="Question") |
|
output_text_question = gr.Textbox(label="Answer") |
|
max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens") |
|
temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") |
|
question_button = gr.Button("Get Answer") |
|
clear_button_question = gr.Button("Clear") |
|
|
|
question_button.click(fn=question_interface, |
|
inputs=[input_text_question, question, max_new_tokens_question, temperature_question], |
|
outputs=output_text_question) |
|
clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question]) |
|
|
|
demo.launch() |