Update README.md
Browse files
README.md
CHANGED
@@ -8,7 +8,7 @@ datasets:
|
|
8 |
# ByT5-base fine-tuned for Hate Speech Detection (on Tweets)
|
9 |
[ByT5](https://huggingface.co/google/byt5-base) base fine-tuned on [tweets hate speech detection](https://huggingface.co/datasets/tweets_hate_speech_detection) dataset for **Sequence Classification** downstream task.
|
10 |
|
11 |
-
# Details of ByT5 - Base
|
12 |
|
13 |
ByT5 is a tokenizer-free version of [Google's T5](https://ai.googleblog.com/2020/02/exploring-transfer-learning-with-t5.html) and generally follows the architecture of [MT5](https://huggingface.co/google/mt5-base).
|
14 |
ByT5 was only pre-trained on [mC4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual) excluding any supervised training with an average span-mask of 20 UTF-8 characters. Therefore, this model has to be fine-tuned before it is useable on a downstream task.
|
@@ -61,18 +61,18 @@ pip install -q transformers
|
|
61 |
|
62 |
```python
|
63 |
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
64 |
-
tokenizer = AutoTokenizer.from_pretrained("Narrativa/byt5-base-tweet-hate-detection")
|
65 |
|
66 |
-
|
67 |
|
68 |
-
|
|
|
69 |
|
70 |
-
|
71 |
-
input_ids = inputs.input_ids.to("cuda")
|
72 |
-
attention_mask = inputs.attention_mask.to("cuda")
|
73 |
|
|
|
|
|
|
|
74 |
output = model.generate(input_ids, attention_mask=attention_mask)
|
75 |
-
|
76 |
return tokenizer.decode(output[0], skip_special_tokens=True)
|
77 |
|
78 |
|
|
|
8 |
# ByT5-base fine-tuned for Hate Speech Detection (on Tweets)
|
9 |
[ByT5](https://huggingface.co/google/byt5-base) base fine-tuned on [tweets hate speech detection](https://huggingface.co/datasets/tweets_hate_speech_detection) dataset for **Sequence Classification** downstream task.
|
10 |
|
11 |
+
# Details of ByT5 - Base 🧠
|
12 |
|
13 |
ByT5 is a tokenizer-free version of [Google's T5](https://ai.googleblog.com/2020/02/exploring-transfer-learning-with-t5.html) and generally follows the architecture of [MT5](https://huggingface.co/google/mt5-base).
|
14 |
ByT5 was only pre-trained on [mC4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual) excluding any supervised training with an average span-mask of 20 UTF-8 characters. Therefore, this model has to be fine-tuned before it is useable on a downstream task.
|
|
|
61 |
|
62 |
```python
|
63 |
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
|
|
64 |
|
65 |
+
ckpt = 'Narrativa/byt5-base-tweet-hate-detection'
|
66 |
|
67 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
68 |
+
model = T5ForConditionalGeneration.from_pretrained(ckpt).to("cuda")
|
69 |
|
70 |
+
def classify_tweet(tweet):
|
|
|
|
|
71 |
|
72 |
+
inputs = tokenizer([tweet], padding='max_length', truncation=True, max_length=512, return_tensors='pt')
|
73 |
+
input_ids = inputs.input_ids.to('cuda')
|
74 |
+
attention_mask = inputs.attention_mask.to('cuda')
|
75 |
output = model.generate(input_ids, attention_mask=attention_mask)
|
|
|
76 |
return tokenizer.decode(output[0], skip_special_tokens=True)
|
77 |
|
78 |
|