zero-shot-prompt-classifier-bart-ft

This model is a fine-tuned version of facebook/bart-large-mnli on the reddgr/nli-chatbot-prompt-categorization dataset.

The purpose of the model is to help classify chatbot prompts into categories that are relevant in the context of working with LLM conversational tools: coding assistance, language assistance, role play, creative writing, general knowledge questions...

The model is fine-tuned and tested on the natural language inference (NLI) dataset reddgr/nli-chatbot-prompt-categorization

Below is a confusion matrix calculated on zero-shot inferences for the 10 most popular categories in the Test split of reddgr/nli-chatbot-prompt-categorization at the time of the first model upload. The classification with the base model on the same small test dataset is shown for comparison:

Zero-shot prompt classification confusion matrix for reddgr/zero-shot-prompt-classifier-bart-ft

The current version of the fine-tuned model outperforms the base model facebook/bart-large-mnli by 23 percentage points (57% accuracy vs 34% accuracy) in a test set with 10 candidate zero-shot classes (the most frequent categories in the test split of reddgr/nli-chatbot-prompt-categorization).

The chart below compares the results for the 12 most popular candidate classes in the Test split, where the base model's zero-shot accuracy is outperformed by 25 percentage points:

Zero-shot prompt classification confusion matrix for reddgr/zero-shot-prompt-classifier-bart-ft

The dataset and the model are continously updated as they assist with content publishing on my website Talking to Chatbots

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • optimizer: {'name': 'Adam', 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'jit_compile': False, 'is_legacy_optimizer': False, 'learning_rate': 5e-06, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}
  • training_precision: float32

Training results

Train Loss Train Accuracy Validation Loss Validation Accuracy Epoch
0.9969 0.5490 0.9182 0.6225 0
0.7647 0.6601 1.0025 0.5441 1
0.6465 0.7157 1.1472 0.5392 2
0.5849 0.7418 1.1974 0.5049 3
0.4474 0.7843 1.5942 0.4657 4

Framework versions

  • Transformers 4.44.2
  • TensorFlow 2.18.0-dev20240717
  • Datasets 2.21.0
  • Tokenizers 0.19.1
Downloads last month
118
Safetensors
Model size
407M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for reddgr/zero-shot-prompt-classifier-bart-ft

Finetuned
(32)
this model

Dataset used to train reddgr/zero-shot-prompt-classifier-bart-ft