import streamlit as st import pandas as pd import numpy as np from unidecode import unidecode import tensorflow as tf import cloudpickle from transformers import DistilBertTokenizerFast import os from matplotlib import pyplot as plt from PIL import Image with open(os.path.join("models", "toxic_comment_preprocessor_classnames.bin"), "rb") as model_file_obj: text_preprocessor, class_names = cloudpickle.load(model_file_obj) interpreter = tf.lite.Interpreter(model_path=os.path.join("models", "toxic_comment_classifier_hf_distilbert.tflite")) def sigmoid(x): return 1 / (1 + np.exp(-x)) def inference(text): text = text_preprocessor.preprocess(pd.Series(text))[0] model_checkpoint = "distilbert-base-uncased" tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf") # tflite model inference interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details()[0] attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] interpreter.set_tensor(input_details[0]["index"], attention_mask) interpreter.set_tensor(input_details[1]["index"], input_ids) interpreter.invoke() tflite_logits = interpreter.get_tensor(output_details["index"])[0] tflite_pred = sigmoid(tflite_logits) result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred}) result_df.sort_values(by='prob', ascending=True, inplace=True) return result_df def display_image(df): fig, ax = plt.subplots(figsize=(2, 1.8)) df.plot(x='class', y='prob', kind='barh', ax=ax, color='black', ylabel='') ax.tick_params(axis='both', which='major', labelsize=8.5) ax.get_legend().remove() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) ax.get_xaxis().set_ticks([]) plt.rcParams["figure.autolayout"] = True plt.xlim(0, 1) for n, i in enumerate([*df['prob']]): plt.text(i+0.015, n-0.15, f'{str(np.round(i, 3))} ', fontsize=7.5) fig.savefig("prediction.png", bbox_inches='tight', dpi=100) image = Image.open('prediction.png') st.write('') st.image(image, output_format="PNG", caption="Prediction") ############## ENTRY POINT START ####################### def main(): st.title("Toxic Comment Classifier") comment_txt = st.text_area("Enter a comment:", "", height=100) if st.button("Submit"): df = inference(comment_txt) display_image(df) ############## ENTRY POINT END ####################### if __name__ == "__main__": main()