File size: 4,244 Bytes
f3a6fe0 3c4cefa c16cd03 5b0b77b 39f7c6e 5b0b77b 07ccc3e 6557c9b b4d1efc 3c4cefa b4d1efc 3c4cefa edcc84e b4d1efc c16cd03 3c4cefa b4d1efc 3c4cefa edcc84e 6557c9b edcc84e 3c4cefa b4d1efc 3c4cefa 86a9050 3c4cefa 86a9050 3c4cefa b4d1efc 3c4cefa 6557c9b 3c4cefa b4d1efc 3c4cefa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import pandas as pd
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import chromadb
from chromadb.utils import embedding_functions
import AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
import re
#######################################################
# Load the email dataset
emails = pd.read_csv("./cleaned_data.csv")
######################################################
client = chromadb.PersistentClient(path="./content")
# Create a ChromaDB client
client = chromadb.Client()
collection = client.create_collection("enron_emails")
# Add documents and IDs to the collection, using ChromaDB's built-in text encoding
collection.add(
documents=emails["body"].tolist()[:10000],
ids=emails["file"].tolist()[:10000],
metadatas=[{"source": "enron_emails"}] * len(emails[:10000]), # Optional metadata
)
####################################################
# Load model directly
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the trained model
model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("varl42/modello42")
##################################################
# Load the ChromaDB collection
client = chromadb.Client()
collection = client.get_collection("enron_emails")
##################################################
def query_collection(query_text):
try:
# Perform the query
response = collection.query(query_texts=[query_text], n_results=2)
# Extract documents from the response
if 'documents' in response and len(response['documents']) > 0:
# Assuming each query only has one set of responses, hence response['documents'][0]
documents = response['documents'][0] # This gets the first (and possibly only) list of documents
return "\n\n".join(documents)
else:
# Handle cases where no documents are found or the structure is unexpected
return "No documents found or the response structure is not as expected."
except Exception as e:
return f"An error occurred while querying: {e}"
def summarize_documents(text_input):
try:
# Tokenize input text for the model
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
# Generate a summary with the model
summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary) # Ensures that sentences ending in ? ! .
summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary)
return summary
except Exception as e:
return f"An error occurred while summarizing: {e}"
def query_then_summarize(query_text, _):
try:
# Perform the query
query_results = query_collection(query_text)
# Return empty summary initially
return query_results, ""
except Exception as e:
return f"An error occurred: {e}", ""
def summarize_from_query(_, query_results):
try:
# Use the query results for summarization
summary = summarize_documents(query_results)
return query_results, summary
except Exception as e:
return query_results, f"An error occurred while summarizing: {e}"
###################################################
# Setup the Gradio interface
with gr.Blocks() as app:
with gr.Row():
query_input = gr.Textbox(label="Enter your query")
query_button = gr.Button("Query")
query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True)
summarize_button = gr.Button("Summarize")
summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...")
query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output])
summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output])
app.launch()
|