prithivMLmods
commited on
Commit
β’
250c203
1
Parent(s):
ac00149
Update README.md
Browse files
README.md
CHANGED
@@ -9,12 +9,12 @@ base_model:
|
|
9 |
pipeline_tag: text-classification
|
10 |
library_name: transformers
|
11 |
---
|
12 |
-
|
13 |
|
14 |
-
This implementation
|
15 |
|
16 |
---
|
17 |
-
|
18 |
|
19 |
| **File Name** | **Size** | **Description** | **Upload Status** |
|
20 |
|------------------------------------|-----------|-----------------------------------------------------|-------------------|
|
@@ -49,9 +49,7 @@ Results were obtained using BERT and the provided training dataset:
|
|
49 |
- **Precision:** **0.9931**
|
50 |
- **Recall:** **0.9597**
|
51 |
- **F1 Score:** **0.9761**
|
52 |
-
|
53 |
---
|
54 |
-
|
55 |
## **π Model Training Details**
|
56 |
|
57 |
### **Model Architecture:**
|
@@ -62,74 +60,27 @@ The model uses `bert-base-uncased` as the pre-trained backbone and is fine-tuned
|
|
62 |
- **Batch Size:** 16
|
63 |
- **Epochs:** 3
|
64 |
- **Loss:** Cross-Entropy
|
65 |
-
|
66 |
---
|
67 |
-
##
|
68 |
-
|
69 |
-
```python
|
70 |
-
import gradio as gr
|
71 |
-
import torch
|
72 |
-
from transformers import BertTokenizer, BertForSequenceClassification
|
73 |
-
|
74 |
-
# Load the pre-trained BERT model and tokenizer
|
75 |
-
MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased"
|
76 |
-
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
|
77 |
-
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
|
78 |
-
|
79 |
-
# Function to predict if a given text is Spam or Ham
|
80 |
-
def predict_spam(text):
|
81 |
-
# Tokenize the input text
|
82 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
83 |
-
|
84 |
-
# Perform inference
|
85 |
-
with torch.no_grad():
|
86 |
-
outputs = model(**inputs)
|
87 |
-
logits = outputs.logits
|
88 |
-
prediction = torch.argmax(logits, axis=-1).item()
|
89 |
-
|
90 |
-
# Map prediction to label
|
91 |
-
if prediction == 1:
|
92 |
-
return "Spam"
|
93 |
-
else:
|
94 |
-
return "Ham"
|
95 |
-
|
96 |
-
|
97 |
-
# Gradio UI - Input and Output components
|
98 |
-
inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...")
|
99 |
-
outputs = gr.Label(label="Prediction")
|
100 |
-
|
101 |
-
# List of example inputs
|
102 |
-
examples = [
|
103 |
-
["Win $1000 gift cards now by clicking here!"],
|
104 |
-
["You have been selected for a lottery."],
|
105 |
-
["Hello, how was your day?"],
|
106 |
-
["Earn money without any effort. Click here."],
|
107 |
-
["Meeting tomorrow at 10 AM. Don't be late."],
|
108 |
-
["Claim your free prize now!"],
|
109 |
-
["Are we still on for dinner tonight?"],
|
110 |
-
["Exclusive offer just for you, act now!"],
|
111 |
-
["Let's catch up over coffee soon."],
|
112 |
-
["Congratulations, you've won a new car!"]
|
113 |
-
]
|
114 |
-
|
115 |
-
# Create the Gradio interface
|
116 |
-
gr_interface = gr.Interface(
|
117 |
-
fn=predict_spam,
|
118 |
-
inputs=inputs,
|
119 |
-
outputs=outputs,
|
120 |
-
examples=examples,
|
121 |
-
title="Spam Detection with BERT",
|
122 |
-
description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model."
|
123 |
-
)
|
124 |
-
|
125 |
-
# Launch the application
|
126 |
-
gr_interface.launch()
|
127 |
|
|
|
|
|
|
|
|
|
128 |
```
|
129 |
-
### Train Details
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
```python
|
132 |
-
|
133 |
# Import necessary libraries
|
134 |
from datasets import load_dataset, ClassLabel
|
135 |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
|
@@ -235,33 +186,14 @@ def predict(text):
|
|
235 |
example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."
|
236 |
print("Prediction:", predict(example_text))
|
237 |
```
|
238 |
-
|
239 |
-
## **π How to Train the Model**
|
240 |
-
|
241 |
-
1. **Clone Repository:**
|
242 |
-
```bash
|
243 |
-
git clone <repository-url>
|
244 |
-
cd <project-directory>
|
245 |
-
```
|
246 |
-
|
247 |
-
2. **Install Dependencies:**
|
248 |
-
Install all necessary dependencies.
|
249 |
-
```bash
|
250 |
-
pip install -r requirements.txt
|
251 |
-
```
|
252 |
-
or manually:
|
253 |
-
```bash
|
254 |
-
pip install transformers datasets wandb scikit-learn
|
255 |
-
```
|
256 |
-
|
257 |
-
3. **Train the Model:**
|
258 |
-
Assuming you have a script like `train.py`, run:
|
259 |
-
```python
|
260 |
-
from train import main
|
261 |
-
```
|
262 |
-
|
263 |
---
|
|
|
|
|
|
|
264 |
|
|
|
|
|
|
|
265 |
## **β¨ Weights & Biases Integration**
|
266 |
|
267 |
### Why Use wandb?
|
@@ -275,10 +207,8 @@ Include this snippet in your training script:
|
|
275 |
import wandb
|
276 |
wandb.init(project="spam-detection")
|
277 |
```
|
278 |
-
|
279 |
---
|
280 |
-
|
281 |
-
## π **Directory Structure**
|
282 |
|
283 |
The directory is organized to ensure scalability and clear separation of components:
|
284 |
|
@@ -292,14 +222,57 @@ project-directory/
|
|
292 |
βββ requirements.txt # List of dependencies
|
293 |
βββ train.py # Main script for training the model
|
294 |
```
|
295 |
-
|
296 |
---
|
|
|
297 |
|
298 |
-
|
299 |
-
The training dataset comes from **Spam-Text-Detect-Analysis** available on Hugging Face:
|
300 |
-
- **Dataset Link:** [Spam Text Detection Dataset - Hugging Face](https://huggingface.co/datasets)
|
301 |
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
---
|
|
|
9 |
pipeline_tag: text-classification
|
10 |
library_name: transformers
|
11 |
---
|
12 |
+
# **Spam Detection with BERT**
|
13 |
|
14 |
+
This repository contains an implementation of a **Spam Detection** model using **BERT (Bidirectional Encoder Representations from Transformers)** for binary classification (Spam / Ham). The model is trained on the **`prithivMLmods/Spam-Text-Detect-Analysis` dataset** and leverages **Weights & Biases (wandb)** for comprehensive experiment tracking.
|
15 |
|
16 |
---
|
17 |
+
## **ποΈ Summary of Uploaded Files**
|
18 |
|
19 |
| **File Name** | **Size** | **Description** | **Upload Status** |
|
20 |
|------------------------------------|-----------|-----------------------------------------------------|-------------------|
|
|
|
49 |
- **Precision:** **0.9931**
|
50 |
- **Recall:** **0.9597**
|
51 |
- **F1 Score:** **0.9761**
|
|
|
52 |
---
|
|
|
53 |
## **π Model Training Details**
|
54 |
|
55 |
### **Model Architecture:**
|
|
|
60 |
- **Batch Size:** 16
|
61 |
- **Epochs:** 3
|
62 |
- **Loss:** Cross-Entropy
|
|
|
63 |
---
|
64 |
+
## **π How to Use the Model**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
### **1. Clone the Repository**
|
67 |
+
```bash
|
68 |
+
git clone <repository-url>
|
69 |
+
cd <project-directory>
|
70 |
```
|
|
|
71 |
|
72 |
+
### **2. Install Dependencies**
|
73 |
+
Install all necessary dependencies.
|
74 |
+
```bash
|
75 |
+
pip install -r requirements.txt
|
76 |
+
```
|
77 |
+
or manually:
|
78 |
+
```bash
|
79 |
+
pip install transformers datasets wandb scikit-learn
|
80 |
+
```
|
81 |
+
### **3. Train the Model**
|
82 |
+
Assuming you have a script like `train.py`, run:
|
83 |
```python
|
|
|
84 |
# Import necessary libraries
|
85 |
from datasets import load_dataset, ClassLabel
|
86 |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
|
|
|
186 |
example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."
|
187 |
print("Prediction:", predict(example_text))
|
188 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
---
|
190 |
+
## **π Dataset Information**
|
191 |
+
The training dataset comes from **Spam-Text-Detect-Analysis** available on Hugging Face:
|
192 |
+
- **Dataset Link:** [Spam Text Detection Dataset - Hugging Face](https://huggingface.co/datasets)
|
193 |
|
194 |
+
Dataset size:
|
195 |
+
- **5.57k entries**
|
196 |
+
---
|
197 |
## **β¨ Weights & Biases Integration**
|
198 |
|
199 |
### Why Use wandb?
|
|
|
207 |
import wandb
|
208 |
wandb.init(project="spam-detection")
|
209 |
```
|
|
|
210 |
---
|
211 |
+
## **π Directory Structure**
|
|
|
212 |
|
213 |
The directory is organized to ensure scalability and clear separation of components:
|
214 |
|
|
|
222 |
βββ requirements.txt # List of dependencies
|
223 |
βββ train.py # Main script for training the model
|
224 |
```
|
|
|
225 |
---
|
226 |
+
## **π Gradio Interface**
|
227 |
|
228 |
+
A Gradio interface is provided to test the model interactively. The interface allows users to input text and get predictions on whether the text is **Spam** or **Ham**.
|
|
|
|
|
229 |
|
230 |
+
### **Example Usage**
|
231 |
+
```python
|
232 |
+
import gradio as gr
|
233 |
+
import torch
|
234 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
235 |
+
|
236 |
+
# Load the pre-trained BERT model and tokenizer
|
237 |
+
MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased"
|
238 |
+
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
|
239 |
+
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
|
240 |
+
|
241 |
+
# Function to predict if a given text is Spam or Ham
|
242 |
+
def predict_spam(text):
|
243 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
244 |
+
with torch.no_grad():
|
245 |
+
outputs = model(**inputs)
|
246 |
+
logits = outputs.logits
|
247 |
+
prediction = torch.argmax(logits, axis=-1).item()
|
248 |
+
return "Spam" if prediction == 1 else "Ham"
|
249 |
|
250 |
+
# Gradio UI
|
251 |
+
inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...")
|
252 |
+
outputs = gr.Label(label="Prediction")
|
253 |
+
|
254 |
+
examples = [
|
255 |
+
["Win $1000 gift cards now by clicking here!"],
|
256 |
+
["You have been selected for a lottery."],
|
257 |
+
["Hello, how was your day?"],
|
258 |
+
["Earn money without any effort. Click here."],
|
259 |
+
["Meeting tomorrow at 10 AM. Don't be late."],
|
260 |
+
["Claim your free prize now!"],
|
261 |
+
["Are we still on for dinner tonight?"],
|
262 |
+
["Exclusive offer just for you, act now!"],
|
263 |
+
["Let's catch up over coffee soon."],
|
264 |
+
["Congratulations, you've won a new car!"]
|
265 |
+
]
|
266 |
+
|
267 |
+
gr_interface = gr.Interface(
|
268 |
+
fn=predict_spam,
|
269 |
+
inputs=inputs,
|
270 |
+
outputs=outputs,
|
271 |
+
examples=examples,
|
272 |
+
title="Spam Detection with BERT",
|
273 |
+
description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model."
|
274 |
+
)
|
275 |
+
|
276 |
+
gr_interface.launch()
|
277 |
+
```
|
278 |
---
|