distilbert-base-q-cat
Model Description
distilbert-base-q-cat is a lightweight, fine-tuned DistilBERT model designed for text classification, specifically focusing on categorizing questions into three distinct categories: fact, opinion, and hypothetical. The model was trained on a Quora dataset, leveraging keyword-based labeling and sentiment analysis to ensure high-quality categorization.
Features
Built on DistilBERT, ensuring faster inference and lower computational requirements compared to standard BERT.
Three Class Categories:
- Fact: Questions seeking factual or objective information.
- Opinion: Questions that elicit subjective views or opinions.
- Hypothetical: Questions exploring hypothetical scenarios or speculative ideas.
Pretrained and Fine-Tuned: Utilizes DistilBERT’s pretrained weights with additional fine-tuning on labeled data.
Dataset
The model was trained using a custom dataset derived from Quora questions:
Data Preparation:
Labeling involved keyword-based rules for fact and hypothetical questions.
Sentiment analysis determined questions as opinion-based.
Dataset Size: ~50k samples, split into training, validation, and test sets.
Performance
The model achieves the following metrics on the validation set:
- Accuracy: 93.33%
- Precision: 93.41%
- Recall: 93.33%
- F1-Score: 93.32%
Installation
To use this model, install the required dependencies:
pip install transformers torch
Usage
Load Model and Tokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load model and tokenizer
model_name = "distilbert-base-q-cat"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Inference Example
def predict_question(question):
inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(dim=-1).item()
label_map = {0: "fact", 1: "opinion", 2: "hypothetical"}
return label_map[predicted_class]
# Example usage
question = "What is artificial intelligence?"
print(predict_question(question))
- Downloads last month
- 2
Model tree for alwanadi17/distilbert-base-q-cat
Base model
distilbert/distilbert-base-uncased