File size: 1,544 Bytes
710b35e
d7abbfc
 
 
 
a0e06f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a212615
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
---
license: apache-2.0
language:
- en
pipeline_tag: text-classification
---

# DeTexD-RoBERTa-base delicate text detection

This is a baseline RoBERTa-base model for the delicate text detection task.

* Paper: [DeTexD: A Benchmark Dataset for Delicate Text Detection](TODO)
* [GitHub repository](https://github.com/grammarly/detexd)

## Classification example code

Here's a short usage example with the torch library in a binary classification task:

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained("grammarly/detexd-roberta")
model = AutoModelForSequenceClassification.from_pretrained("grammarly/detexd-roberta")
model.eval()

def predict_binary_score(text: str, break_class_ix=3):
    with torch.no_grad():
        # get multiclass probability scores
        logits = model(**tokenizer(text, return_tensors='pt'))[0]
        probs = torch.nn.functional.softmax(logits, dim=-1)

        # convert to a binary prediction by summing the probability scores
        # for the higher-index classes, as defined by break_class_ix
        bin_score = probs[..., break_class_ix:].sum(dim=-1)

        return bin_score.item()

def predict_delicate(text: str, threshold=0.72496545):
    return predict_binary_score(text) > threshold

print(predict_delicate("Time flies like an arrow. Fruit flies like a banana."))
```

Expected output:

```
False
```

## BibTeX entry and citation info

Please cite [our paper](TODO) if you use this model.

```bibtex
TODO
```