legalspace / app.py
Ahmadkhan12's picture
Create app.py
2b6a7bc verified
raw
history blame
6.49 kB
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
from datasets import load_dataset
import faiss
import numpy as np
import streamlit as st
# Load the datasets from Hugging Face
datasets_dict = {
"BillSum": load_dataset("billsum"),
"EurLex": load_dataset("eurlex")
}
# Load the T5 model and tokenizer for summarization
t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
# Initialize variables for the selected dataset
selected_dataset = "BillSum"
documents = []
titles = []
# Prepare the dataset for retrieval based on user selection
def prepare_dataset(dataset_name):
global documents, titles
dataset = datasets_dict[dataset_name]
documents = dataset['train']['text'][:100] # Use a subset for demo purposes
titles = dataset['train']['title'][:100] # Get corresponding titles
prepare_dataset(selected_dataset)
# Function to embed text for retrieval
def embed_text(text):
input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = t5_model.encoder(input_ids)
return outputs.last_hidden_state.mean(dim=1).numpy()
# Create embeddings for the documents
doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32)
# Initialize FAISS index
index = faiss.IndexFlatL2(doc_embeddings.shape[1])
index.add(doc_embeddings)
# Define functions for retrieving and summarizing cases
def retrieve_cases(query, top_k=3):
query_embedding = embed_text(query)
distances, indices = index.search(query_embedding, top_k)
return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles
def summarize_cases(cases):
summaries = []
for case, _ in cases:
input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True)
outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
# Step 3: Streamlit App Code
st.title("Legal Case Summarizer")
st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
# Dropdown for selecting dataset
dataset_options = list(datasets_dict.keys())
selected_dataset = st.selectbox("Choose a dataset:", dataset_options)
# Prepare the selected dataset
prepare_dataset(selected_dataset)
query = st.text_input("Enter search keywords:", "healthcare")
if st.button("Retrieve and Summarize Cases"):
with st.spinner("Retrieving and summarizing cases..."):
cases = retrieve_cases(query)
if cases:
summaries = summarize_cases(cases)
for i, (case, title) in enumerate(cases):
summary = summaries[i]
st.write(f"### Case {i + 1}")
st.write(f"**Title:** {title}")
st.write(f"**Case Text:** {case}")
st.write(f"**Summary:** {summary}")
else:
st.write("No cases found for the given query.")
st.write("Using T5 for summarization and retrieval.")
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
from datasets import load_dataset
import faiss
import numpy as np
import streamlit as st
# Load the datasets from Hugging Face
datasets_dict = {
"BillSum": load_dataset("billsum"),
"EurLex": load_dataset("eurlex")
}
# Load the T5 model and tokenizer for summarization
t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
# Initialize variables for the selected dataset
selected_dataset = "BillSum"
documents = []
titles = []
# Prepare the dataset for retrieval based on user selection
def prepare_dataset(dataset_name):
global documents, titles
dataset = datasets_dict[dataset_name]
documents = dataset['train']['text'][:100] # Use a subset for demo purposes
titles = dataset['train']['title'][:100] # Get corresponding titles
prepare_dataset(selected_dataset)
# Function to embed text for retrieval
def embed_text(text):
input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = t5_model.encoder(input_ids)
return outputs.last_hidden_state.mean(dim=1).numpy()
# Create embeddings for the documents
doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32)
# Initialize FAISS index
index = faiss.IndexFlatL2(doc_embeddings.shape[1])
index.add(doc_embeddings)
# Define functions for retrieving and summarizing cases
def retrieve_cases(query, top_k=3):
query_embedding = embed_text(query)
distances, indices = index.search(query_embedding, top_k)
return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles
def summarize_cases(cases):
summaries = []
for case, _ in cases:
input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True)
outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
# Step 3: Streamlit App Code
st.title("Legal Case Summarizer")
st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")
# Dropdown for selecting dataset
dataset_options = list(datasets_dict.keys())
selected_dataset = st.selectbox("Choose a dataset:", dataset_options)
# Prepare the selected dataset
prepare_dataset(selected_dataset)
query = st.text_input("Enter search keywords:", "healthcare")
if st.button("Retrieve and Summarize Cases"):
with st.spinner("Retrieving and summarizing cases..."):
cases = retrieve_cases(query)
if cases:
summaries = summarize_cases(cases)
for i, (case, title) in enumerate(cases):
summary = summaries[i]
st.write(f"### Case {i + 1}")
st.write(f"**Title:** {title}")
st.write(f"**Case Text:** {case}")
st.write(f"**Summary:** {summary}")
else:
st.write("No cases found for the given query.")
st.write("Using T5 for summarization and retrieval.")