File size: 2,973 Bytes
d3ae133
 
 
 
 
 
a85c8ad
3bb2e98
86c0799
 
 
 
 
3042494
3bb2e98
e3903e3
3bb2e98
3042494
3bb2e98
 
3042494
3bb2e98
 
86c0799
 
 
 
 
 
 
 
 
 
 
 
 
3bb2e98
86c0799
 
 
 
3bb2e98
86c0799
 
 
 
 
 
 
3bb2e98
86c0799
 
 
 
 
 
 
 
 
 
 
3bb2e98
86c0799
 
 
 
3bb2e98
 
 
e3903e3
 
 
 
86c0799
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
import pandas as pd
import transformers
import torch
import seaborn as sns
import matplotlib.pyplot as plt

# Load the pre-trained BERT model and tokenizer
try:
    tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
    model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
except Exception as e:
    st.error(f"Error loading the model: {e}")

# Set up the Streamlit app
st.set_page_config(layout="wide")
st.title('Toxicity Classification App')

# Create a text input for the user to enter their text
text_input = st.text_input('Enter text to classify')

# Create a button to run the classification
if st.button('Classify'):
    if not text_input:
        st.warning("Please enter text to classify.")
    else:
        # Tokenize the text and convert to input IDs
        encoded_text = tokenizer.encode_plus(
            text_input,
            max_length=512,
            padding='max_length',
            truncation=True,
            add_special_tokens=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        # Run the text through the model
        with torch.no_grad():
            output = model(encoded_text['input_ids'], encoded_text['attention_mask'])
            probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0]

        # Display the classification results
        st.write('Toxic:', probabilities[0])
        st.write('Severe Toxic:', probabilities[1])
        st.write('Obscene:', probabilities[2])
        st.write('Threat:', probabilities[3])
        st.write('Insult:', probabilities[4])
        st.write('Identity Hate:', probabilities[5])

        # Create a DataFrame to store the classification results
        results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
        results_df = results_df.append({
            'Text': text_input,
            'Toxic': probabilities[0],
            'Severe Toxic': probabilities[1],
            'Obscene': probabilities[2],
            'Threat': probabilities[3],
            'Insult': probabilities[4],
            'Identity Hate': probabilities[5]
        }, ignore_index=True)

        # Append the classification results to the persistent DataFrame
        if 'results' not in st.session_state:
            st.session_state['results'] = pd.DataFrame(columns=results_df.columns)
        st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True)

# Display the persistent DataFrame
st.write('Classification Results:', st.session_state.get('results', pd.DataFrame()))

# Plot the distribution of probabilities for each category
if len(st.session_state.get('results', pd.DataFrame())) > 0:
    df = st.session_state['results']
    st.pyplot(sns.histplot(data=df, x='Toxic', kde=True))
    st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True))