distilbert-base-q-cat

Model Description

distilbert-base-q-cat is a lightweight, fine-tuned DistilBERT model designed for text classification, specifically focusing on categorizing questions into three distinct categories: fact, opinion, and hypothetical. The model was trained on a Quora dataset, leveraging keyword-based labeling and sentiment analysis to ensure high-quality categorization.

Features

Built on DistilBERT, ensuring faster inference and lower computational requirements compared to standard BERT.

Three Class Categories:

  • Fact: Questions seeking factual or objective information.
  • Opinion: Questions that elicit subjective views or opinions.
  • Hypothetical: Questions exploring hypothetical scenarios or speculative ideas.

Pretrained and Fine-Tuned: Utilizes DistilBERT’s pretrained weights with additional fine-tuning on labeled data.

Dataset

The model was trained using a custom dataset derived from Quora questions:

Data Preparation:

  • Labeling involved keyword-based rules for fact and hypothetical questions.

  • Sentiment analysis determined questions as opinion-based.

Dataset Size: ~50k samples, split into training, validation, and test sets.

Performance

The model achieves the following metrics on the validation set:

  • Accuracy: 93.33%
  • Precision: 93.41%
  • Recall: 93.33%
  • F1-Score: 93.32%

Installation

To use this model, install the required dependencies:

pip install transformers torch

Usage

Load Model and Tokenizer

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load model and tokenizer
model_name = "distilbert-base-q-cat"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Inference Example

def predict_question(question):
    inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax(dim=-1).item()

    label_map = {0: "fact", 1: "opinion", 2: "hypothetical"}
    return label_map[predicted_class]

# Example usage
question = "What is artificial intelligence?"
print(predict_question(question))
Downloads last month
2
Safetensors
Model size
67M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for alwanadi17/distilbert-base-q-cat

Finetuned
(7214)
this model

Dataset used to train alwanadi17/distilbert-base-q-cat