File size: 3,909 Bytes
3f7144c
5f27c76
 
 
 
3f7144c
5f27c76
 
3f7144c
5f27c76
 
 
3f7144c
5f27c76
 
3f7144c
5f27c76
 
 
3f7144c
5f27c76
 
 
3f7144c
5f27c76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f26191
5f27c76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f26191
5f27c76
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch
import os

# Specify the directory where you'll save your fine-tuned model
FINE_TUNED_MODEL_DIR = "./fine_tuned_sms_spam_model"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("wesleyacheng/sms-spam-classification-with-bert")
model = AutoModelForSequenceClassification.from_pretrained("wesleyacheng/sms-spam-classification-with-bert")

# Create a Streamlit app
st.title("SMS Spam Classification")

def classify_spam_or_ham(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

    # Perform classification
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the predicted label
    predicted_label = "Spam" if outputs.logits[0][1] > outputs.logits[0][0] else "Not-Spam"

    return predicted_label

st.write("Single SMS Example:")


# Function to classify a single SMS
def classify_single_sms(text):
    if isinstance(text, str):  # Check if text is a string
        prediction = classify_spam_or_ham(text)
        st.write(f"SMS: {text}")
        st.write(f"Prediction: {prediction}")
        st.write("--------")
    else:
        st.warning("Skipping non-text data.")

# Main Streamlit code for CSV file upload
st.sidebar.header("Upload CSV File")
uploaded_file = st.sidebar.file_uploader("Upload a CSV file with SMS messages:", type=["csv"])

if uploaded_file is not None:
    st.sidebar.write("Classifying SMS messages in the uploaded file...")
    try:
        df = pd.read_csv(uploaded_file, encoding='latin1')  # Specify the appropriate encoding
    except UnicodeDecodeError:
        st.sidebar.error("Error: Unable to decode the CSV file. Please make sure it is in the correct encoding.")
    else:
        # Allow the user to select the column containing SMS messages
        selected_column = st.sidebar.selectbox("Select the SMS column:", df.columns)

        if df[selected_column].dtype == "object":
            st.write("Classifications:")
            for sms_text in df[selected_column]:
                classify_single_sms(sms_text)
        else:
            st.sidebar.error("Selected column does not contain text data and cannot be tokenized.")

        st.sidebar.write("Classification completed!")

st.sidebar.write("Or classify a single SMS:")
user_input = st.sidebar.text_area("Enter an SMS message:")
if st.sidebar.button("Classify"):
    if user_input:
        classify_single_sms(user_input)
    else:
        st.sidebar.warning("Please enter an SMS message.")

st.write("Or fine-tune the model:")
if st.button("Fine-Tune Model"):
    if uploaded_file is not None and selected_column and df[selected_column].dtype == "object":
        # Use the data from the uploaded CSV file as the fine-tuning dataset
        custom_dataset = df[selected_column]

        # Specify your fine-tuning training arguments
        training_args = TrainingArguments(
            output_dir=FINE_TUNED_MODEL_DIR,
            overwrite_output_dir=True,
            per_device_train_batch_size=8,
            num_train_epochs=3,
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=custom_dataset,
        )

        # Fine-tune the model
        trainer.train()

        # Save the fine-tuned model
        model.save_pretrained(FINE_TUNED_MODEL_DIR)
        tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR)
        st.write("Model has been fine-tuned and saved.")
    elif not uploaded_file:
        st.warning("Please upload a CSV file before fine-tuning.")
    elif not selected_column:
        st.warning("Please select the SMS column before fine-tuning.")
    else:
        st.warning("The selected column does not contain text data and cannot be used for fine-tuning.")