|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
pipeline_tag: text-classification |
|
--- |
|
|
|
# DeTexD-RoBERTa-base delicate text detection |
|
|
|
This is a baseline RoBERTa-base model for the delicate text detection task. |
|
|
|
* Paper: [DeTexD: A Benchmark Dataset for Delicate Text Detection](TODO) |
|
* [GitHub repository](https://github.com/grammarly/detexd) |
|
|
|
## Classification example code |
|
|
|
Here's a short usage example with the torch library in a binary classification task: |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("grammarly/detexd-roberta") |
|
model = AutoModelForSequenceClassification.from_pretrained("grammarly/detexd-roberta") |
|
model.eval() |
|
|
|
def predict_binary_score(text: str, break_class_ix=3): |
|
with torch.no_grad(): |
|
# get multiclass probability scores |
|
logits = model(**tokenizer(text, return_tensors='pt'))[0] |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
# convert to a binary prediction by summing the probability scores |
|
# for the higher-index classes, as defined by break_class_ix |
|
bin_score = probs[..., break_class_ix:].sum(dim=-1) |
|
|
|
return bin_score.item() |
|
|
|
def predict_delicate(text: str, threshold=0.72496545): |
|
return predict_binary_score(text) > threshold |
|
|
|
print(predict_delicate("Time flies like an arrow. Fruit flies like a banana.")) |
|
``` |
|
|
|
Expected output: |
|
|
|
``` |
|
False |
|
``` |
|
|
|
## BibTeX entry and citation info |
|
|
|
Please cite [our paper](TODO) if you use this model. |
|
|
|
```bibtex |
|
TODO |
|
``` |