File size: 2,952 Bytes
a83ff17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
import pandas as pd
import numpy as np
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel
import torch
from torch import cuda
from torch.utils.data import Dataset, DataLoader
import finetuning
from finetuning import CustomDistilBertClass


model_map = {
    'BERT': 'bert-base-uncased',
    'RoBERTa': 'roberta-base',
    'DistilBERT': 'distilbert-base-uncased'
}

model_options = list(model_map.keys())
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']


@st.cache_resource
def load_model(model_name):
    """Load pretrained BERT model."""
    path = "finetuned_model.pt"
    model = torch.load(path)
    tokenizer = AutoTokenizer.from_pretrained(model_map[model_name])
    return model, tokenizer

def classify_text(model, tokenizer, text):
    """Classify text using pretrained BERT model."""
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=512, 
        padding='max_length',
        return_tensors='pt',
        truncation=True
    )
    with torch.no_grad():
        logits = model(inputs['input_ids'],inputs['attention_mask'])[0]
        probabilities = torch.softmax(logits, dim=1)[0]
        pred_class = torch.argmax(probabilities, dim=0)
    return label_cols[pred_class], round(probabilities[0].tolist(),2)
        


st.title('Toxicity Classification App')
model_name = st.sidebar.selectbox('Select model', model_options)
st.sidebar.write('You selected:', model_name)
model, tokenizer = load_model(model_name)


st.subheader('Enter your text below:')
text_input = st.text_area(label='', height=100, max_chars=500)

if st.button('Classify'):
    if not text_input:
        st.write('Please enter some text')
    else:
        class_label, class_prob = classify_text(model, tokenizer, text_input)       
        st.subheader('Result')
        st.write('Input Text:', text_input)
        st.write('Highest Toxicity Class:', class_label)
        st.write('Probability:', class_prob)

st.subheader('Classification Results')
if 'classification_results' not in st.session_state:
    st.session_state.classification_results = pd.DataFrame(columns=['text', 'toxicity_class', 'probability'])
if st.button('Add to Results'):
    if not text_input:
        st.write('Please enter some text')
    else:
        class_label, class_prob = classify_text(model, tokenizer, text_input)       
        st.subheader('Result')
        st.write('Input Text:', text_input)
        st.write('Highest Toxicity Class:', class_label)
        st.write('Probability:', class_prob)
        st.session_state.classification_results = st.session_state.classification_results.append({
            'text': text_input,
            'toxicity_class': class_label,
            'probability': class_prob
        }, ignore_index=True)
st.write(st.session_state.classification_results)