Ahmadkhan12 commited on
Commit
87ab71f
1 Parent(s): d5384c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -91
app.py CHANGED
@@ -1,92 +1,3 @@
1
-
2
- from transformers import AutoTokenizer, T5ForConditionalGeneration
3
- from datasets import load_dataset
4
- import faiss
5
- import numpy as np
6
- import streamlit as st
7
-
8
- # Load the datasets from Hugging Face
9
- datasets_dict = {
10
- "BillSum": load_dataset("billsum"),
11
- "EurLex": load_dataset("eurlex")
12
- }
13
-
14
- # Load the T5 model and tokenizer for summarization
15
- t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
16
- t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
17
-
18
- # Initialize variables for the selected dataset
19
- selected_dataset = "BillSum"
20
- documents = []
21
- titles = []
22
-
23
- # Prepare the dataset for retrieval based on user selection
24
- def prepare_dataset(dataset_name):
25
- global documents, titles
26
- dataset = datasets_dict[dataset_name]
27
- documents = dataset['train']['text'][:100] # Use a subset for demo purposes
28
- titles = dataset['train']['title'][:100] # Get corresponding titles
29
-
30
- prepare_dataset(selected_dataset)
31
-
32
- # Function to embed text for retrieval
33
- def embed_text(text):
34
- input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
35
- with torch.no_grad():
36
- outputs = t5_model.encoder(input_ids)
37
- return outputs.last_hidden_state.mean(dim=1).numpy()
38
-
39
- # Create embeddings for the documents
40
- doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32)
41
-
42
- # Initialize FAISS index
43
- index = faiss.IndexFlatL2(doc_embeddings.shape[1])
44
- index.add(doc_embeddings)
45
-
46
- # Define functions for retrieving and summarizing cases
47
- def retrieve_cases(query, top_k=3):
48
- query_embedding = embed_text(query)
49
- distances, indices = index.search(query_embedding, top_k)
50
- return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles
51
-
52
- def summarize_cases(cases):
53
- summaries = []
54
- for case, _ in cases:
55
- input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True)
56
- outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
57
- summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
58
- summaries.append(summary)
59
- return summaries
60
-
61
- # Step 3: Streamlit App Code
62
- st.title("Legal Case Summarizer")
63
- st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
64
-
65
- # Dropdown for selecting dataset
66
- dataset_options = list(datasets_dict.keys())
67
- selected_dataset = st.selectbox("Choose a dataset:", dataset_options)
68
-
69
- # Prepare the selected dataset
70
- prepare_dataset(selected_dataset)
71
-
72
- query = st.text_input("Enter search keywords:", "healthcare")
73
-
74
- if st.button("Retrieve and Summarize Cases"):
75
- with st.spinner("Retrieving and summarizing cases..."):
76
- cases = retrieve_cases(query)
77
- if cases:
78
- summaries = summarize_cases(cases)
79
- for i, (case, title) in enumerate(cases):
80
- summary = summaries[i]
81
- st.write(f"### Case {i + 1}")
82
- st.write(f"**Title:** {title}")
83
- st.write(f"**Case Text:** {case}")
84
- st.write(f"**Summary:** {summary}")
85
- else:
86
- st.write("No cases found for the given query.")
87
-
88
- st.write("Using T5 for summarization and retrieval.")
89
- import torch
90
  from transformers import AutoTokenizer, T5ForConditionalGeneration
91
  from datasets import load_dataset
92
  import faiss
@@ -96,7 +7,7 @@ import streamlit as st
96
  # Load the datasets from Hugging Face
97
  datasets_dict = {
98
  "BillSum": load_dataset("billsum"),
99
- "EurLex": load_dataset("eurlex")
100
  }
101
 
102
  # Load the T5 model and tokenizer for summarization
@@ -146,7 +57,7 @@ def summarize_cases(cases):
146
  summaries.append(summary)
147
  return summaries
148
 
149
- # Step 3: Streamlit App Code
150
  st.title("Legal Case Summarizer")
151
  st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, T5ForConditionalGeneration
2
  from datasets import load_dataset
3
  import faiss
 
7
  # Load the datasets from Hugging Face
8
  datasets_dict = {
9
  "BillSum": load_dataset("billsum"),
10
+ "EurLex": load_dataset("eurlex", trust_remote_code=True) # Set trust_remote_code=True
11
  }
12
 
13
  # Load the T5 model and tokenizer for summarization
 
57
  summaries.append(summary)
58
  return summaries
59
 
60
+ # Streamlit App Code
61
  st.title("Legal Case Summarizer")
62
  st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
63