prithivMLmods
commited on
Commit
•
3f5e2ad
1
Parent(s):
e6ad0c0
Update README.md
Browse files
README.md
CHANGED
@@ -113,6 +113,115 @@ gr_interface = gr.Interface(
|
|
113 |
# Launch the application
|
114 |
gr_interface.launch()
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
```
|
117 |
|
118 |
## **🚀 How to Train the Model**
|
|
|
113 |
# Launch the application
|
114 |
gr_interface.launch()
|
115 |
|
116 |
+
```
|
117 |
+
### Train Details
|
118 |
+
|
119 |
+
```python
|
120 |
+
|
121 |
+
# Import necessary libraries
|
122 |
+
from datasets import load_dataset, ClassLabel
|
123 |
+
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
|
124 |
+
import torch
|
125 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
126 |
+
|
127 |
+
# Load dataset
|
128 |
+
dataset = load_dataset("prithivMLmods/Spam-Text-Detect-Analysis", split="train")
|
129 |
+
|
130 |
+
# Encode labels as integers
|
131 |
+
label_mapping = {"ham": 0, "spam": 1}
|
132 |
+
dataset = dataset.map(lambda x: {"label": label_mapping[x["Category"]]})
|
133 |
+
dataset = dataset.rename_column("Message", "text").remove_columns(["Category"])
|
134 |
+
|
135 |
+
# Convert label column to ClassLabel for stratification
|
136 |
+
class_label = ClassLabel(names=["ham", "spam"])
|
137 |
+
dataset = dataset.cast_column("label", class_label)
|
138 |
+
|
139 |
+
# Split into train and test
|
140 |
+
dataset = dataset.train_test_split(test_size=0.2, stratify_by_column="label")
|
141 |
+
train_dataset = dataset["train"]
|
142 |
+
test_dataset = dataset["test"]
|
143 |
+
|
144 |
+
# Load BERT tokenizer
|
145 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
146 |
+
|
147 |
+
# Tokenize the data
|
148 |
+
def tokenize_function(examples):
|
149 |
+
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
|
150 |
+
|
151 |
+
train_dataset = train_dataset.map(tokenize_function, batched=True)
|
152 |
+
test_dataset = test_dataset.map(tokenize_function, batched=True)
|
153 |
+
|
154 |
+
# Set format for PyTorch
|
155 |
+
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
156 |
+
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
157 |
+
|
158 |
+
# Load pre-trained BERT model
|
159 |
+
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
|
160 |
+
|
161 |
+
# Move model to GPU if available
|
162 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
163 |
+
model.to(device)
|
164 |
+
|
165 |
+
# Define evaluation metric
|
166 |
+
def compute_metrics(eval_pred):
|
167 |
+
predictions, labels = eval_pred
|
168 |
+
predictions = torch.argmax(torch.tensor(predictions), dim=-1)
|
169 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
|
170 |
+
acc = accuracy_score(labels, predictions)
|
171 |
+
return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}
|
172 |
+
|
173 |
+
# Training arguments
|
174 |
+
training_args = TrainingArguments(
|
175 |
+
output_dir="./results",
|
176 |
+
evaluation_strategy="epoch", # Evaluate after every epoch
|
177 |
+
save_strategy="epoch", # Save checkpoint after every epoch
|
178 |
+
learning_rate=2e-5,
|
179 |
+
per_device_train_batch_size=16,
|
180 |
+
per_device_eval_batch_size=16,
|
181 |
+
num_train_epochs=3,
|
182 |
+
weight_decay=0.01,
|
183 |
+
logging_dir="./logs",
|
184 |
+
logging_steps=10,
|
185 |
+
load_best_model_at_end=True,
|
186 |
+
metric_for_best_model="accuracy",
|
187 |
+
greater_is_better=True
|
188 |
+
)
|
189 |
+
|
190 |
+
# Trainer
|
191 |
+
trainer = Trainer(
|
192 |
+
model=model,
|
193 |
+
args=training_args,
|
194 |
+
train_dataset=train_dataset,
|
195 |
+
eval_dataset=test_dataset,
|
196 |
+
compute_metrics=compute_metrics
|
197 |
+
)
|
198 |
+
|
199 |
+
# Train the model
|
200 |
+
trainer.train()
|
201 |
+
|
202 |
+
# Evaluate the model
|
203 |
+
results = trainer.evaluate()
|
204 |
+
print("Evaluation Results:", results)
|
205 |
+
|
206 |
+
# Save the trained model
|
207 |
+
model.save_pretrained("./saved_model")
|
208 |
+
tokenizer.save_pretrained("./saved_model")
|
209 |
+
|
210 |
+
# Load the model for inference
|
211 |
+
loaded_model = BertForSequenceClassification.from_pretrained("./saved_model").to(device)
|
212 |
+
loaded_tokenizer = BertTokenizer.from_pretrained("./saved_model")
|
213 |
+
|
214 |
+
# Test the model on a custom input
|
215 |
+
def predict(text):
|
216 |
+
inputs = loaded_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
|
217 |
+
inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to the same device as model
|
218 |
+
outputs = loaded_model(**inputs)
|
219 |
+
prediction = torch.argmax(outputs.logits, dim=-1).item()
|
220 |
+
return "Spam" if prediction == 1 else "Ham"
|
221 |
+
|
222 |
+
# Example test
|
223 |
+
example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."
|
224 |
+
print("Prediction:", predict(example_text))
|
225 |
```
|
226 |
|
227 |
## **🚀 How to Train the Model**
|