File size: 5,260 Bytes
333cd91 fc8771f 333cd91 fc8771f 6c8e814 0eb38a5 74f4d51 df63252 74f4d51 df63252 333cd91 74f4d51 df63252 74f4d51 df63252 333cd91 74f4d51 df63252 74f4d51 df63252 333cd91 74f4d51 93d2618 fc8771f 74f4d51 df63252 fc8771f 74f4d51 1b938f5 df63252 fc8771f df63252 74f4d51 77122ee fc8771f 74f4d51 fc8771f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("Addaci/mT5-small-experiment-13-checkpoint-2790")
tokenizer = T5Tokenizer.from_pretrained("Addaci/mT5-small-experiment-13-checkpoint-2790")
# Define task-specific prompts
def correct_htr_text(input_text, max_new_tokens, temperature):
prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def summarize_legal_text(input_text, max_new_tokens, temperature):
prompt = f"Summarize this legal text: {input_text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def answer_legal_question(input_text, question, max_new_tokens, temperature):
prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Define Gradio interface functions
def correct_htr_interface(text, max_new_tokens, temperature):
return correct_htr_text(text, max_new_tokens, temperature)
def summarize_interface(text, max_new_tokens, temperature):
return summarize_legal_text(text, max_new_tokens, temperature)
def question_interface(text, question, max_new_tokens, temperature):
return answer_legal_question(text, question, max_new_tokens, temperature)
def clear_all():
return "", ""
# External clickable buttons
def clickable_buttons():
button_html = """
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary"
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
Admiralty Court Legal Glossary</a>
<a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt"
style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
HCA 13/70 Ground Truth</a>
</div>
"""
return button_html
# Interface layout
with gr.Blocks() as demo:
gr.HTML("<h1>Marinelives mT5-small Legal Assistant</h1>")
gr.HTML(clickable_buttons())
with gr.Tab("Correct Raw HTR"):
input_text = gr.Textbox(lines=10, label="Input Text")
output_text = gr.Textbox(label="Corrected Text")
max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
correct_button = gr.Button("Correct HTR")
clear_button = gr.Button("Clear")
correct_button.click(fn=correct_htr_interface,
inputs=[input_text, max_new_tokens, temperature],
outputs=output_text)
clear_button.click(fn=clear_all, outputs=[input_text, output_text])
with gr.Tab("Summarize Legal Text"):
input_text_summarize = gr.Textbox(lines=10, label="Input Text")
output_text_summarize = gr.Textbox(label="Summary")
max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
summarize_button = gr.Button("Summarize Text")
clear_button_summarize = gr.Button("Clear")
summarize_button.click(fn=summarize_interface,
inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize],
outputs=output_text_summarize)
clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize])
with gr.Tab("Answer Legal Question"):
input_text_question = gr.Textbox(lines=10, label="Input Text")
question = gr.Textbox(label="Question")
output_text_question = gr.Textbox(label="Answer")
max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens")
temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
question_button = gr.Button("Get Answer")
clear_button_question = gr.Button("Clear")
question_button.click(fn=question_interface,
inputs=[input_text_question, question, max_new_tokens_question, temperature_question],
outputs=output_text_question)
clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question])
demo.launch() |