import pandas as pd import nltk nltk.download('punkt') from nltk.tokenize import sent_tokenize import chromadb from chromadb.utils import embedding_functions from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import gradio as gr import re ####################################################### # Load the email dataset emails = pd.read_csv("./cleaned_data.csv") ###################################################### client = chromadb.PersistentClient(path="./content") # Create a ChromaDB client client = chromadb.Client() collection = client.create_collection("enron_emails") # Add documents and IDs to the collection, using ChromaDB's built-in text encoding collection.add( documents=emails["body"].tolist()[:10000], ids=emails["file"].tolist()[:10000], metadatas=[{"source": "enron_emails"}] * len(emails[:10000]), # Optional metadata ) #################################################### # Load model directly from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Load the trained model model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained("varl42/modello42") ################################################## # Load the ChromaDB collection client = chromadb.Client() collection = client.get_collection("enron_emails") ################################################## def query_collection(query_text): try: # Perform the query response = collection.query(query_texts=[query_text], n_results=2) # Extract documents from the response if 'documents' in response and len(response['documents']) > 0: # Assuming each query only has one set of responses, hence response['documents'][0] documents = response['documents'][0] # This gets the first (and possibly only) list of documents return "\n\n".join(documents) else: # Handle cases where no documents are found or the structure is unexpected return "No documents found or the response structure is not as expected." except Exception as e: return f"An error occurred while querying: {e}" def summarize_documents(text_input): try: # Tokenize input text for the model inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) # Generate a summary with the model summary_ids = model.generate(inputs['input_ids'], max_length=512, min_length=125, length_penalty=2.0, num_beams=4, early_stopping=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary) # Ensures that sentences ending in ? ! . summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary) return summary except Exception as e: return f"An error occurred while summarizing: {e}" def query_then_summarize(query_text, _): try: # Perform the query query_results = query_collection(query_text) # Return empty summary initially return query_results, "" except Exception as e: return f"An error occurred: {e}", "" def summarize_from_query(_, query_results): try: # Use the query results for summarization summary = summarize_documents(query_results) return query_results, summary except Exception as e: return query_results, f"An error occurred while summarizing: {e}" ################################################### # Setup the Gradio interface with gr.Blocks() as app: with gr.Row(): query_input = gr.Textbox(label="Enter your query") query_button = gr.Button("Query") query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True) summarize_button = gr.Button("Summarize") summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...") query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output]) summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output]) app.launch()