In [None]:
# !pip install tensorflow==2.10

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import string
from unidecode import unidecode
import tensorflow as tf 
from sklearn.utils import class_weight
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import cloudpickle
import os
from transformers import DistilBertTokenizerFast
from transformers import TFDistilBertModel, DistilBertConfig
from tensorflow.keras.layers import Input, Dense, Dropout, Average, BatchNormalization
from tensorflow.keras.optimizers.schedules import PolynomialDecay
from tensorflow.keras.callbacks import EarlyStopping

In [3]:
class TextPreprocessor:
    def __init__(self, remove_punct: bool = True, remove_digits: bool = True,
                 remove_stop_words: bool = True,
                 remove_short_words: bool = True, minlen: int = 1, maxlen: int = 1, top_p: float = None,
                 bottom_p: float = None):
        self.remove_punct = remove_punct
        self.remove_digits = remove_digits
        self.remove_stop_words = remove_stop_words
        self.remove_short_words = remove_short_words
        self.minlen = minlen
        self.maxlen = maxlen
        self.top_p = top_p
        self.bottom_p = bottom_p
        self.words_to_remove = []
        self.stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you',
                           'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself',
                           'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them',
                           'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
                           'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has',
                           'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'if', 'or',
                           'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about',
                           'into', 'through', 'during', 'before', 'after', 'to', 'from',
                           'in', 'out', 'on', 'off', 'further', 'then', 'once',
                           'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each',
                           'other', 'such', 'own', 'same', 'so', 'than', 'can', 'will', 'should','now']

        self.contraction_to_expansion = {"ain't": "am not",
                                         "aren't": "are not",
                                         "can't": "cannot",
                                         "can't've": "cannot have",
                                         "'cause": "because",
                                         "could've": "could have",
                                         "couldn't": "could not",
                                         "couldn't've": "could not have",
                                         "didn't": "did not",
                                         "doesn't": "does not",
                                         "don't": "do not",
                                         "hadn't": "had not",
                                         "hadn't've": "had not have",
                                         "hasn't": "has not",
                                         "haven't": "have not",
                                         "he'd": "he would",
                                         "he'd've": "he would have",
                                         "he'll": "he will",
                                         "he'll've": "he will have",
                                         "he's": "he is",
                                         "how'd": "how did",
                                         "how'd'y": "how do you",
                                         "how'll": "how will",
                                         "how's": "how is",
                                         "i'd": "i would",
                                         "i'd've": "i would have",
                                         "i'll": "i will",
                                         "i'll've": "i will have",
                                         "i'm": "i am",
                                         "i've": "i have",
                                         "isn't": "is not",
                                         "it'd": "it had",
                                         "it'd've": "it would have",
                                         "it'll": "it will",
                                         "it'll've": "it will have",
                                         "it's": "it is",
                                         "let's": "let us",
                                         "ma'am": "madam",
                                         "mayn't": "may not",
                                         "might've": "might have",
                                         "mightn't": "might not",
                                         "mightn't've": "might not have",
                                         "must've": "must have",
                                         "mustn't": "must not",
                                         "mustn't've": "must not have",
                                         "needn't": "need not",
                                         "needn't've": "need not have",
                                         "o'clock": "of the clock",
                                         "oughtn't": "ought not",
                                         "oughtn't've": "ought not have",
                                         "shan't": "shall not",
                                         "sha'n't": "shall not",
                                         "shan't've": "shall not have",
                                         "she'd": "she would",
                                         "she'd've": "she would have",
                                         "she'll": "she will",
                                         "she'll've": "she will have",
                                         "she's": "she is",
                                         "should've": "should have",
                                         "shouldn't": "should not",
                                         "shouldn't've": "should not have",
                                         "so've": "so have",
                                         "so's": "so is",
                                         "that'd": "that would",
                                         "that'd've": "that would have",
                                         "that's": "that is",
                                         "there'd": "there had",
                                         "there'd've": "there would have",
                                         "there's": "there is",
                                         "they'd": "they would",
                                         "they'd've": "they would have",
                                         "they'll": "they will",
                                         "they'll've": "they will have",
                                         "they're": "they are",
                                         "they've": "they have",
                                         "to've": "to have",
                                         "wasn't": "was not",
                                         "we'd": "we had",
                                         "we'd've": "we would have",
                                         "we'll": "we will",
                                         "we'll've": "we will have",
                                         "we're": "we are",
                                         "we've": "we have",
                                         "weren't": "were not",
                                         "what'll": "what will",
                                         "what'll've": "what will have",
                                         "what're": "what are",
                                         "what's": "what is",
                                         "what've": "what have",
                                         "when's": "when is",
                                         "when've": "when have",
                                         "where'd": "where did",
                                         "where's": "where is",
                                         "where've": "where have",
                                         "who'll": "who will",
                                         "who'll've": "who will have",
                                         "who's": "who is",
                                         "who've": "who have",
                                         "why's": "why is",
                                         "why've": "why have",
                                         "will've": "will have",
                                         "won't": "will not",
                                         "won't've": "will not have",
                                         "would've": "would have",
                                         "wouldn't": "would not",
                                         "wouldn't've": "would not have",
                                         "y'all": "you all",
                                         "y'alls": "you alls",
                                         "y'all'd": "you all would",
                                         "y'all'd've": "you all would have",
                                         "y'all're": "you all are",
                                         "y'all've": "you all have",
                                         "you'd": "you had",
                                         "you'd've": "you would have",
                                         "you'll": "you you will",
                                         "you'll've": "you you will have",
                                         "you're": "you are",
                                         "you've": "you have"
                                         }

    @staticmethod
    def __remove_double_whitespaces(string: str):
        return " ".join(string.split())

    def __remove_url(self, string_series: pd.Series):
        """
        Removes URLs m text
        :param string_series: pd.Series, input string series
        :return: pd.Series, cleaned string series
        """
        clean_string_series = string_series.str.replace(
            pat=r"(https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|www\.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9]+\.[^\s]{2,}|www\.[a-zA-Z0-9]+\.[^\s]{2,})",
            repl=" ", regex=True)
        return clean_string_series.map(self.__remove_double_whitespaces)

    def __expand(self, string_series: pd.Series):
        """
        Replaces contractions with expansions. eg. don't wit do not.
        :param string_series: pd.Series, input string series
        :return: pd.Series, cleaned string series
        """
        clean_string_series = string_series.copy()
        for c, e in self.contraction_to_expansion.items():
            clean_string_series = clean_string_series.str.replace(pat=c, repl=e, regex=False)
        return clean_string_series.map(self.__remove_double_whitespaces)

    def __remove_punct(self, string_series: pd.Series):
        """
       Removes punctuations from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        clean_string_series = string_series.copy()
        puncts = [r'\n', r'\r', r'\t']
        puncts.extend(list(string.punctuation))
        for i in puncts:
            clean_string_series = clean_string_series.str.replace(pat=i, repl=" ", regex=False)
        return clean_string_series.map(self.__remove_double_whitespaces)

    def __remove_digits(self, string_series: pd.Series):
        """
       Removes digits from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        clean_string_series = string_series.str.replace(pat=r'\d', repl=" ", regex=True)
        return clean_string_series.map(self.__remove_double_whitespaces)

    @staticmethod
    def __remove_short_words(string_series: pd.Series, minlen: int = 1, maxlen: int = 1):
        """
        Reomves words/tokens where minlen <= len <= maxlen.
        :param string_series: pd.Series, input string series
        :param minlen: int, minimum length of token to be removed.
        :param maxlen:  int, maximum length of token to be removed.
        :return: pd.Series, cleaned string series
        """
        clean_string_series = string_series.map(lambda string: " ".join([word for word in string.split() if
                                                                         (len(word) > maxlen) or (len(word) < minlen)]))
        return clean_string_series

    def __remove_stop_words(self, string_series: pd.Series):
        """
       Removes stop words from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        def str_remove_stop_words(string: str):
            stops = self.stop_words
            return " ".join([token for token in string.split() if token not in stops])

        return string_series.map(str_remove_stop_words)

    def __remove_top_bottom_words(self, string_series: pd.Series, top_p: int = None,
                                  bottom_p: int = None, dataset: str = 'train'):
        """
        Reomoves top_p percent (frequent) words and bottom_p percent (rare) words.
        :param string_series: pd.Series, input string series
        :param top_p: float, percent of frequent words to remove.
        :param bottom_p: float, percent of rare words to remove.
        :param dataset: str, "train" for training set, "tesrt" for val/dev/test set.
        :return: pd.Series, cleaned string series
        """
        if dataset == 'train':
            if top_p is None:
                top_p = 0
            if bottom_p is None:
                bottom_p = 0

            if top_p > 0 or bottom_p > 0:
                word_freq = pd.Series(" ".join(string_series).split()).value_counts()
                n_words = len(word_freq)

            if top_p > 0:
                self.words_to_remove.extend([*word_freq.index[: int(np.ceil(top_p * n_words))]])

            if bottom_p > 0:
                self.words_to_remove.extend([*word_freq.index[-int(np.ceil(bottom_p * n_words)):]])

        if len(self.words_to_remove) == 0:
            return string_series
        else:
            clean_string_series = string_series.map(lambda string: " ".join([word for word in string.split()
                                                                             if word not in self.words_to_remove]))
            return clean_string_series

    def preprocess(self, string_series: pd.Series, dataset: str = "train"):
        """
        Entry point.
        :param string_series: pd.Series, input string series
        :param dataset: str, "train" for training set, "tesrt" for val/dev/test set.
        :return: pd.Series, cleaned string series
        """
        string_series = string_series.str.lower()
        string_series = string_series.map(unidecode)
        string_series = self.__remove_url(string_series=string_series)
        string_series = self.__expand(string_series=string_series)

        if self.remove_punct:
            string_series = self.__remove_punct(string_series=string_series)
        if self.remove_digits:
            string_series = self.__remove_digits(string_series=string_series)
        if self.remove_stop_words:
            string_series = self.__remove_stop_words(string_series=string_series)
        if self.remove_short_words:
            string_series = self.__remove_short_words(string_series=string_series,
                                                      minlen=self.minlen,
                                                      maxlen=self.maxlen)
        string_series = self.__remove_top_bottom_words(string_series=string_series,
                                                       top_p=self.top_p,
                                                       bottom_p=self.bottom_p, dataset=dataset)

        string_series = string_series.str.strip()
        string_series.replace(to_replace="", value="this is an empty message", inplace=True)

        return string_series

In [4]:
data = pd.read_csv('train.csv')

In [5]:
data

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0
...,...,...,...,...,...,...,...,...
159566,ffe987279560d7ff,""":::::And for the second time of asking, when ...",0,0,0,0,0,0
159567,ffea4adeee384e90,You should be ashamed of yourself \n\nThat is ...,0,0,0,0,0,0
159568,ffee36eab5c267c9,"Spitzer \n\nUmm, theres no actual article for ...",0,0,0,0,0,0
159569,fff125370e4aaaf3,And it looks like it was actually you who put ...,0,0,0,0,0,0


