ksvmuralidhar commited on
Commit
8ae18de
1 Parent(s): ef41a74

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from unidecode import unidecode
5
+ import tensorflow as tf
6
+ import cloudpickle
7
+ from transformers import DistilBertTokenizerFast
8
+ import os
9
+ from matplotlib import pyplot as plt
10
+ from PIL import Image
11
+
12
+
13
+ with open(os.path.join("models", "toxic_comment_preprocessor_classnames.bin"), "rb") as model_file_obj:
14
+ text_preprocessor, class_names = cloudpickle.load(model_file_obj)
15
+ interpreter = tf.lite.Interpreter(model_path=os.path.join("models", "toxic_comment_classifier_hf_distilbert.tflite"))
16
+
17
+
18
+ def inference(text):
19
+ text = text_preprocessor.preprocess(pd.Series(text))[0]
20
+ model_checkpoint = "distilbert-base-uncased"
21
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
22
+ tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf")
23
+
24
+ # tflite model inference
25
+ interpreter.allocate_tensors()
26
+ input_details = interpreter.get_input_details()
27
+ output_details = interpreter.get_output_details()[0]
28
+ attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
29
+ interpreter.set_tensor(input_details[0]["index"], attention_mask)
30
+ interpreter.set_tensor(input_details[1]["index"], input_ids)
31
+ interpreter.invoke()
32
+ tflite_pred = interpreter.get_tensor(output_details["index"])[0]
33
+ result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred})
34
+ result_df.sort_values(by='prob', ascending=True, inplace=True)
35
+ return result_df
36
+
37
+
38
+ def display_image(df):
39
+ fig, ax = plt.subplots(figsize=(2, 1.8))
40
+ df.plot(x='class', y='prob', kind='barh', ax=ax, color='black', ylabel='')
41
+ ax.tick_params(axis='both', which='major', labelsize=8.5)
42
+ ax.get_legend().remove()
43
+ ax.spines['top'].set_visible(False)
44
+ ax.spines['right'].set_visible(False)
45
+ ax.spines['bottom'].set_visible(False)
46
+ ax.spines['left'].set_visible(False)
47
+ ax.get_xaxis().set_ticks([])
48
+ plt.rcParams["figure.autolayout"] = True
49
+ plt.xlim(0, 1)
50
+ for n, i in enumerate([*df['prob']]):
51
+ plt.text(i+0.015, n-0.15, f'{str(np.round(i, 3))} ', fontsize=7.5)
52
+
53
+ fig.savefig("prediction.png", bbox_inches='tight', dpi=100)
54
+ image = Image.open('prediction.png')
55
+ st.write('')
56
+ st.image(image, output_format="PNG", caption="Prediction")
57
+
58
+ ############## ENTRY POINT START #######################
59
+ def main():
60
+ st.title("Toxic Comment Classifier")
61
+ comment_txt = st.text_area("Enter a comment:", "", height=100)
62
+ if st.button("Submit"):
63
+ df = inference(comment_txt)
64
+ display_image(df)
65
+
66
+ ############## ENTRY POINT END #######################
67
+
68
+ if __name__ == "__main__":
69
+ main()