File size: 3,044 Bytes
2cd131f
2b6a7bc
 
 
 
 
 
 
2cd131f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b6a7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2aa777
 
 
 
 
 
 
 
 
 
 
 
 
 
87ab71f
2b6a7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2aa777
2b6a7bc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
from datasets import load_dataset
import faiss
import numpy as np
import streamlit as st

# Load the datasets from Hugging Face
datasets_dict = {}

# Function to load datasets safely
def load_datasets():
    global datasets_dict
    try:
        datasets_dict["BillSum"] = load_dataset("billsum")
    except Exception as e:
        st.error(f"Error loading BillSum dataset: {e}")
    
    try:
        datasets_dict["EurLex"] = load_dataset("eurlex", trust_remote_code=True)  # Set trust_remote_code=True
    except Exception as e:
        st.error(f"Error loading EurLex dataset: {e}")

# Load datasets at the start
load_datasets()

# Load the T5 model and tokenizer for summarization
t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")

# Initialize variables for the selected dataset
selected_dataset = "BillSum"
documents = []
titles = []

# Prepare the dataset for retrieval based on user selection
def prepare_dataset(dataset_name):
    global documents, titles
    dataset = datasets_dict[dataset_name]
    documents = dataset['train']['text'][:100]  # Use a subset for demo purposes
    titles = dataset['train']['title'][:100]  # Get corresponding titles

# Function for case retrieval and summarization
def retrieve_cases(query):
    # Implement a simple keyword-based search for demo purposes
    return [(doc, title) for doc, title in zip(documents, titles) if query.lower() in doc.lower()]

def summarize_cases(cases):
    summaries = []
    for case in cases:
        input_ids = t5_tokenizer.encode(case[0], return_tensors="pt", max_length=512, truncation=True)
        outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
        summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
        summaries.append(summary)
    return summaries

# Streamlit App Code
st.title("Legal Case Summarizer")
st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.")

# Dropdown for selecting dataset
dataset_options = list(datasets_dict.keys())
selected_dataset = st.selectbox("Choose a dataset:", dataset_options)

# Prepare the selected dataset
prepare_dataset(selected_dataset)

query = st.text_input("Enter search keywords:", "healthcare")

if st.button("Retrieve and Summarize Cases"):
    with st.spinner("Retrieving and summarizing cases..."):
        cases = retrieve_cases(query)
        if cases:
            summaries = summarize_cases(cases)
            for i, (case, title) in enumerate(cases):
                summary = summaries[i]
                st.write(f"### Case {i + 1}")
                st.write(f"**Title:** {title}")
                st.write(f"**Case Text:** {case[0]}")
                st.write(f"**Summary:** {summary}")
        else:
            st.write("No cases found for the given query.")

st.write("Using T5 for summarization and retrieval.")