Philipp Traber
commited on
Commit
•
1038a32
1
Parent(s):
a23704d
Updated README and added images
Browse files- README.md +64 -36
- assets/classify01.png +0 -0
- assets/classify02.png +0 -0
- assets/distribution01.png +0 -0
- assets/distribution02.png +0 -0
README.md
CHANGED
@@ -1,48 +1,76 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
inference: false
|
10 |
-
---
|
11 |
-
|
12 |
-
## Reddit post classification
|
13 |
-
|
14 |
-
This model predicts the subreddit of a provided post
|
15 |
-
The transformers library is required
|
16 |
```
|
17 |
-
pip install
|
18 |
```
|
19 |
|
|
|
|
|
|
|
20 |
```py
|
21 |
from transformers import pipeline
|
22 |
-
pipe = pipeline(
|
23 |
-
pipe("Biden says US is at tipping point on gun control: We will ban assault weapons in this country")
|
24 |
```
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
```py
|
31 |
-
import
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
```
|
43 |
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RedBERT - a Reddit post classifier
|
2 |
+
|
3 |
+
This model based on distilbert is finetuned to predict the subreddit of a Reddit post.
|
4 |
+
|
5 |
+
|
6 |
+
## Usage
|
7 |
+
### Preparations
|
8 |
+
The model uses the transformers library, so make sure to install it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
```
|
10 |
+
pip install transformers[torch]
|
11 |
```
|
12 |
|
13 |
+
After the installation, the model can be loaded from Hugging Face.
|
14 |
+
The model will be sored localy so if you run this lines multiple times the model will be loaded from cache.
|
15 |
+
|
16 |
```py
|
17 |
from transformers import pipeline
|
18 |
+
pipe = pipeline("text-classification", model="traberph/RedBERT")
|
|
|
19 |
```
|
20 |
|
21 |
+
### Basic
|
22 |
+
For a simple classification task just call the pipeline with the text of your choice
|
23 |
+
```py
|
24 |
+
text = "I (33f) need to explain to my coworker (30m) I don't want his company on the commute back home"
|
25 |
+
pipe(text)
|
26 |
+
```
|
27 |
+
output:
|
28 |
+
[{'label': 'relationships', 'score': 0.9622366428375244}]
|
29 |
|
30 |
+
### Multiclass with visualization
|
31 |
+
Everyone likes visualizations! Therefore this is an example to output the 5 most probable labels and visualize the result.
|
32 |
+
Make sure that all requirements are satisfied.
|
33 |
+
```
|
34 |
+
pip install pandas seaborn
|
35 |
+
```
|
36 |
```py
|
37 |
+
import pandas as pd
|
38 |
+
import seaborn as sns
|
39 |
+
|
40 |
+
# if the model is already loaded this can be skipped
|
41 |
+
from transformers import pipeline
|
42 |
+
pipe = pipeline("text-classification", model="traberph/RedBERT")
|
43 |
+
|
44 |
+
text = "Today I spilled coffee over my pc. It started to smoke and the screen turned black. I guess I have a problem now."
|
45 |
+
|
46 |
+
# predict the 5 most probable labels
|
47 |
+
res = pipe(text, top_k=5)
|
48 |
+
|
49 |
+
# create a pandas dataframe from the result
|
50 |
+
df = pd.DataFrame(res)
|
51 |
+
|
52 |
+
# use seaborn to create a barplot
|
53 |
+
sns.barplot(df, x='score', y='label', color='steelblue')
|
54 |
```
|
55 |
|
56 |
+
output:
|
57 |
+
![](./assets/classify01.png)
|
58 |
|
59 |
+
|
60 |
+
## Training
|
61 |
+
The training of the final version of this model took `130h` on a single `Tesla P100 GPU`.
|
62 |
+
90% of the [webis/tldr-17](https://huggingface.co/datasets/webis/tldr-17/) where used for this version.
|
63 |
+
|
64 |
+
|
65 |
+
## Bias and Limitations
|
66 |
+
The webis/tldr-17 dataset used to train this model contains 3 848 330 posts from 29 651 subreddits.
|
67 |
+
Those posts however are not equally distributed over the subreddits. 589 947 posts belong to the subreddit `AskReddit`, which is `15%` of the whole dataset. Other subreddits are underrepresented.
|
68 |
+
| top subreddits | distribution |
|
69 |
+
| --- | --- |
|
70 |
+
| ![distribution](./assets/distribution01.png) | ![distribution](./assets/distribution02.png) |
|
71 |
+
|
72 |
+
|
73 |
+
This bias in the subreddit distribution is also represented in the model and can be observed during inference.
|
74 |
+
| class labels for `"Biden says US is at tipping point on gun control: We will ban assault weapons in this country"`, from r/politics |
|
75 |
+
| --- |
|
76 |
+
| ![classification](./assets/classify02.png) |
|
assets/classify01.png
ADDED
assets/classify02.png
ADDED
assets/distribution01.png
ADDED
assets/distribution02.png
ADDED