In [6]:
data.drop(columns='id', inplace=True)

In [7]:
data

Unnamed: 0,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0
...,...,...,...,...,...,...,...
159566,""":::::And for the second time of asking, when ...",0,0,0,0,0,0
159567,You should be ashamed of yourself \n\nThat is ...,0,0,0,0,0,0
159568,"Spitzer \n\nUmm, theres no actual article for ...",0,0,0,0,0,0
159569,And it looks like it was actually you who put ...,0,0,0,0,0,0


In [8]:
data.rename(columns={'comment_text': 'text'}, inplace=True)

In [9]:
data.shape

(159571, 7)

In [10]:
data.dtypes

text             object
toxic             int64
severe_toxic      int64
obscene           int64
threat            int64
insult            int64
identity_hate     int64
dtype: object

In [11]:
# data.drop(columns='categories', inplace=True)
data.dropna(inplace=True)

In [12]:
data

Unnamed: 0,text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0
...,...,...,...,...,...,...,...
159566,""":::::And for the second time of asking, when ...",0,0,0,0,0,0
159567,You should be ashamed of yourself \n\nThat is ...,0,0,0,0,0,0
159568,"Spitzer \n\nUmm, theres no actual article for ...",0,0,0,0,0,0
159569,And it looks like it was actually you who put ...,0,0,0,0,0,0


