Spam-Bert-Uncased / README.md
prithivMLmods's picture
Update README.md
e6ad0c0 verified
|
raw
history blame
5.12 kB
metadata
license: creativeml-openrail-m
datasets:
  - prithivMLmods/Spam-Text-Detect-Analysis
language:
  - en
base_model:
  - google-bert/bert-base-uncased
pipeline_tag: text-classification
library_name: transformers

SPAM DETECTION UNCASED [ SPAM / HAM ]

This implementation leverages BERT (Bidirectional Encoder Representations from Transformers) for binary classification (Spam / Ham) using sequence classification. The model uses the prithivMLmods/Spam-Text-Detect-Analysis dataset and integrates Weights & Biases (wandb) for comprehensive experiment tracking.


πŸ› οΈ Overview

Core Details:

  • Model: BERT for sequence classification
    Pre-trained Model: bert-base-uncased
  • Task: Spam detection - Binary classification task (Spam vs Ham).
  • Metrics Tracked:
    • Accuracy
    • Precision
    • Recall
    • F1 Score
    • Evaluation loss

πŸ“Š Key Results

Results were obtained using BERT and the provided training dataset:

  • Validation Accuracy: 0.9937
  • Precision: 0.9931
  • Recall: 0.9597
  • F1 Score: 0.9761

πŸ“ˆ Model Training Details

Model Architecture:

The model uses bert-base-uncased as the pre-trained backbone and is fine-tuned for the sequence classification task.

Training Parameters:

  • Learning Rate: 2e-5
  • Batch Size: 16
  • Epochs: 3
  • Loss: Cross-Entropy

Gradio Build

import gradio as gr
import torch
from transformers import BertTokenizer, BertForSequenceClassification

# Load the pre-trained BERT model and tokenizer
MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)

# Function to predict if a given text is Spam or Ham
def predict_spam(text):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        prediction = torch.argmax(logits, axis=-1).item()
    
    # Map prediction to label
    if prediction == 1:
        return "Spam"
    else:
        return "Ham"


# Gradio UI - Input and Output components
inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...")
outputs = gr.Label(label="Prediction")

# List of example inputs
examples = [
    ["Win $1000 gift cards now by clicking here!"],
    ["You have been selected for a lottery."],
    ["Hello, how was your day?"],
    ["Earn money without any effort. Click here."],
    ["Meeting tomorrow at 10 AM. Don't be late."],
    ["Claim your free prize now!"],
    ["Are we still on for dinner tonight?"],
    ["Exclusive offer just for you, act now!"],
    ["Let's catch up over coffee soon."],
    ["Congratulations, you've won a new car!"]
]

# Create the Gradio interface
gr_interface = gr.Interface(
    fn=predict_spam,
    inputs=inputs,
    outputs=outputs,
    examples=examples,
    title="Spam Detection with BERT",
    description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model."
)

# Launch the application
gr_interface.launch()

πŸš€ How to Train the Model

  1. Clone Repository:

    git clone <repository-url>
    cd <project-directory>
    
  2. Install Dependencies: Install all necessary dependencies.

    pip install -r requirements.txt
    

    or manually:

    pip install transformers datasets wandb scikit-learn
    
  3. Train the Model: Assuming you have a script like train.py, run:

    from train import main
    

✨ Weights & Biases Integration

Why Use wandb?

  • Monitor experiments in real time via visualization.
  • Log metrics such as loss, accuracy, precision, recall, and F1 score.
  • Provides a history of past runs and their comparisons.

Initialize Weights & Biases

Include this snippet in your training script:

import wandb
wandb.init(project="spam-detection")

πŸ“ Directory Structure

The directory is organized to ensure scalability and clear separation of components:

project-directory/
β”‚
β”œβ”€β”€ data/                # Dataset processing scripts
β”œβ”€β”€ wandb/              # Logged artifacts from wandb runs
β”œβ”€β”€ results/            # Save training and evaluation results
β”œβ”€β”€ model/              # Trained model checkpoints
β”œβ”€β”€ requirements.txt    # List of dependencies
└── train.py            # Main script for training the model

πŸ”— Dataset Information

The training dataset comes from Spam-Text-Detect-Analysis available on Hugging Face:

Dataset size:

  • 5.57k entries

Let me know if you need assistance setting up the training pipeline, optimizing metrics, visualizing with wandb, or deploying this fine-tuned model. πŸš€