Files changed (1) hide show
  1. app.py +85 -51
app.py CHANGED
@@ -1,79 +1,113 @@
1
  import gradio as gr
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer
3
 
4
- # Load model and tokenizer
5
- model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
6
- tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
7
 
8
- # Function for Correct Raw HTR
9
- def correct_htr(text, max_new_tokens, temperature):
10
- inputs = tokenizer(text, return_tensors="pt")
 
11
  outputs = model.generate(
12
- **inputs,
13
- max_new_tokens=max_new_tokens,
14
- temperature=temperature,
15
- do_sample=True
16
  )
17
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
18
 
19
- # Function for Summarize Legal Text
20
- def summarize_legal_text(text, max_new_tokens, temperature):
21
- prompt = "summarize: " + text
22
  inputs = tokenizer(prompt, return_tensors="pt")
23
  outputs = model.generate(
24
- **inputs,
25
- max_new_tokens=max_new_tokens,
26
- temperature=temperature,
27
- do_sample=True
28
  )
29
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
31
- # Function for Answer Legal Question
32
- def answer_legal_question(context, question, max_new_tokens, temperature):
33
- prompt = f"question: {question} context: {context}"
34
  inputs = tokenizer(prompt, return_tensors="pt")
35
  outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=max_new_tokens,
38
- temperature=temperature,
39
- do_sample=True
40
  )
41
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
- # Gradio Interface Setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with gr.Blocks() as demo:
45
- # Title and clickable buttons with URLs
46
- gr.Markdown("# Flan-T5 Legal Assistant")
47
 
48
- with gr.Row():
49
- gr.Markdown('[Admiralty Court Legal Glossary](http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary)')
50
- gr.Markdown('[HCA 13/70 Ground Truth](https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt)')
51
-
52
- # Tabs for different functionalities
53
  with gr.Tab("Correct Raw HTR"):
54
- text_input_htr = gr.Textbox(label="Textbox", placeholder="Enter text to correct")
55
- text_output_htr = gr.Textbox(label="Textbox", placeholder="Corrected text will appear here")
56
- max_new_tokens_htr = gr.Slider(10, 512, value=128, label="Max New Tokens")
57
- temperature_htr = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
58
- gr.Button("Correct HTR").click(correct_htr, inputs=[text_input_htr, max_new_tokens_htr, temperature_htr], outputs=text_output_htr)
59
- gr.Button("Clear").click(lambda: "", None, text_input_htr)
 
 
 
 
 
60
 
61
  with gr.Tab("Summarize Legal Text"):
62
- text_input_summarize = gr.Textbox(label="Textbox", placeholder="Enter legal text to summarize")
63
- text_output_summarize = gr.Textbox(label="Textbox", placeholder="Summary will appear here")
64
  max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
65
  temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
66
- gr.Button("Summarize Text").click(summarize_legal_text, inputs=[text_input_summarize, max_new_tokens_summarize, temperature_summarize], outputs=text_output_summarize)
67
- gr.Button("Clear").click(lambda: "", None, text_input_summarize)
 
 
 
 
 
68
 
69
  with gr.Tab("Answer Legal Question"):
70
- context_input = gr.Textbox(label="Textbox", placeholder="Enter legal text for context")
71
- question_input = gr.Textbox(label="Textbox", placeholder="Enter your question")
72
- answer_output = gr.Textbox(label="Textbox", placeholder="Answer will appear here")
73
- max_new_tokens_answer = gr.Slider(10, 512, value=128, label="Max New Tokens")
74
- temperature_answer = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
75
- gr.Button("Get Answer").click(answer_legal_question, inputs=[context_input, question_input, max_new_tokens_answer, temperature_answer], outputs=answer_output)
76
- gr.Button("Clear").click(lambda: "", None, [context_input, question_input])
 
 
 
 
 
 
 
77
 
78
- # Launch the demo
79
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
3
 
4
+ # Load model and tokenizer for mT5-small
5
+ model = T5ForConditionalGeneration.from_pretrained("google/mt5-small")
6
+ tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
7
 
