File size: 5,260 Bytes
333cd91
fc8771f
333cd91
fc8771f
6c8e814
 
0eb38a5
74f4d51
 
 
 
df63252
74f4d51
 
 
df63252
 
333cd91
74f4d51
 
df63252
 
74f4d51
 
 
df63252
 
333cd91
74f4d51
 
df63252
 
74f4d51
 
 
df63252
 
333cd91
74f4d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93d2618
fc8771f
74f4d51
df63252
 
fc8771f
 
74f4d51
 
 
 
 
 
 
 
 
1b938f5
df63252
fc8771f
 
df63252
 
74f4d51
 
 
 
 
 
 
77122ee
 
fc8771f
 
 
74f4d51
 
 
 
 
 
 
 
 
 
fc8771f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("Addaci/mT5-small-experiment-13-checkpoint-2790")
tokenizer = T5Tokenizer.from_pretrained("Addaci/mT5-small-experiment-13-checkpoint-2790")

# Define task-specific prompts
def correct_htr_text(input_text, max_new_tokens, temperature):
    prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def summarize_legal_text(input_text, max_new_tokens, temperature):
    prompt = f"Summarize this legal text: {input_text}"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def answer_legal_question(input_text, question, max_new_tokens, temperature):
    prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Define Gradio interface functions
def correct_htr_interface(text, max_new_tokens, temperature):
    return correct_htr_text(text, max_new_tokens, temperature)

def summarize_interface(text, max_new_tokens, temperature):
    return summarize_legal_text(text, max_new_tokens, temperature)

def question_interface(text, question, max_new_tokens, temperature):
    return answer_legal_question(text, question, max_new_tokens, temperature)

def clear_all():
    return "", ""

# External clickable buttons
def clickable_buttons():
    button_html = """
    <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
        <a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" 
        style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
        Admiralty Court Legal Glossary</a>
        <a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" 
        style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
        HCA 13/70 Ground Truth</a>
    </div>
    """
    return button_html

# Interface layout
with gr.Blocks() as demo:
    gr.HTML("<h1>Marinelives mT5-small Legal Assistant</h1>")
    gr.HTML(clickable_buttons())
    
    with gr.Tab("Correct Raw HTR"):
        input_text = gr.Textbox(lines=10, label="Input Text")
        output_text = gr.Textbox(label="Corrected Text")
        max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens")
        temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
        correct_button = gr.Button("Correct HTR")
        clear_button = gr.Button("Clear")
        
        correct_button.click(fn=correct_htr_interface, 
                             inputs=[input_text, max_new_tokens, temperature], 
                             outputs=output_text)
        clear_button.click(fn=clear_all, outputs=[input_text, output_text])

    with gr.Tab("Summarize Legal Text"):
        input_text_summarize = gr.Textbox(lines=10, label="Input Text")
        output_text_summarize = gr.Textbox(label="Summary")
        max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
        temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
        summarize_button = gr.Button("Summarize Text")
        clear_button_summarize = gr.Button("Clear")
        
        summarize_button.click(fn=summarize_interface, 
                               inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize], 
                               outputs=output_text_summarize)
        clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize])

    with gr.Tab("Answer Legal Question"):
        input_text_question = gr.Textbox(lines=10, label="Input Text")
        question = gr.Textbox(label="Question")
        output_text_question = gr.Textbox(label="Answer")
        max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens")
        temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
        question_button = gr.Button("Get Answer")
        clear_button_question = gr.Button("Clear")
        
        question_button.click(fn=question_interface, 
                              inputs=[input_text_question, question, max_new_tokens_question, temperature_question], 
                              outputs=output_text_question)
        clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question])

demo.launch()