reddgr's picture
Update README.md
c3910c7 verified
|
raw
history blame
3.56 kB
metadata
base_model: facebook/bart-large-mnli
datasets:
  - reddgr/nli-chatbot-prompt-categorization
library_name: transformers
license: mit
tags:
  - generated_from_keras_callback
model-index:
  - name: zero-shot-prompt-classifier-bart-ft
    results: []

zero-shot-prompt-classifier-bart-ft

This model is a fine-tuned version of facebook/bart-large-mnli on the reddgr/nli-chatbot-prompt-categorization dataset.

The purpose of the model is to help classify chatbot prompts into categories that are relevant in the context of working with LLM conversational tools: coding assistance, language assistance, role play, creative writing, general knowledge questions...

The model is fine-tuned and tested on the natural language inference (NLI) dataset reddgr/nli-chatbot-prompt-categorization

Below is a confusion matrix calculated on zero-shot inferences for the 10 most popular categories in the Test split of reddgr/nli-chatbot-prompt-categorization at the time of the first model upload. The classification with the base model on the same small test dataset is shown for comparison:

Zero-shot prompt classification confusion matrix for reddgr/zero-shot-prompt-classifier-bart-ft

As of the first version of the model uploaded to hub, the fine-tuned version outperforms the base model facebook/bart-large-mnli by 17 percentage points (51% accuracy vs 34% accuracy) in this test set with 10 candidate zero-shot classes (the most frequent categories in the test split of reddgr/nli-chatbot-prompt-categorization at the time of the first model upload).

The dataset and the model are continously updated as they assist with content publishing on my website Talking to Chatbots

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • optimizer: {'name': 'Adam', 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'jit_compile': False, 'is_legacy_optimizer': False, 'learning_rate': 5e-06, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}
  • training_precision: float32

Training results

Train Loss Train Accuracy Validation Loss Validation Accuracy Epoch
0.9969 0.5490 0.9182 0.6225 0
0.7647 0.6601 1.0025 0.5441 1
0.6465 0.7157 1.1472 0.5392 2
0.5849 0.7418 1.1974 0.5049 3
0.4474 0.7843 1.5942 0.4657 4

Framework versions

  • Transformers 4.44.2
  • TensorFlow 2.18.0-dev20240717
  • Datasets 2.21.0
  • Tokenizers 0.19.1