Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
from datasets import load_dataset | |
import faiss | |
import numpy as np | |
import streamlit as st | |
# Load the datasets from Hugging Face | |
datasets_dict = { | |
"BillSum": load_dataset("billsum"), | |
"EurLex": load_dataset("eurlex") | |
} | |
# Load the T5 model and tokenizer for summarization | |
t5_tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
# Initialize variables for the selected dataset | |
selected_dataset = "BillSum" | |
documents = [] | |
titles = [] | |
# Prepare the dataset for retrieval based on user selection | |
def prepare_dataset(dataset_name): | |
global documents, titles | |
dataset = datasets_dict[dataset_name] | |
documents = dataset['train']['text'][:100] # Use a subset for demo purposes | |
titles = dataset['train']['title'][:100] # Get corresponding titles | |
prepare_dataset(selected_dataset) | |
# Function to embed text for retrieval | |
def embed_text(text): | |
input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True) | |
with torch.no_grad(): | |
outputs = t5_model.encoder(input_ids) | |
return outputs.last_hidden_state.mean(dim=1).numpy() | |
# Create embeddings for the documents | |
doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32) | |
# Initialize FAISS index | |
index = faiss.IndexFlatL2(doc_embeddings.shape[1]) | |
index.add(doc_embeddings) | |
# Define functions for retrieving and summarizing cases | |
def retrieve_cases(query, top_k=3): | |
query_embedding = embed_text(query) | |
distances, indices = index.search(query_embedding, top_k) | |
return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles | |
def summarize_cases(cases): | |
summaries = [] | |
for case, _ in cases: | |
input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True) | |
outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
summaries.append(summary) | |
return summaries | |
# Step 3: Streamlit App Code | |
st.title("Legal Case Summarizer") | |
st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.") | |
# Dropdown for selecting dataset | |
dataset_options = list(datasets_dict.keys()) | |
selected_dataset = st.selectbox("Choose a dataset:", dataset_options) | |
# Prepare the selected dataset | |
prepare_dataset(selected_dataset) | |
query = st.text_input("Enter search keywords:", "healthcare") | |
if st.button("Retrieve and Summarize Cases"): | |
with st.spinner("Retrieving and summarizing cases..."): | |
cases = retrieve_cases(query) | |
if cases: | |
summaries = summarize_cases(cases) | |
for i, (case, title) in enumerate(cases): | |
summary = summaries[i] | |
st.write(f"### Case {i + 1}") | |
st.write(f"**Title:** {title}") | |
st.write(f"**Case Text:** {case}") | |
st.write(f"**Summary:** {summary}") | |
else: | |
st.write("No cases found for the given query.") | |
st.write("Using T5 for summarization and retrieval.") | |
import torch | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
from datasets import load_dataset | |
import faiss | |
import numpy as np | |
import streamlit as st | |
# Load the datasets from Hugging Face | |
datasets_dict = { | |
"BillSum": load_dataset("billsum"), | |
"EurLex": load_dataset("eurlex") | |
} | |
# Load the T5 model and tokenizer for summarization | |
t5_tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
# Initialize variables for the selected dataset | |
selected_dataset = "BillSum" | |
documents = [] | |
titles = [] | |
# Prepare the dataset for retrieval based on user selection | |
def prepare_dataset(dataset_name): | |
global documents, titles | |
dataset = datasets_dict[dataset_name] | |
documents = dataset['train']['text'][:100] # Use a subset for demo purposes | |
titles = dataset['train']['title'][:100] # Get corresponding titles | |
prepare_dataset(selected_dataset) | |
# Function to embed text for retrieval | |
def embed_text(text): | |
input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True) | |
with torch.no_grad(): | |
outputs = t5_model.encoder(input_ids) | |
return outputs.last_hidden_state.mean(dim=1).numpy() | |
# Create embeddings for the documents | |
doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32) | |
# Initialize FAISS index | |
index = faiss.IndexFlatL2(doc_embeddings.shape[1]) | |
index.add(doc_embeddings) | |
# Define functions for retrieving and summarizing cases | |
def retrieve_cases(query, top_k=3): | |
query_embedding = embed_text(query) | |
distances, indices = index.search(query_embedding, top_k) | |
return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles | |
def summarize_cases(cases): | |
summaries = [] | |
for case, _ in cases: | |
input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True) | |
outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
summaries.append(summary) | |
return summaries | |
# Step 3: Streamlit App Code | |
st.title("Legal Case Summarizer") | |
st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.") | |
# Dropdown for selecting dataset | |
dataset_options = list(datasets_dict.keys()) | |
selected_dataset = st.selectbox("Choose a dataset:", dataset_options) | |
# Prepare the selected dataset | |
prepare_dataset(selected_dataset) | |
query = st.text_input("Enter search keywords:", "healthcare") | |
if st.button("Retrieve and Summarize Cases"): | |
with st.spinner("Retrieving and summarizing cases..."): | |
cases = retrieve_cases(query) | |
if cases: | |
summaries = summarize_cases(cases) | |
for i, (case, title) in enumerate(cases): | |
summary = summaries[i] | |
st.write(f"### Case {i + 1}") | |
st.write(f"**Title:** {title}") | |
st.write(f"**Case Text:** {case}") | |
st.write(f"**Summary:** {summary}") | |
else: | |
st.write("No cases found for the given query.") | |
st.write("Using T5 for summarization and retrieval.") | |