Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
import torch | |
import os | |
# Specify the directory where you'll save your fine-tuned model | |
FINE_TUNED_MODEL_DIR = "./fine_tuned_sms_spam_model" | |
# Load model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("wesleyacheng/sms-spam-classification-with-bert") | |
model = AutoModelForSequenceClassification.from_pretrained("wesleyacheng/sms-spam-classification-with-bert") | |
# Create a Streamlit app | |
st.title("SMS Spam Classification") | |
def classify_spam_or_ham(text): | |
# Tokenize input text | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
# Perform classification | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the predicted label | |
predicted_label = "Spam" if outputs.logits[0][1] > outputs.logits[0][0] else "Not-Spam" | |
return predicted_label | |
st.write("Single SMS Example:") | |
# Function to classify a single SMS | |
def classify_single_sms(text): | |
if isinstance(text, str): # Check if text is a string | |
prediction = classify_spam_or_ham(text) | |
st.write(f"SMS: {text}") | |
st.write(f"Prediction: {prediction}") | |
st.write("--------") | |
else: | |
st.warning("Skipping non-text data.") | |
# Main Streamlit code for CSV file upload | |
st.sidebar.header("Upload CSV File") | |
uploaded_file = st.sidebar.file_uploader("Upload a CSV file with SMS messages:", type=["csv"]) | |
if uploaded_file is not None: | |
st.sidebar.write("Classifying SMS messages in the uploaded file...") | |
try: | |
df = pd.read_csv(uploaded_file, encoding='latin1') # Specify the appropriate encoding | |
except UnicodeDecodeError: | |
st.sidebar.error("Error: Unable to decode the CSV file. Please make sure it is in the correct encoding.") | |
else: | |
# Allow the user to select the column containing SMS messages | |
selected_column = st.sidebar.selectbox("Select the SMS column:", df.columns) | |
if df[selected_column].dtype == "object": | |
st.write("Classifications:") | |
for sms_text in df[selected_column]: | |
classify_single_sms(sms_text) | |
else: | |
st.sidebar.error("Selected column does not contain text data and cannot be tokenized.") | |
st.sidebar.write("Classification completed!") | |
st.sidebar.write("Or classify a single SMS:") | |
user_input = st.sidebar.text_area("Enter an SMS message:") | |
if st.sidebar.button("Classify"): | |
if user_input: | |
classify_single_sms(user_input) | |
else: | |
st.sidebar.warning("Please enter an SMS message.") | |
st.write("Or fine-tune the model:") | |
if st.button("Fine-Tune Model"): | |
if uploaded_file is not None and selected_column and df[selected_column].dtype == "object": | |
# Use the data from the uploaded CSV file as the fine-tuning dataset | |
custom_dataset = df[selected_column] | |
# Specify your fine-tuning training arguments | |
training_args = TrainingArguments( | |
output_dir=FINE_TUNED_MODEL_DIR, | |
overwrite_output_dir=True, | |
per_device_train_batch_size=8, | |
num_train_epochs=3, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=custom_dataset, | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Save the fine-tuned model | |
model.save_pretrained(FINE_TUNED_MODEL_DIR) | |
tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR) | |
st.write("Model has been fine-tuned and saved.") | |
elif not uploaded_file: | |
st.warning("Please upload a CSV file before fine-tuning.") | |
elif not selected_column: | |
st.warning("Please select the SMS column before fine-tuning.") | |
else: | |
st.warning("The selected column does not contain text data and cannot be used for fine-tuning.") | |