import streamlit as st import pandas as pd from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments import torch import os # Specify the directory where you'll save your fine-tuned model FINE_TUNED_MODEL_DIR = "./fine_tuned_sms_spam_model" # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("wesleyacheng/sms-spam-classification-with-bert") model = AutoModelForSequenceClassification.from_pretrained("wesleyacheng/sms-spam-classification-with-bert") # Create a Streamlit app st.title("SMS Spam Classification") def classify_spam_or_ham(text): # Tokenize input text inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) # Perform classification with torch.no_grad(): outputs = model(**inputs) # Get the predicted label predicted_label = "Spam" if outputs.logits[0][1] > outputs.logits[0][0] else "Not-Spam" return predicted_label st.write("Single SMS Example:") # Function to classify a single SMS def classify_single_sms(text): if isinstance(text, str): # Check if text is a string prediction = classify_spam_or_ham(text) st.write(f"SMS: {text}") st.write(f"Prediction: {prediction}") st.write("--------") else: st.warning("Skipping non-text data.") # Main Streamlit code for CSV file upload st.sidebar.header("Upload CSV File") uploaded_file = st.sidebar.file_uploader("Upload a CSV file with SMS messages:", type=["csv"]) if uploaded_file is not None: st.sidebar.write("Classifying SMS messages in the uploaded file...") try: df = pd.read_csv(uploaded_file, encoding='latin1') # Specify the appropriate encoding except UnicodeDecodeError: st.sidebar.error("Error: Unable to decode the CSV file. Please make sure it is in the correct encoding.") else: # Allow the user to select the column containing SMS messages selected_column = st.sidebar.selectbox("Select the SMS column:", df.columns) if df[selected_column].dtype == "object": st.write("Classifications:") for sms_text in df[selected_column]: classify_single_sms(sms_text) else: st.sidebar.error("Selected column does not contain text data and cannot be tokenized.") st.sidebar.write("Classification completed!") st.sidebar.write("Or classify a single SMS:") user_input = st.sidebar.text_area("Enter an SMS message:") if st.sidebar.button("Classify"): if user_input: classify_single_sms(user_input) else: st.sidebar.warning("Please enter an SMS message.") st.write("Or fine-tune the model:") if st.button("Fine-Tune Model"): if uploaded_file is not None and selected_column and df[selected_column].dtype == "object": # Use the data from the uploaded CSV file as the fine-tuning dataset custom_dataset = df[selected_column] # Specify your fine-tuning training arguments training_args = TrainingArguments( output_dir=FINE_TUNED_MODEL_DIR, overwrite_output_dir=True, per_device_train_batch_size=8, num_train_epochs=3, ) trainer = Trainer( model=model, args=training_args, train_dataset=custom_dataset, ) # Fine-tune the model trainer.train() # Save the fine-tuned model model.save_pretrained(FINE_TUNED_MODEL_DIR) tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR) st.write("Model has been fine-tuned and saved.") elif not uploaded_file: st.warning("Please upload a CSV file before fine-tuning.") elif not selected_column: st.warning("Please select the SMS column before fine-tuning.") else: st.warning("The selected column does not contain text data and cannot be used for fine-tuning.")