Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel | |
import torch | |
from torch import cuda | |
from torch.utils.data import Dataset, DataLoader | |
import finetuning | |
from finetuning import CustomDistilBertClass | |
model_map = { | |
'BERT': 'bert-base-uncased', | |
'RoBERTa': 'roberta-base', | |
'DistilBERT': 'distilbert-base-uncased' | |
} | |
model_options = list(model_map.keys()) | |
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
def load_model(model_name): | |
"""Load pretrained BERT model.""" | |
path = "finetuned_model.pt" | |
model = torch.load(path) | |
tokenizer = AutoTokenizer.from_pretrained(model_map[model_name]) | |
return model, tokenizer | |
def classify_text(model, tokenizer, text): | |
"""Classify text using pretrained BERT model.""" | |
inputs = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=512, | |
padding='max_length', | |
return_tensors='pt', | |
truncation=True | |
) | |
with torch.no_grad(): | |
logits = model(inputs['input_ids'],inputs['attention_mask'])[0] | |
probabilities = torch.softmax(logits, dim=1)[0] | |
pred_class = torch.argmax(probabilities, dim=0) | |
return label_cols[pred_class], round(probabilities[0].tolist(),2) | |
st.title('Toxicity Classification App') | |
model_name = st.sidebar.selectbox('Select model', model_options) | |
st.sidebar.write('You selected:', model_name) | |
model, tokenizer = load_model(model_name) | |
st.subheader('Enter your text below:') | |
text_input = st.text_area(label='', height=100, max_chars=500) | |
if st.button('Classify'): | |
if not text_input: | |
st.write('Please enter some text') | |
else: | |
class_label, class_prob = classify_text(model, tokenizer, text_input) | |
st.subheader('Result') | |
st.write('Input Text:', text_input) | |
st.write('Highest Toxicity Class:', class_label) | |
st.write('Probability:', class_prob) | |
st.subheader('Classification Results') | |
if 'classification_results' not in st.session_state: | |
st.session_state.classification_results = pd.DataFrame(columns=['text', 'toxicity_class', 'probability']) | |
if st.button('Add to Results'): | |
if not text_input: | |
st.write('Please enter some text') | |
else: | |
class_label, class_prob = classify_text(model, tokenizer, text_input) | |
st.subheader('Result') | |
st.write('Input Text:', text_input) | |
st.write('Highest Toxicity Class:', class_label) | |
st.write('Probability:', class_prob) | |
st.session_state.classification_results = st.session_state.classification_results.append({ | |
'text': text_input, | |
'toxicity_class': class_label, | |
'probability': class_prob | |
}, ignore_index=True) | |
st.write(st.session_state.classification_results) |