|
--- |
|
title: News Source Classifier |
|
emoji: 📰 |
|
colorFrom: blue |
|
colorTo: red |
|
sdk: streamlit |
|
app_file: eval_pipeline.py |
|
library_name: transformers |
|
pinned: false |
|
language: en |
|
license: mit |
|
tags: |
|
- text-classification |
|
- news-classification |
|
- BERT |
|
- pytorch |
|
- transformers |
|
pipeline_tag: text-classification |
|
widget: |
|
- example_title: "Politics News Headline" |
|
text: "Trump's campaign rival decides between voting for him or Biden" |
|
- example_title: "International News Headline" |
|
text: "World Food Programme Director Cindy McCain: Northern Gaza is in a 'full-blown famine'" |
|
- example_title: "Domestic News Headline" |
|
text: "Ohio sheriff suggests residents keep a list of homes with Harris yard signs" |
|
model-index: |
|
- name: News Source Classifier |
|
results: |
|
- task: |
|
type: text-classification |
|
name: Text Classification |
|
dataset: |
|
name: Custom FOX-NBC Dataset |
|
type: Custom |
|
metrics: |
|
- name: F1 Score |
|
type: f1 |
|
value: 0.94 |
|
--- |
|
|
|
# News Source Classifier - BERT Model |
|
|
|
## Model Overview |
|
This repository contains a fine-tuned BERT model that classifies news headlines between Fox News and NBC News, along with an evaluation pipeline for assessing model performance using Streamlit. |
|
|
|
### Model Details |
|
- **Base Model**: BERT (bert-base-uncased) |
|
- **Task**: Binary classification (Fox News vs NBC News) |
|
- **Model ID**: CIS519PG/News_Classifier_Demo |
|
- **Training Data**: News headlines from Fox News and NBC News |
|
- **Input**: News article headlines (text) |
|
- **Output**: Binary classification with probability scores |
|
|
|
## Evaluation Pipeline Setup |
|
|
|
### Prerequisites |
|
- Python 3.8+ |
|
- pip package manager |
|
|
|
### Required Dependencies |
|
Install the required packages using pip: |
|
```bash |
|
pip install streamlit pandas torch transformers scikit-learn numpy plotly tqdm |
|
``` |
|
|
|
### Running the Evaluation Pipeline |
|
|
|
1. Save the following provided evaluation code as `eval_pipeline.py`, also downloadable in files. |
|
|
|
```bash |
|
import streamlit as st |
|
import pandas as pd |
|
import torch |
|
from transformers import BertTokenizer, AutoModelForSequenceClassification |
|
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from tqdm import tqdm |
|
|
|
def load_model_and_tokenizer(): |
|
try: |
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = model.to(device) |
|
model.eval() |
|
return model, tokenizer, device |
|
except Exception as e: |
|
st.error(f"Error loading model or tokenizer: {str(e)}") |
|
return None, None, None |
|
|
|
def preprocess_data(df): |
|
try: |
|
processed_data = [] |
|
for _, row in df.iterrows(): |
|
outlet = row["outlet"].strip().upper() |
|
if outlet == "FOX NEWS": |
|
outlet = "FOXNEWS" |
|
elif outlet == "NBC NEWS": |
|
outlet = "NBC" |
|
|
|
processed_data.append({ |
|
"title": row["title"], |
|
"outlet": outlet |
|
}) |
|
return processed_data |
|
except Exception as e: |
|
st.error(f"Error preprocessing data: {str(e)}") |
|
return None |
|
|
|
def evaluate_model(model, tokenizer, device, test_dataset): |
|
label2id = {"FOXNEWS": 0, "NBC": 1} |
|
all_logits = [] |
|
references = [] |
|
|
|
batch_size = 16 |
|
progress_bar = st.progress(0) |
|
|
|
for i in range(0, len(test_dataset), batch_size): |
|
progress = min(i / len(test_dataset), 1.0) |
|
progress_bar.progress(progress) |
|
|
|
batch = test_dataset[i:i + batch_size] |
|
texts = [item['title'] for item in batch] |
|
|
|
encoded = tokenizer( |
|
texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=128, |
|
return_tensors="pt" |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in encoded.items()} |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits.cpu().numpy() |
|
|
|
true_labels = [label2id[item['outlet']] for item in batch] |
|
all_logits.extend(logits) |
|
references.extend(true_labels) |
|
progress_bar.progress(1.0) |
|
probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy() |
|
return references, probabilities |
|
|
|
def plot_roc_curve(references, probabilities): |
|
fpr, tpr, _ = roc_curve(references, probabilities[:, 1]) |
|
auc_score = roc_auc_score(references, probabilities[:, 1]) |
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})')) |
|
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash'))) |
|
fig.update_layout( |
|
title='ROC Curve', |
|
xaxis_title='False Positive Rate', |
|
yaxis_title='True Positive Rate', |
|
showlegend=True |
|
) |
|
return fig, auc_score |
|
|
|
def plot_metrics_by_threshold(references, probabilities): |
|
thresholds = np.arange(0.0, 1.0, 0.01) |
|
metrics = { |
|
'threshold': thresholds, |
|
'f1': [], |
|
'precision': [], |
|
'recall': [] |
|
} |
|
best_f1 = 0 |
|
best_threshold = 0 |
|
best_metrics = {} |
|
for threshold in thresholds: |
|
preds = (probabilities[:, 1] > threshold).astype(int) |
|
f1 = f1_score(references, preds) |
|
precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary') |
|
metrics['f1'].append(f1) |
|
metrics['precision'].append(precision) |
|
metrics['recall'].append(recall) |
|
if f1 > best_f1: |
|
best_f1 = f1 |
|
best_threshold = threshold |
|
cm = confusion_matrix(references, preds) |
|
report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4) |
|
best_metrics = { |
|
'threshold': threshold, |
|
'f1_score': f1, |
|
'confusion_matrix': cm, |
|
'classification_report': report |
|
} |
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score')) |
|
fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision')) |
|
fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall')) |
|
fig.update_layout( |
|
title='Metrics by Threshold', |
|
xaxis_title='Threshold', |
|
yaxis_title='Score', |
|
showlegend=True |
|
) |
|
return fig, best_metrics |
|
|
|
def plot_confusion_matrix(cm): |
|
labels = ['FOXNEWS', 'NBC'] |
|
annotations = [] |
|
for i in range(len(labels)): |
|
for j in range(len(labels)): |
|
annotations.append( |
|
dict( |
|
text=str(cm[i, j]), |
|
x=labels[j], |
|
y=labels[i], |
|
showarrow=False, |
|
font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black') |
|
) |
|
) |
|
fig = go.Figure(data=go.Heatmap( |
|
z=cm, |
|
x=labels, |
|
y=labels, |
|
colorscale='Blues', |
|
showscale=True |
|
)) |
|
fig.update_layout( |
|
title='Confusion Matrix', |
|
xaxis_title='Predicted Label', |
|
yaxis_title='True Label', |
|
annotations=annotations |
|
) |
|
return fig |
|
|
|
def main(): |
|
st.title("News Classifier Model Evaluation") |
|
uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv']) |
|
if uploaded_file is not None: |
|
df = pd.read_csv(uploaded_file) |
|
st.write("Preview of uploaded data:") |
|
st.dataframe(df.head()) |
|
model, tokenizer, device = load_model_and_tokenizer() |
|
if model and tokenizer: |
|
test_dataset = preprocess_data(df) |
|
if test_dataset: |
|
st.write(f"Total examples: {len(test_dataset)}") |
|
with st.spinner('Evaluating model...'): |
|
references, probabilities = evaluate_model(model, tokenizer, device, test_dataset) |
|
roc_fig, auc_score = plot_roc_curve(references, probabilities) |
|
st.plotly_chart(roc_fig) |
|
st.metric("AUC-ROC Score", f"{auc_score:.4f}") |
|
metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities) |
|
st.plotly_chart(metrics_fig) |
|
st.subheader("Best Threshold Evaluation") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}") |
|
with col2: |
|
st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}") |
|
st.subheader("Confusion Matrix") |
|
cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix']) |
|
st.plotly_chart(cm_fig) |
|
st.subheader("Classification Report") |
|
st.text(best_metrics['classification_report']) |
|
if __name__ == "__main__": |
|
main() |
|
``` |
|
|
|
2. Run the Streamlit application: |
|
```bash |
|
streamlit run eval_pipeline.py |
|
``` |
|
|
|
3. The web interface will automatically open in your default browser |
|
|
|
### Using the Web Interface |
|
|
|
1. **Upload Test Data**: |
|
- Prepare your test data in CSV format |
|
- Required columns: |
|
- Index column (automatic numbering) |
|
- "title": The news headline text |
|
- "label": Binary label (0 for Fox News, 1 for NBC News) |
|
- "outlet": The source ("Fox News" or "NBC News") |
|
|
|
2. **View Evaluation Results**: |
|
The pipeline will display: |
|
- Data preview |
|
- ROC curve with AUC score |
|
- Metrics vs threshold plot |
|
- Best threshold and F1 score |
|
- Confusion matrix visualization |
|
- Detailed classification report |
|
|
|
### Sample Data Format |
|
```csv |
|
,title,label,outlet |
|
0,"Jack Carr's take on the late Tom Clancy, born on this day in 1947",0,Fox News |
|
1,"Feeding America CEO asks community to help others amid today's high inflation",0,Fox News |
|
2,"World Food Programme Director Cindy McCain: Northern Gaza is in a 'full-blown famine'",1,NBC News |
|
3,"Ohio sheriff suggests residents keep a list of homes with Harris yard signs",1,NBC News |
|
``` |
|
|
|
## Model Architecture |
|
- Base model: BERT (bert-base-uncased) |
|
- Fine-tuned for binary classification |
|
- Uses PyTorch and Hugging Face Transformers |
|
|
|
## Limitations and Bias |
|
This model has been trained on news headlines from specific sources (Fox News and NBC News) and time periods, which may introduce certain biases: |
|
- Limited to two specific news sources |
|
- Temporal bias based on training data collection period |
|
- May not generalize well to other news sources or formats |
|
|
|
## Evaluation Metrics |
|
The pipeline provides comprehensive evaluation metrics: |
|
- AUC-ROC Score |
|
- F1 Score |
|
- Precision & Recall |
|
- Confusion Matrix |
|
- Detailed Classification Report |
|
|
|
## Troubleshooting |
|
|
|
Common issues and solutions: |
|
|
|
1. **CUDA/GPU Error**: |
|
- The pipeline automatically falls back to CPU if CUDA is not available |
|
- No action needed from user |
|
|
|
2. **Memory Issues**: |
|
- Default batch size is 16 |
|
- Reduce batch size if memory constraints exist |
|
|
|
3. **File Format Error**: |
|
- Ensure CSV file has exact column names: "title", "label", "outlet" |
|
- Verify label values are 0 or 1 |
|
- Confirm "outlet" values are exactly "Fox News" or "NBC News" |
|
|
|
## License |
|
This project is licensed under the MIT License. |