license: apache-2.0
datasets:
- webis/tldr-17
language:
- en
library_name: transformers
pipeline_tag: text-classification
widget:
- text: >-
Biden says US is at tipping point on gun control: We will ban assault
weapons in this country
example_title: classification
RedBERT - a Reddit post classifier
This model based on distilbert is finetuned to predict the subreddit of a Reddit post.
Usage
Preparations
The model uses the transformers library, so make sure to install it.
pip install transformers[torch]
After the installation, the model can be loaded from Hugging Face.
The model will be sored localy so if you run this lines multiple times the model will be loaded from cache.
from transformers import pipeline
pipe = pipeline("text-classification", model="traberph/RedBERT")
Basic
For a simple classification task just call the pipeline with the text of your choice
text = "I (33f) need to explain to my coworker (30m) I don't want his company on the commute back home"
pipe(text)
output:
[{'label': 'relationships', 'score': 0.9622366428375244}]
Multiclass with visualization
Everyone likes visualizations! Therefore this is an example to output the 5 most probable labels and visualize the result.
Make sure that all requirements are satisfied.
pip install pandas seaborn
import pandas as pd
import seaborn as sns
# if the model is already loaded this can be skipped
from transformers import pipeline
pipe = pipeline("text-classification", model="traberph/RedBERT")
text = "Today I spilled coffee over my pc. It started to smoke and the screen turned black. I guess I have a problem now."
# predict the 5 most probable labels
res = pipe(text, top_k=5)
# create a pandas dataframe from the result
df = pd.DataFrame(res)
# use seaborn to create a barplot
sns.barplot(df, x='score', y='label', color='steelblue')
Training
The training of the final version of this model took 130h
on a single Tesla P100 GPU
.
90% of the webis/tldr-17 where used for this version.
Bias and Limitations
The webis/tldr-17 dataset used to train this model contains 3 848 330 posts from 29 651 subreddits.
Those posts however are not equally distributed over the subreddits. 589 947 posts belong to the subreddit AskReddit
, which is 15%
of the whole dataset. Other subreddits are underrepresented.
This bias in the subreddit distribution is also represented in the model and can be observed during inference.