|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- webis/tldr-17 |
|
language: |
|
- en |
|
library_name: transformers |
|
pipeline_tag: text-classification |
|
widget: |
|
- text: "Biden says US is at tipping point on gun control: We will ban assault weapons in this country" |
|
example_title: "classification" |
|
--- |
|
# RedBERT - a Reddit post classifier |
|
|
|
This model based on distilbert is finetuned to predict the subreddit of a Reddit post. |
|
|
|
|
|
## Usage |
|
### Preparations |
|
The model uses the transformers library, so make sure to install it. |
|
``` |
|
pip install transformers[torch] |
|
``` |
|
|
|
After the installation, the model can be loaded from Hugging Face. |
|
The model will be sored localy so if you run this lines multiple times the model will be loaded from cache. |
|
|
|
```py |
|
from transformers import pipeline |
|
pipe = pipeline("text-classification", model="traberph/RedBERT") |
|
``` |
|
|
|
### Basic |
|
For a simple classification task just call the pipeline with the text of your choice |
|
```py |
|
text = "I (33f) need to explain to my coworker (30m) I don't want his company on the commute back home" |
|
pipe(text) |
|
``` |
|
output: |
|
[{'label': 'relationships', 'score': 0.9622366428375244}] |
|
|
|
### Multiclass with visualization |
|
Everyone likes visualizations! Therefore this is an example to output the 5 most probable labels and visualize the result. |
|
Make sure that all requirements are satisfied. |
|
``` |
|
pip install pandas seaborn |
|
``` |
|
```py |
|
import pandas as pd |
|
import seaborn as sns |
|
|
|
# if the model is already loaded this can be skipped |
|
from transformers import pipeline |
|
pipe = pipeline("text-classification", model="traberph/RedBERT") |
|
|
|
text = "Today I spilled coffee over my pc. It started to smoke and the screen turned black. I guess I have a problem now." |
|
|
|
# predict the 5 most probable labels |
|
res = pipe(text, top_k=5) |
|
|
|
# create a pandas dataframe from the result |
|
df = pd.DataFrame(res) |
|
|
|
# use seaborn to create a barplot |
|
sns.barplot(df, x='score', y='label', color='steelblue') |
|
``` |
|
|
|
output: |
|
![](./assets/classify01.png) |
|
|
|
|
|
## Training |
|
The training of the final version of this model took `130h` on a single `Tesla P100 GPU`. |
|
90% of the [webis/tldr-17](https://huggingface.co/datasets/webis/tldr-17/) where used for this version. |
|
|
|
|
|
## Bias and Limitations |
|
The webis/tldr-17 dataset used to train this model contains 3 848 330 posts from 29 651 subreddits. |
|
Those posts however are not equally distributed over the subreddits. 589 947 posts belong to the subreddit `AskReddit`, which is `15%` of the whole dataset. Other subreddits are underrepresented. |
|
| top subreddits | distribution | |
|
| --- | --- | |
|
| ![distribution](./assets/distribution01.png) | ![distribution](./assets/distribution02.png) | |
|
|
|
|
|
This bias in the subreddit distribution is also represented in the model and can be observed during inference. |
|
| class labels for `"Biden says US is at tipping point on gun control: We will ban assault weapons in this country"`, from r/politics | |
|
| --- | |
|
| ![classification](./assets/classify02.png) | |