In [13]:
data['text'][2]

"Hey man, I'm really not trying to edit war. It's just that this guy is constantly removing relevant information and talking to me through edits instead of my talk page. He seems to care more about the formatting than the actual info."

In [14]:
CLASS_NAMES = [*data.columns][1:]
print(CLASS_NAMES)

['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']


In [15]:
tp = TextPreprocessor()
data['text'] = tp.preprocess(data['text'])

In [16]:
data['text'][2]

'hey man really not trying edit war just guy constantly removing relevant information talking edits instead talk page seems care more formatting actual info'

In [17]:
with open("toxic_comment_preprocessor_classnames.bin", "wb") as model_file_obj:
    cloudpickle.dump((tp, CLASS_NAMES), model_file_obj)

In [18]:
x = data['text']
y = data.drop(columns='text').values.copy()

In [19]:
x

0         explanation edits made under username hardcore...
1         aww matches background colour seemingly stuck ...
2         hey man really not trying edit war just guy co...
3         more cannot make real suggestions improvement ...
4                             sir hero chance remember page
                                ...                        
159566    second time asking view completely contradicts...
159567                 ashamed horrible thing put talk page
159568    spitzer umm theres no actual article prostitut...
159569    looks like actually put speedy first version d...
159570    really not think understand came idea bad righ...
Name: text, Length: 159571, dtype: object

