File size: 4,244 Bytes
f3a6fe0
3c4cefa
 
 
c16cd03
5b0b77b
39f7c6e
5b0b77b
07ccc3e
 
6557c9b
 
b4d1efc
3c4cefa
 
b4d1efc
3c4cefa
edcc84e
b4d1efc
c16cd03
3c4cefa
b4d1efc
 
3c4cefa
 
edcc84e
6557c9b
 
 
edcc84e
3c4cefa
 
b4d1efc
3c4cefa
 
 
86a9050
3c4cefa
 
86a9050
3c4cefa
b4d1efc
 
 
 
 
 
 
 
3c4cefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6557c9b
 
 
 
3c4cefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4d1efc
 
3c4cefa
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pandas as pd
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import chromadb
from chromadb.utils import embedding_functions
import AutoModelForSeq2SeqLM, AutoTokenizer

import gradio as gr

import re 

#######################################################

# Load the email dataset
emails = pd.read_csv("./cleaned_data.csv")

######################################################
client = chromadb.PersistentClient(path="./content")

# Create a ChromaDB client
client = chromadb.Client()
collection = client.create_collection("enron_emails")

# Add documents and IDs to the collection, using ChromaDB's built-in text encoding
collection.add(
    documents=emails["body"].tolist()[:10000],
    ids=emails["file"].tolist()[:10000],
    metadatas=[{"source": "enron_emails"}] * len(emails[:10000]),  # Optional metadata
)


####################################################
# Load model directly
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the trained model
model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("varl42/modello42")

##################################################

# Load the ChromaDB collection
client = chromadb.Client()
collection = client.get_collection("enron_emails")

##################################################

def query_collection(query_text):
    try:
        # Perform the query
        response = collection.query(query_texts=[query_text], n_results=2)

        # Extract documents from the response
        if 'documents' in response and len(response['documents']) > 0:
            # Assuming each query only has one set of responses, hence response['documents'][0]
            documents = response['documents'][0]  # This gets the first (and possibly only) list of documents
            return "\n\n".join(documents)
        else:
            # Handle cases where no documents are found or the structure is unexpected
            return "No documents found or the response structure is not as expected."
    except Exception as e:
        return f"An error occurred while querying: {e}"


def summarize_documents(text_input):
    try:
        # Tokenize input text for the model
        inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
        # Generate a summary with the model
        summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary)  # Ensures that sentences ending in ? ! .
        summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary) 
        
        return summary
    except Exception as e:
        return f"An error occurred while summarizing: {e}"

def query_then_summarize(query_text, _):
    try:
        # Perform the query
        query_results = query_collection(query_text)
        # Return empty summary initially
        return query_results, ""
    except Exception as e:
        return f"An error occurred: {e}", ""

def summarize_from_query(_, query_results):
    try:
        # Use the query results for summarization
        summary = summarize_documents(query_results)
        return query_results, summary
    except Exception as e:
        return query_results, f"An error occurred while summarizing: {e}"


###################################################
        
# Setup the Gradio interface
with gr.Blocks() as app:
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Query")
    query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True)
    summarize_button = gr.Button("Summarize")
    summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...")

    query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output])
    summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output])

app.launch()