import torch from transformers import AutoTokenizer, T5ForConditionalGeneration from datasets import load_dataset import faiss import numpy as np import streamlit as st # Load the datasets from Hugging Face datasets_dict = {} # Function to load datasets safely def load_datasets(): global datasets_dict try: datasets_dict["BillSum"] = load_dataset("billsum") except Exception as e: st.error(f"Error loading BillSum dataset: {e}") try: datasets_dict["EurLex"] = load_dataset("eurlex", trust_remote_code=True) # Set trust_remote_code=True except Exception as e: st.error(f"Error loading EurLex dataset: {e}") # Load datasets at the start load_datasets() # Load the T5 model and tokenizer for summarization t5_tokenizer = AutoTokenizer.from_pretrained("t5-base") t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") # Initialize variables for the selected dataset selected_dataset = "BillSum" documents = [] titles = [] # Prepare the dataset for retrieval based on user selection def prepare_dataset(dataset_name): global documents, titles dataset = datasets_dict[dataset_name] documents = dataset['train']['text'][:100] # Use a subset for demo purposes titles = dataset['train']['title'][:100] # Get corresponding titles # Function for case retrieval and summarization def retrieve_cases(query): # Implement a simple keyword-based search for demo purposes return [(doc, title) for doc, title in zip(documents, titles) if query.lower() in doc.lower()] def summarize_cases(cases): summaries = [] for case in cases: input_ids = t5_tokenizer.encode(case[0], return_tensors="pt", max_length=512, truncation=True) outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) summaries.append(summary) return summaries # Streamlit App Code st.title("Legal Case Summarizer") st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.") # Dropdown for selecting dataset dataset_options = list(datasets_dict.keys()) selected_dataset = st.selectbox("Choose a dataset:", dataset_options) # Prepare the selected dataset prepare_dataset(selected_dataset) query = st.text_input("Enter search keywords:", "healthcare") if st.button("Retrieve and Summarize Cases"): with st.spinner("Retrieving and summarizing cases..."): cases = retrieve_cases(query) if cases: summaries = summarize_cases(cases) for i, (case, title) in enumerate(cases): summary = summaries[i] st.write(f"### Case {i + 1}") st.write(f"**Title:** {title}") st.write(f"**Case Text:** {case[0]}") st.write(f"**Summary:** {summary}") else: st.write("No cases found for the given query.") st.write("Using T5 for summarization and retrieval.")