8
+ # Define task-specific prompts
9
+ def correct_htr_text(input_text, max_new_tokens, temperature):
10
+ prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}"
11
+ inputs = tokenizer(prompt, return_tensors="pt")
12
  outputs = model.generate(
13
+ inputs.input_ids,
14
+ max_new_tokens=max_new_tokens,
15
+ temperature=temperature
 
16
  )
17
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
18
 
19
+ def summarize_legal_text(input_text, max_new_tokens, temperature):
20
+ prompt = f"Summarize this legal text: {input_text}"
 
21
  inputs = tokenizer(prompt, return_tensors="pt")
22
  outputs = model.generate(
23
+ inputs.input_ids,
24
+ max_new_tokens=max_new_tokens,
25
+ temperature=temperature
 
26
  )
27
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
29
+ def answer_legal_question(input_text, question, max_new_tokens, temperature):
30
+ prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}"
 
31
  inputs = tokenizer(prompt, return_tensors="pt")
32
  outputs = model.generate(
33
+ inputs.input_ids,
34
+ max_new_tokens=max_new_tokens,
35
+ temperature=temperature
 
36
  )
37
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
39
+ # Define Gradio interface functions
40
+ def correct_htr_interface(text, max_new_tokens, temperature):
41
+ return correct_htr_text(text, max_new_tokens, temperature)
42
+
43
+ def summarize_interface(text, max_new_tokens, temperature):
44
+ return summarize_legal_text(text, max_new_tokens, temperature)
45
+
46
+ def question_interface(text, question, max_new_tokens, temperature):
47
+ return answer_legal_question(text, question, max_new_tokens, temperature)
48
+
49
+ def clear_all():
50
+ return "", ""
51
+
52
+ # External clickable buttons
53
+ def clickable_buttons():
54
+ button_html = """
55
+ <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
56
+ <a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary"
57
+ style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
58
+ Admiralty Court Legal Glossary</a>
59
+ <a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt"
60
+ style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
61
+ HCA 13/70 Ground Truth</a>
62
+ </div>
63
+ """
64
+ return button_html
65
+
66
+ # Interface layout
67
  with gr.Blocks() as demo:
68
+ gr.HTML("<h1>Flan-T5 Legal Assistant</h1>")
69
+ gr.HTML(clickable_buttons())
70
 
 
 
 
 
 
71
  with gr.Tab("Correct Raw HTR"):
72
+ input_text = gr.Textbox(lines=10, label="Textbox")
73
+ output_text = gr.Textbox(label="Textbox")
74
+ max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens")
75
+ temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
76
+ correct_button = gr.Button("Correct HTR")
77
+ clear_button = gr.Button("Clear")
78
+
79
+ correct_button.click(fn=correct_htr_interface,
80
+ inputs=[input_text, max_new_tokens, temperature],
81
+ outputs=output_text)
82
+ clear_button.click(fn=clear_all, outputs=[input_text, output_text])
83
 
84
  with gr.Tab("Summarize Legal Text"):
85
+ input_text_summarize = gr.Textbox(lines=10, label="Textbox")
86
+ output_text_summarize = gr.Textbox(label="Textbox")
87
  max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
88
  temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
89
+ summarize_button = gr.Button("Summarize Text")
90
+ clear_button_summarize = gr.Button("Clear")
91
+
92
+ summarize_button.click(fn=summarize_interface,
93
+ inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize],
94
+ outputs=output_text_summarize)
95
+ clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize])
96
 
97
  with gr.Tab("Answer Legal Question"):
98
+ input_text_question = gr.Textbox(lines=10, label="Textbox")
99
+ question = gr.Textbox(label="Textbox")
100
+ output_text_question = gr.Textbox(label="Textbox")
101
+ max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens")
102
+ temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
103
+ question_button = gr.Button("Get Answer")
104
+ clear_button_question = gr.Button("Clear")
105
+
106
+ question_button.click(fn=question_interface,
107
+ inputs=[input_text_question, question, max_new_tokens_question, temperature_question],
108
+ outputs=output_text_question)
109
+ clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question])
110
+
111
+ gr.Button("Clear", elem_id="clear_button").click(clear_all)
112
 
 
113
  demo.launch()