--- license: creativeml-openrail-m datasets: - prithivMLmods/Spam-Text-Detect-Analysis language: - en base_model: - google-bert/bert-base-uncased pipeline_tag: text-classification library_name: transformers --- ### **SPAM DETECTION UNCASED [ SPAM / HAM ]** This implementation leverages **BERT (Bidirectional Encoder Representations from Transformers)** for binary classification (Spam / Ham) using sequence classification. The model uses the **`prithivMLmods/Spam-Text-Detect-Analysis` dataset** and integrates **Weights & Biases (wandb)** for comprehensive experiment tracking. --- ## **🛠️ Overview** ### **Core Details:** - **Model:** BERT for sequence classification Pre-trained Model: `bert-base-uncased` - **Task:** Spam detection - Binary classification task (Spam vs Ham). - **Metrics Tracked:** - Accuracy - Precision - Recall - F1 Score - Evaluation loss --- ## **📊 Key Results** Results were obtained using BERT and the provided training dataset: - **Validation Accuracy:** **0.9937** - **Precision:** **0.9931** - **Recall:** **0.9597** - **F1 Score:** **0.9761** --- ## **📈 Model Training Details** ### **Model Architecture:** The model uses `bert-base-uncased` as the pre-trained backbone and is fine-tuned for the sequence classification task. ### **Training Parameters:** - **Learning Rate:** 2e-5 - **Batch Size:** 16 - **Epochs:** 3 - **Loss:** Cross-Entropy --- ## Gradio Build ```python import gradio as gr import torch from transformers import BertTokenizer, BertForSequenceClassification # Load the pre-trained BERT model and tokenizer MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased" tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) model = BertForSequenceClassification.from_pretrained(MODEL_PATH) # Function to predict if a given text is Spam or Ham def predict_spam(text): # Tokenize the input text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Perform inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits prediction = torch.argmax(logits, axis=-1).item() # Map prediction to label if prediction == 1: return "Spam" else: return "Ham" # Gradio UI - Input and Output components inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...") outputs = gr.Label(label="Prediction") # List of example inputs examples = [ ["Win $1000 gift cards now by clicking here!"], ["You have been selected for a lottery."], ["Hello, how was your day?"], ["Earn money without any effort. Click here."], ["Meeting tomorrow at 10 AM. Don't be late."], ["Claim your free prize now!"], ["Are we still on for dinner tonight?"], ["Exclusive offer just for you, act now!"], ["Let's catch up over coffee soon."], ["Congratulations, you've won a new car!"] ] # Create the Gradio interface gr_interface = gr.Interface( fn=predict_spam, inputs=inputs, outputs=outputs, examples=examples, title="Spam Detection with BERT", description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model." ) # Launch the application gr_interface.launch() ``` ### Train Details ```python # Import necessary libraries from datasets import load_dataset, ClassLabel from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments import torch from sklearn.metrics import accuracy_score, precision_recall_fscore_support # Load dataset dataset = load_dataset("prithivMLmods/Spam-Text-Detect-Analysis", split="train") # Encode labels as integers label_mapping = {"ham": 0, "spam": 1} dataset = dataset.map(lambda x: {"label": label_mapping[x["Category"]]}) dataset = dataset.rename_column("Message", "text").remove_columns(["Category"]) # Convert label column to ClassLabel for stratification class_label = ClassLabel(names=["ham", "spam"]) dataset = dataset.cast_column("label", class_label) # Split into train and test dataset = dataset.train_test_split(test_size=0.2, stratify_by_column="label") train_dataset = dataset["train"] test_dataset = dataset["test"] # Load BERT tokenizer tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Tokenize the data def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128) train_dataset = train_dataset.map(tokenize_function, batched=True) test_dataset = test_dataset.map(tokenize_function, batched=True) # Set format for PyTorch train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) # Load pre-trained BERT model model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Define evaluation metric def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = torch.argmax(torch.tensor(predictions), dim=-1) precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary") acc = accuracy_score(labels, predictions) return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1} # Training arguments training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", # Evaluate after every epoch save_strategy="epoch", # Save checkpoint after every epoch learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=3, weight_decay=0.01, logging_dir="./logs", logging_steps=10, load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True ) # Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, compute_metrics=compute_metrics ) # Train the model trainer.train() # Evaluate the model results = trainer.evaluate() print("Evaluation Results:", results) # Save the trained model model.save_pretrained("./saved_model") tokenizer.save_pretrained("./saved_model") # Load the model for inference loaded_model = BertForSequenceClassification.from_pretrained("./saved_model").to(device) loaded_tokenizer = BertTokenizer.from_pretrained("./saved_model") # Test the model on a custom input def predict(text): inputs = loaded_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128) inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to the same device as model outputs = loaded_model(**inputs) prediction = torch.argmax(outputs.logits, dim=-1).item() return "Spam" if prediction == 1 else "Ham" # Example test example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now." print("Prediction:", predict(example_text)) ``` ## **🚀 How to Train the Model** 1. **Clone Repository:** ```bash git clone cd ``` 2. **Install Dependencies:** Install all necessary dependencies. ```bash pip install -r requirements.txt ``` or manually: ```bash pip install transformers datasets wandb scikit-learn ``` 3. **Train the Model:** Assuming you have a script like `train.py`, run: ```python from train import main ``` --- ## **✨ Weights & Biases Integration** ### Why Use wandb? - **Monitor experiments in real time** via visualization. - Log metrics such as loss, accuracy, precision, recall, and F1 score. - Provides a history of past runs and their comparisons. ### Initialize Weights & Biases Include this snippet in your training script: ```python import wandb wandb.init(project="spam-detection") ``` --- ## 📁 **Directory Structure** The directory is organized to ensure scalability and clear separation of components: ``` project-directory/ │ ├── data/ # Dataset processing scripts ├── wandb/ # Logged artifacts from wandb runs ├── results/ # Save training and evaluation results ├── model/ # Trained model checkpoints ├── requirements.txt # List of dependencies └── train.py # Main script for training the model ``` --- ## 🔗 Dataset Information The training dataset comes from **Spam-Text-Detect-Analysis** available on Hugging Face: - **Dataset Link:** [Spam Text Detection Dataset - Hugging Face](https://huggingface.co/datasets) Dataset size: - **5.57k entries** ---