File size: 5,795 Bytes
333cd91
acc8063
0eb38a5
333cd91
93d2618
0eb38a5
 
340628c
 
acc8063
 
333cd91
acc8063
fb6b907
93d2618
acc8063
 
93d2618
acc8063
 
fb6b907
9519c42
93d2618
333cd91
acc8063
fb6b907
93d2618
acc8063
 
93d2618
acc8063
 
fb6b907
9519c42
93d2618
333cd91
acc8063
fb6b907
93d2618
acc8063
 
93d2618
acc8063
 
fb6b907
9519c42
93d2618
333cd91
0eb38a5
93d2618
340628c
 
93d2618
77122ee
0eb38a5
 
93d2618
0eb38a5
 
 
 
93d2618
0eb38a5
 
 
 
 
 
77122ee
93d2618
77122ee
 
 
 
 
 
93d2618
 
77122ee
93d2618
77122ee
 
 
 
 
 
93d2618
 
77122ee
93d2618
77122ee
 
 
 
 
 
 
93d2618
 
340628c
77122ee
85abbe0
c16aaec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging

# Setup logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Load the Flan-T5 Small model and tokenizer
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

def correct_htr(raw_htr_text, max_new_tokens, temperature):
    try:
        logging.info("Processing HTR correction...")
        prompt = f"Correct this text: {raw_htr_text}"
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens), temperature=temperature)
        corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return corrected_text
    except Exception as e:
        logging.error(f"Error in HTR correction: {e}", exc_info=True)
        return str(e)

def summarize_text(legal_text, max_new_tokens, temperature):
    try:
        logging.info("Processing summarization...")
        prompt = f"Summarize the following legal text: {legal_text}"
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens), temperature=temperature)
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary
    except Exception as e:
        logging.error(f"Error in summarization: {e}", exc_info=True)
        return str(e)

def answer_question(legal_text, question, max_new_tokens, temperature):
    try:
        logging.info("Processing question-answering...")
        prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens), temperature=temperature)
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return answer
    except Exception as e:
        logging.error(f"Error in question-answering: {e}", exc_info=True)
        return str(e)

# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Flan-T5 Small Legal Assistant")
    gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")
    
    with gr.Row():
        gr.HTML('''
            <div style="display: flex; gap: 10px;">
                <div style="border: 2px solid black; padding: 10px;">
                    <a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" target="_blank">
                        <button style="font-weight:bold;">Admiralty Court Legal Glossary</button>
                    </a>
                </div>
                <div style="border: 2px solid black; padding: 10px;">
                    <a href="https://raw.githubusercontent.com/Addaci/HCA/refs/heads/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" target="_blank">
                        <button style="font-weight:bold;">HCA 13/70 Ground Truth (1654-55)</button>
                    </a>
                </div>
            </div>
        ''')

    # Tab 1: Correct HTR
    with gr.Tab("Correct HTR"):
        gr.Markdown("### Correct Raw HTR Text")
        raw_htr_input = gr.Textbox(lines=5, placeholder="Enter raw HTR text here...")
        corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
        correct_button = gr.Button("Correct HTR")
        clear_button = gr.Button("Clear")
        correct_button.click(correct_htr, inputs=[raw_htr_input, gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")], outputs=corrected_output)
        clear_button.click(lambda: ("", ""), outputs=[raw_htr_input, corrected_output])

    # Tab 2: Summarize Legal Text
    with gr.Tab("Summarize Legal Text"):
        gr.Markdown("### Summarize Legal Text")
        legal_text_input = gr.Textbox(lines=10, placeholder="Enter legal text to summarize...")
        summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
        summarize_button = gr.Button("Summarize Text")
        clear_button = gr.Button("Clear")
        summarize_button.click(summarize_text, inputs=[legal_text_input, gr.Slider(minimum=10, maximum=512, value=256, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Temperature")], outputs=summary_output)
        clear_button.click(lambda: ("", ""), outputs=[legal_text_input, summary_output])

    # Tab 3: Answer Legal Question
    with gr.Tab("Answer Legal Question"):
        gr.Markdown("### Answer a Question Based on Legal Text")
        legal_text_input_q = gr.Textbox(lines=10, placeholder="Enter legal text...")
        question_input = gr.Textbox(lines=2, placeholder="Enter your question...")
        answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
        answer_button = gr.Button("Get Answer")
        clear_button = gr.Button("Clear")
        answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature")], outputs=answer_output)
        clear_button.click(lambda: ("", "", ""), outputs=[legal_text_input_q, question_input, answer_output])

# Launch the Gradio interface
if __name__ == "__main__":
    demo.launch()