In [20]:
y

array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       ...,
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]])

In [21]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

In [22]:
x_train.shape, x_test.shape, y_train.shape, y_test.shape

((127656,), (31915,), (127656, 6), (31915, 6))

In [23]:
x_train, x_test = x_train.to_list(), x_test.to_list()

In [24]:
from transformers import DistilBertTokenizerFast

In [None]:
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)

In [26]:
print(x_train[0])
print(tokenizer.tokenize(x_train[0]))
print(tokenizer(x_train[0]))

grandma terri burn trash grandma terri trash hate grandma terri hell
['grandma', 'terri', 'burn', 'trash', 'grandma', 'terri', 'trash', 'hate', 'grandma', 'terri', 'hell']
{'input_ids': [101, 13055, 26568, 6402, 11669, 13055, 26568, 11669, 5223, 13055, 26568, 3109, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [27]:
strategy = tf.distribute.MirroredStrategy()

In [28]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
N_TOKENS = 512
N_CLASSES = len(CLASS_NAMES)

In [29]:
train_tokens = tokenizer(x_train, max_length=N_TOKENS, padding="max_length", truncation=True, return_tensors="tf", return_attention_mask=True, return_token_type_ids=False)
test_tokens = tokenizer(x_test, max_length=N_TOKENS, padding="max_length", truncation=True, return_tensors="tf", return_attention_mask=True, return_token_type_ids=False)

In [30]:
train_tokens[:5]

[Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]

In [31]:
sample_weight_param = class_weight.compute_sample_weight(class_weight='balanced', y=y_train)
sample_weight_param

array([0.18495376, 0.01961102, 0.01961102, ..., 0.18495376, 0.01961102,
       0.01961102])

In [32]:
len(sample_weight_param)

127656

In [33]:
# train_tf_data = tf.data.Dataset.from_tensor_slices((dict(train_tokens), y_train, sample_weight_param))
train_tf_data = tf.data.Dataset.from_tensor_slices((dict(train_tokens), y_train))
test_tf_data = tf.data.Dataset.from_tensor_slices((dict(test_tokens), y_test))

In [34]:
del(data)
del(train_tokens)
del(test_tokens)

In [35]:
train_tf_data=train_tf_data.prefetch(tf.data.AUTOTUNE)
test_tf_data=test_tf_data.prefetch(tf.data.AUTOTUNE)

In [36]:
for i in train_tf_data.take(1):
    print(i)

({'input_ids': <tf.Tensor: shape=(512,), dtype=int32, numpy=
array([  101, 13055, 26568,  6402, 11669, 13055, 26568, 11669,  5223,
       13055, 26568,  3109,   102,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,  

In [37]:
from transformers import TFDistilBertModel, DistilBertConfig
from tensorflow.keras.layers import Input, Dense, Dropout, Average, BatchNormalization

In [38]:
config = DistilBertConfig.from_pretrained(model_checkpoint, output_hidden_states=False)

In [39]:
from tensorflow.keras.optimizers.schedules import PolynomialDecay
with strategy.scope():
    model = TFDistilBertModel.from_pretrained(model_checkpoint, config=config)
    learning_schedule = PolynomialDecay(initial_learning_rate=2e-5, decay_steps=len(train_tf_data) * 10, end_learning_rate=0)
    input_ids = Input(shape=(N_TOKENS,), dtype=tf.int32, name="input_ids")
    attention_mask = Input(shape=(N_TOKENS,), dtype=tf.int32, name="attention_mask")
    x = model([input_ids, attention_mask])[0][:,0,:] # [CLS] token of last hidden state
    x = Dropout(0.3)(x)
    x = BatchNormalization()(x)
    x = Dense(1024, activation="relu")(x)
    x = Dropout(0.3)(x)
    x = BatchNormalization()(x)
    x = Dense(512, activation="relu")(x)
    x = Dropout(0.3)(x)
    x = BatchNormalization()(x)
    output = Dense(N_CLASSES, activation="sigmoid", name="output")(x)
    model = tf.keras.Model(inputs=[input_ids, attention_mask],outputs=output)
    metric = [tf.keras.metrics.AUC(multi_label=True, num_labels=N_CLASSES)]
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_schedule), metrics=metric, loss=tf.keras.losses.BinaryCrossentropy())

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


In [40]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 512)]        0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 512)]        0           []                               
                                                                                                  
 tf_distil_bert_model (TFDistil  TFBaseModelOutput(l  66362880   ['input_ids[0][0]',              
 BertModel)                     ast_hidden_state=(N               'attention_mask[0][0]']         
                                one, 512, 768),                                                   
                                 hidden_states=None                                           

In [41]:
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor="val_loss",patience=1,mode="min")

