Files changed (1) hide show
  1. README.md +80 -3
README.md CHANGED
@@ -2,15 +2,50 @@
2
  tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
 
5
  ---
6
  # nvidia/domain-classifier
7
 
8
- This repository contains the code for the domain classifier model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # How to use in transformers
11
  To use the Domain classifier, use the following code:
12
 
13
- ```python3
14
 
15
  import torch
16
  from torch import nn
@@ -45,4 +80,46 @@ predicted_classes = torch.argmax(outputs, dim=1)
45
  predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
46
  print(predicted_domains)
47
  # ['Sports', 'News']
48
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
5
+ license: apache-2.0
6
  ---
7
  # nvidia/domain-classifier
8
 
9
+ # Model Overview
10
+ This is a text classification model to classify documents into one of 26 domain classes:
11
+
12
+ 'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'
13
+ # Model Architecture
14
+ The model architecture is Deberta V3 Base
15
+ Context length is 512 tokens
16
+ # Training (details)
17
+ ## Training data:
18
+ - 1 million Common Crawl samples, labeled using Google Cloud’s Natural Language API: https://cloud.google.com/natural-language/docs/classifying-text
19
+ - 500k Wikepedia articles, curated using Wikipedia-API: https://pypi.org/project/Wikipedia-API/
20
+ ## Training steps:
21
+ Model was trained in multiple rounds using Wikipedia and Common Crawl data, labeled by a combination of pseudo labels and Google Cloud API.
22
+ # How To Use This Model
23
+ ## Input
24
+ The model takes one or several paragraphs of text as input.
25
+ Example input:
26
+ ```
27
+ q Directions
28
+ 1. Mix 2 flours and baking powder together
29
+ 2. Mix water and egg in a separate bowl. Add dry to wet little by little
30
+ 3. Heat frying pan on medium
31
+ 4. Pour batter into pan and then put blueberries on top before flipping
32
+ 5. Top with desired toppings!
33
+ ```
34
+ ## Output
35
+ The model outputs one of the 26 domain classes as the predicted domain for each input sample.
36
+ Example output:
37
+ ```
38
+ Food_and_Drink
39
+ ```
40
+
41
+ # How to use in NeMo Curator
42
+
43
+ The inference code is available on NeMo Curator's GitHub repository. Download the [model.pth](https://huggingface.co/nvidia/domain-classifier/blob/main/model.pth) file and check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/distributed_data_classification/distributed_data_classification.ipynb) to get started.
44
 
45
  # How to use in transformers
46
  To use the Domain classifier, use the following code:
47
 
48
+ ```python
49
 
50
  import torch
51
  from torch import nn
 
80
  predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
81
  print(predicted_domains)
82
  # ['Sports', 'News']
83
+ ```
84
+
85
+ # Evaluation Benchmarks
86
+
87
+ Evaluation Metric: PR-AUC
88
+ PR-AUC score on evaluation set with 105k samples - 0.9873
89
+ PR-AUC score for each domain:
90
+ | Domain | PR-AUC |
91
+ |--------------------------|--------|
92
+ | Adult | 0.999 |
93
+ | Arts_and_Entertainment | 0.997 |
94
+ | Autos_and_Vehicles | 0.997 |
95
+ | Beauty_and_Fitness | 0.997 |
96
+ | Books_and_Literature | 0.995 |
97
+ | Business_and_Industrial | 0.982 |
98
+ | Computers_and_Electronics| 0.992 |
99
+ | Finance | 0.989 |
100
+ | Food_and_Drink | 0.998 |
101
+ | Games | 0.997 |
102
+ | Health | 0.997 |
103
+ | Hobbies_and_Leisure | 0.984 |
104
+ | Home_and_Garden | 0.997 |
105
+ | Internet_and_Telecom | 0.982 |
106
+ | Jobs_and_Education | 0.993 |
107
+ | Law_and_Government | 0.967 |
108
+ | News | 0.918 |
109
+ | Online_Communities | 0.983 |
110
+ | People_and_Society | 0.975 |
111
+ | Pets_and_Animals | 0.997 |
112
+ | Real_Estate | 0.997 |
113
+ | Science | 0.988 |
114
+ | Sensitive_Subjects | 0.982 |
115
+ | Shopping | 0.995 |
116
+ | Sports | 0.995 |
117
+ | Travel_and_Transportation| 0.996 |
118
+ | Mean | 0.9873 |
119
+
120
+ # References
121
+ https://arxiv.org/abs/2111.09543
122
+ https://github.com/microsoft/DeBERTa
123
+ # License
124
+ License to use this model is covered by the Apache 2.0. By downloading the public and release version of the model, you accept the terms and conditions of the Apache License 2.0.
125
+ This repository contains the code for the domain classifier model.