from transformers import AutoTokenizer, T5ForConditionalGeneration from datasets import load_dataset import faiss import numpy as np import streamlit as st # Load the datasets from Hugging Face datasets_dict = { "BillSum": load_dataset("billsum"), "EurLex": load_dataset("eurlex") } # Load the T5 model and tokenizer for summarization t5_tokenizer = AutoTokenizer.from_pretrained("t5-base") t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") # Initialize variables for the selected dataset selected_dataset = "BillSum" documents = [] titles = [] # Prepare the dataset for retrieval based on user selection def prepare_dataset(dataset_name): global documents, titles dataset = datasets_dict[dataset_name] documents = dataset['train']['text'][:100] # Use a subset for demo purposes titles = dataset['train']['title'][:100] # Get corresponding titles prepare_dataset(selected_dataset) # Function to embed text for retrieval def embed_text(text): input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True) with torch.no_grad(): outputs = t5_model.encoder(input_ids) return outputs.last_hidden_state.mean(dim=1).numpy() # Create embeddings for the documents doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32) # Initialize FAISS index index = faiss.IndexFlatL2(doc_embeddings.shape[1]) index.add(doc_embeddings) # Define functions for retrieving and summarizing cases def retrieve_cases(query, top_k=3): query_embedding = embed_text(query) distances, indices = index.search(query_embedding, top_k) return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles def summarize_cases(cases): summaries = [] for case, _ in cases: input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True) outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) summaries.append(summary) return summaries # Step 3: Streamlit App Code st.title("Legal Case Summarizer") st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.") # Dropdown for selecting dataset dataset_options = list(datasets_dict.keys()) selected_dataset = st.selectbox("Choose a dataset:", dataset_options) # Prepare the selected dataset prepare_dataset(selected_dataset) query = st.text_input("Enter search keywords:", "healthcare") if st.button("Retrieve and Summarize Cases"): with st.spinner("Retrieving and summarizing cases..."): cases = retrieve_cases(query) if cases: summaries = summarize_cases(cases) for i, (case, title) in enumerate(cases): summary = summaries[i] st.write(f"### Case {i + 1}") st.write(f"**Title:** {title}") st.write(f"**Case Text:** {case}") st.write(f"**Summary:** {summary}") else: st.write("No cases found for the given query.") st.write("Using T5 for summarization and retrieval.") import torch from transformers import AutoTokenizer, T5ForConditionalGeneration from datasets import load_dataset import faiss import numpy as np import streamlit as st # Load the datasets from Hugging Face datasets_dict = { "BillSum": load_dataset("billsum"), "EurLex": load_dataset("eurlex") } # Load the T5 model and tokenizer for summarization t5_tokenizer = AutoTokenizer.from_pretrained("t5-base") t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") # Initialize variables for the selected dataset selected_dataset = "BillSum" documents = [] titles = [] # Prepare the dataset for retrieval based on user selection def prepare_dataset(dataset_name): global documents, titles dataset = datasets_dict[dataset_name] documents = dataset['train']['text'][:100] # Use a subset for demo purposes titles = dataset['train']['title'][:100] # Get corresponding titles prepare_dataset(selected_dataset) # Function to embed text for retrieval def embed_text(text): input_ids = t5_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True) with torch.no_grad(): outputs = t5_model.encoder(input_ids) return outputs.last_hidden_state.mean(dim=1).numpy() # Create embeddings for the documents doc_embeddings = np.vstack([embed_text(doc) for doc in documents]).astype(np.float32) # Initialize FAISS index index = faiss.IndexFlatL2(doc_embeddings.shape[1]) index.add(doc_embeddings) # Define functions for retrieving and summarizing cases def retrieve_cases(query, top_k=3): query_embedding = embed_text(query) distances, indices = index.search(query_embedding, top_k) return [(documents[i], titles[i]) for i in indices[0]] # Return documents and their titles def summarize_cases(cases): summaries = [] for case, _ in cases: input_ids = t5_tokenizer.encode(case, return_tensors="pt", max_length=512, truncation=True) outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) summaries.append(summary) return summaries # Step 3: Streamlit App Code st.title("Legal Case Summarizer") st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.") # Dropdown for selecting dataset dataset_options = list(datasets_dict.keys()) selected_dataset = st.selectbox("Choose a dataset:", dataset_options) # Prepare the selected dataset prepare_dataset(selected_dataset) query = st.text_input("Enter search keywords:", "healthcare") if st.button("Retrieve and Summarize Cases"): with st.spinner("Retrieving and summarizing cases..."): cases = retrieve_cases(query) if cases: summaries = summarize_cases(cases) for i, (case, title) in enumerate(cases): summary = summaries[i] st.write(f"### Case {i + 1}") st.write(f"**Title:** {title}") st.write(f"**Case Text:** {case}") st.write(f"**Summary:** {summary}") else: st.write("No cases found for the given query.") st.write("Using T5 for summarization and retrieval.")