Ahmadkhan12 commited on
Commit
2b6a7bc
·
verified ·
1 Parent(s): 7c2f7da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
3
+ from datasets import load_dataset
4
+ import faiss
5
+ import numpy as np
6
+ import streamlit as st
7
+
8
+ # Load the datasets from Hugging Face
9
+ datasets_dict = {
10
+ "BillSum": load_dataset("billsum"),
11
+ "EurLex": load_dataset("eurlex")
12
+ }
13
+
14
+ # Load the T5 model and tokenizer for summarization
15
+ t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
16
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
17
+
18
+ # Initialize variables for the selected dataset
19
+ selected_dataset = "BillSum"
20
+ documents = []
21
+ titles = []
22
+
23
+ # Prepare the dataset for retrieval based on user selection
24
+ def prepare_dataset(dataset_name):
25
+ global documents, titles
26
+ dataset = datasets_dict[dataset_name]
27
+ documents = dataset['train']['text'][:100] # Use a subset for demo purposes
28
+ titles = dataset['train']['title'][:100] # Get corresponding titles
29
+
30
+ prepare_dataset(selected_dataset)
31
+
32
+ # Function to embed text for retrieval
33
+ def embed_text(text):
34
+ input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
35
+ with torch.no_grad():
36
+ outputs = t5_model.encoder(input_ids)
37
+ return outputs.last_hidden_state.mean(dim=1).numpy()
38
+
39
+ # Create embeddings for the documents
40
+ doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32)
41
+
42
+ # Initialize FAISS index
43
+ index = faiss.IndexFlatL2(doc_embeddings.shape[1])
44
+ index.add(doc_embeddings)
45
+
46
+ # Define functions for retrieving and summarizing cases
47
+ def retrieve_cases(query, top_k=3):
48
+ query_embedding = embed_text(query)
49
+ distances, indices = index.search(query_embedding, top_k)
50
+ return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles
51
+
52
+ def summarize_cases(cases):
53
+ summaries = []
54
+ for case, _ in cases:
55
+ input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True)
56
+ outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
57
+ summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+ summaries.append(summary)
59
+ return summaries
60
+
61
+ # Step 3: Streamlit App Code
62
+ st.title("Legal Case Summarizer")
63
+ st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
64
+
65
+ # Dropdown for selecting dataset
66
+ dataset_options = list(datasets_dict.keys())
67
+ selected_dataset = st.selectbox("Choose a dataset:", dataset_options)
68
+
69
+ # Prepare the selected dataset
70
+ prepare_dataset(selected_dataset)
71
+
72
+ query = st.text_input("Enter search keywords:", "healthcare")
73
+
74
+ if st.button("Retrieve and Summarize Cases"):
75
+ with st.spinner("Retrieving and summarizing cases..."):
76
+ cases = retrieve_cases(query)
77
+ if cases:
78
+ summaries = summarize_cases(cases)
79
+ for i, (case, title) in enumerate(cases):
80
+ summary = summaries[i]
81
+ st.write(f"### Case {i + 1}")
82
+ st.write(f"**Title:** {title}")
83
+ st.write(f"**Case Text:** {case}")
84
+ st.write(f"**Summary:** {summary}")
85
+ else:
86
+ st.write("No cases found for the given query.")
87
+
88
+ st.write("Using T5 for summarization and retrieval.")
89
+ import torch
90
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
91
+ from datasets import load_dataset
92
+ import faiss
93
+ import numpy as np
94
+ import streamlit as st
95
+
96
+ # Load the datasets from Hugging Face
97
+ datasets_dict = {
98
+ "BillSum": load_dataset("billsum"),
99
+ "EurLex": load_dataset("eurlex")
100
+ }
101
+
102
+ # Load the T5 model and tokenizer for summarization
103
+ t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
104
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
105
+
106
+ # Initialize variables for the selected dataset
107
+ selected_dataset = "BillSum"
108
+ documents = []
109
+ titles = []
110
+
111
+ # Prepare the dataset for retrieval based on user selection
112
+ def prepare_dataset(dataset_name):
113
+ global documents, titles
114
+ dataset = datasets_dict[dataset_name]
115
+ documents = dataset['train']['text'][:100] # Use a subset for demo purposes
116
+ titles = dataset['train']['title'][:100] # Get corresponding titles
117
+
118
+ prepare_dataset(selected_dataset)
119
+
120
+ # Function to embed text for retrieval
121
+ def embed_text(text):
122
+ input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
123
+ with torch.no_grad():
124
+ outputs = t5_model.encoder(input_ids)
125
+ return outputs.last_hidden_state.mean(dim=1).numpy()
126
+
127
+ # Create embeddings for the documents
128
+ doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32)
129
+
130
+ # Initialize FAISS index
131
+ index = faiss.IndexFlatL2(doc_embeddings.shape[1])
132
+ index.add(doc_embeddings)
133
+
134
+ # Define functions for retrieving and summarizing cases
135
+ def retrieve_cases(query, top_k=3):
136
+ query_embedding = embed_text(query)
137
+ distances, indices = index.search(query_embedding, top_k)
138
+ return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles
139
+
140
+ def summarize_cases(cases):
141
+ summaries = []
142
+ for case, _ in cases:
143
+ input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True)
144
+ outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
145
+ summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+ summaries.append(summary)
147
+ return summaries
148
+
149
+ # Step 3: Streamlit App Code
150
+ st.title("Legal Case Summarizer")
151
+ st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
152
+
153
+ # Dropdown for selecting dataset
154
+ dataset_options = list(datasets_dict.keys())
155
+ selected_dataset = st.selectbox("Choose a dataset:", dataset_options)
156
+
157
+ # Prepare the selected dataset
158
+ prepare_dataset(selected_dataset)
159
+
160
+ query = st.text_input("Enter search keywords:", "healthcare")
161
+
162
+ if st.button("Retrieve and Summarize Cases"):
163
+ with st.spinner("Retrieving and summarizing cases..."):
164
+ cases = retrieve_cases(query)
165
+ if cases:
166
+ summaries = summarize_cases(cases)
167
+ for i, (case, title) in enumerate(cases):
168
+ summary = summaries[i]
169
+ st.write(f"### Case {i + 1}")
170
+ st.write(f"**Title:** {title}")
171
+ st.write(f"**Case Text:** {case}")
172
+ st.write(f"**Summary:** {summary}")
173
+ else:
174
+ st.write("No cases found for the given query.")
175
+
176
+ st.write("Using T5 for summarization and retrieval.")