|
--- |
|
license: creativeml-openrail-m |
|
datasets: |
|
- prithivMLmods/Spam-Text-Detect-Analysis |
|
language: |
|
- en |
|
base_model: |
|
- google-bert/bert-base-uncased |
|
pipeline_tag: text-classification |
|
library_name: transformers |
|
--- |
|
### **SPAM DETECTION UNCASED [ SPAM / HAM ]** |
|
|
|
This implementation leverages **BERT (Bidirectional Encoder Representations from Transformers)** for binary classification (Spam / Ham) using sequence classification. The model uses the **`prithivMLmods/Spam-Text-Detect-Analysis` dataset** and integrates **Weights & Biases (wandb)** for comprehensive experiment tracking. |
|
|
|
--- |
|
|
|
## **π οΈ Overview** |
|
|
|
### **Core Details:** |
|
- **Model:** BERT for sequence classification |
|
Pre-trained Model: `bert-base-uncased` |
|
- **Task:** Spam detection - Binary classification task (Spam vs Ham). |
|
- **Metrics Tracked:** |
|
- Accuracy |
|
- Precision |
|
- Recall |
|
- F1 Score |
|
- Evaluation loss |
|
|
|
--- |
|
|
|
## **π Key Results** |
|
Results were obtained using BERT and the provided training dataset: |
|
|
|
- **Validation Accuracy:** **0.9937** |
|
- **Precision:** **0.9931** |
|
- **Recall:** **0.9597** |
|
- **F1 Score:** **0.9761** |
|
|
|
--- |
|
|
|
## **π Model Training Details** |
|
|
|
### **Model Architecture:** |
|
The model uses `bert-base-uncased` as the pre-trained backbone and is fine-tuned for the sequence classification task. |
|
|
|
### **Training Parameters:** |
|
- **Learning Rate:** 2e-5 |
|
- **Batch Size:** 16 |
|
- **Epochs:** 3 |
|
- **Loss:** Cross-Entropy |
|
|
|
--- |
|
## Gradio Build |
|
|
|
```python |
|
import gradio as gr |
|
import torch |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
# Load the pre-trained BERT model and tokenizer |
|
MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased" |
|
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) |
|
model = BertForSequenceClassification.from_pretrained(MODEL_PATH) |
|
|
|
# Function to predict if a given text is Spam or Ham |
|
def predict_spam(text): |
|
# Tokenize the input text |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
# Perform inference |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
prediction = torch.argmax(logits, axis=-1).item() |
|
|
|
# Map prediction to label |
|
if prediction == 1: |
|
return "Spam" |
|
else: |
|
return "Ham" |
|
|
|
|
|
# Gradio UI - Input and Output components |
|
inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...") |
|
outputs = gr.Label(label="Prediction") |
|
|
|
# List of example inputs |
|
examples = [ |
|
["Win $1000 gift cards now by clicking here!"], |
|
["You have been selected for a lottery."], |
|
["Hello, how was your day?"], |
|
["Earn money without any effort. Click here."], |
|
["Meeting tomorrow at 10 AM. Don't be late."], |
|
["Claim your free prize now!"], |
|
["Are we still on for dinner tonight?"], |
|
["Exclusive offer just for you, act now!"], |
|
["Let's catch up over coffee soon."], |
|
["Congratulations, you've won a new car!"] |
|
] |
|
|
|
# Create the Gradio interface |
|
gr_interface = gr.Interface( |
|
fn=predict_spam, |
|
inputs=inputs, |
|
outputs=outputs, |
|
examples=examples, |
|
title="Spam Detection with BERT", |
|
description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model." |
|
) |
|
|
|
# Launch the application |
|
gr_interface.launch() |
|
|
|
``` |
|
|
|
## **π How to Train the Model** |
|
|
|
1. **Clone Repository:** |
|
```bash |
|
git clone <repository-url> |
|
cd <project-directory> |
|
``` |
|
|
|
2. **Install Dependencies:** |
|
Install all necessary dependencies. |
|
```bash |
|
pip install -r requirements.txt |
|
``` |
|
or manually: |
|
```bash |
|
pip install transformers datasets wandb scikit-learn |
|
``` |
|
|
|
3. **Train the Model:** |
|
Assuming you have a script like `train.py`, run: |
|
```python |
|
from train import main |
|
``` |
|
|
|
--- |
|
|
|
## **β¨ Weights & Biases Integration** |
|
|
|
### Why Use wandb? |
|
- **Monitor experiments in real time** via visualization. |
|
- Log metrics such as loss, accuracy, precision, recall, and F1 score. |
|
- Provides a history of past runs and their comparisons. |
|
|
|
### Initialize Weights & Biases |
|
Include this snippet in your training script: |
|
```python |
|
import wandb |
|
wandb.init(project="spam-detection") |
|
``` |
|
|
|
--- |
|
|
|
## π **Directory Structure** |
|
|
|
The directory is organized to ensure scalability and clear separation of components: |
|
|
|
``` |
|
project-directory/ |
|
β |
|
βββ data/ # Dataset processing scripts |
|
βββ wandb/ # Logged artifacts from wandb runs |
|
βββ results/ # Save training and evaluation results |
|
βββ model/ # Trained model checkpoints |
|
βββ requirements.txt # List of dependencies |
|
βββ train.py # Main script for training the model |
|
``` |
|
|
|
--- |
|
|
|
## π Dataset Information |
|
The training dataset comes from **Spam-Text-Detect-Analysis** available on Hugging Face: |
|
- **Dataset Link:** [Spam Text Detection Dataset - Hugging Face](https://huggingface.co/datasets) |
|
|
|
Dataset size: |
|
- **5.57k entries** |
|
|
|
--- |
|
|
|
Let me know if you need assistance setting up the training pipeline, optimizing metrics, visualizing with wandb, or deploying this fine-tuned model. π |