Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import transformers | |
import torch | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
# Load the pre-trained BERT model and tokenizer | |
try: | |
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') | |
model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6) | |
except Exception as e: | |
st.error(f"Error loading the model: {e}") | |
# Set up the Streamlit app | |
st.set_page_config(layout="wide") | |
st.title('Toxicity Classification App') | |
# Create a text input for the user to enter their text | |
text_input = st.text_input('Enter text to classify') | |
# Create a button to run the classification | |
if st.button('Classify'): | |
if not text_input: | |
st.warning("Please enter text to classify.") | |
else: | |
# Tokenize the text and convert to input IDs | |
encoded_text = tokenizer.encode_plus( | |
text_input, | |
max_length=512, | |
padding='max_length', | |
truncation=True, | |
add_special_tokens=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
# Run the text through the model | |
with torch.no_grad(): | |
output = model(encoded_text['input_ids'], encoded_text['attention_mask']) | |
probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0] | |
# Display the classification results | |
st.write('Toxic:', probabilities[0]) | |
st.write('Severe Toxic:', probabilities[1]) | |
st.write('Obscene:', probabilities[2]) | |
st.write('Threat:', probabilities[3]) | |
st.write('Insult:', probabilities[4]) | |
st.write('Identity Hate:', probabilities[5]) | |
# Create a DataFrame to store the classification results | |
results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate']) | |
results_df = results_df.append({ | |
'Text': text_input, | |
'Toxic': probabilities[0], | |
'Severe Toxic': probabilities[1], | |
'Obscene': probabilities[2], | |
'Threat': probabilities[3], | |
'Insult': probabilities[4], | |
'Identity Hate': probabilities[5] | |
}, ignore_index=True) | |
# Append the classification results to the persistent DataFrame | |
if 'results' not in st.session_state: | |
st.session_state['results'] = pd.DataFrame(columns=results_df.columns) | |
st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True) | |
# Display the persistent DataFrame | |
st.write('Classification Results:', st.session_state.get('results', pd.DataFrame())) | |
# Plot the distribution of probabilities for each category | |
if len(st.session_state.get('results', pd.DataFrame())) > 0: | |
df = st.session_state['results'] | |
st.pyplot(sns.histplot(data=df, x='Toxic', kde=True)) | |
st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True)) | |