In [None]:
model.fit(train_tf_data.shuffle(len(train_tf_data)).batch(BATCH_SIZE), validation_data=test_tf_data.shuffle(len(test_tf_data)).batch(BATCH_SIZE), 
          epochs=10, callbacks=[early_stop])

In [43]:
model.save("toxic_comment_classifier_hf_distilbert.h5")

In [44]:
tf_model = tf.keras.models.load_model('toxic_comment_classifier_hf_distilbert.h5', custom_objects={"TFDistilBertModel": TFDistilBertModel})

In [45]:
import pathlib
converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

tflite_models_dir = pathlib.Path(os.path.join("tflite_models"))
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir/"toxic_comment_classifier_hf_distilbert.tflite"
tflite_model_file.write_bytes(tflite_model)

69201064

In [46]:
with open("toxic_comment_preprocessor_classnames.bin", "rb") as model_file_obj:
        text_preprocessor, class_names = cloudpickle.load(model_file_obj)
        
interpreter = tf.lite.Interpreter(model_path=os.path.join("tflite_models", "toxic_comment_classifier_hf_distilbert.tflite"))


In [47]:
def inference(text):
    text = text_preprocessor.preprocess(pd.Series(text))[0]
    
    model_checkpoint = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf")
    
    # tflite model inference  
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()[0]
    attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
    interpreter.set_tensor(input_details[0]["index"], attention_mask)
    interpreter.set_tensor(input_details[1]["index"], input_ids)
    interpreter.invoke()
    tflite_pred = interpreter.get_tensor(output_details["index"])[0]
    result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred})
    result_df.sort_values(by='prob', ascending=True, inplace=True)
    return result_df

In [49]:
inference("Hello!! How are you?")

Unnamed: 0,class,prob
3,threat,0.000621
1,severe_toxic,0.000848
5,identity_hate,0.000876
2,obscene,0.001126
4,insult,0.00154
0,toxic,0.00289
