Milestone3 / app.py
Jainesh212's picture
Create app.py
a83ff17
raw
history blame
2.95 kB
import streamlit as st
import pandas as pd
import numpy as np
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel
import torch
from torch import cuda
from torch.utils.data import Dataset, DataLoader
import finetuning
from finetuning import CustomDistilBertClass
model_map = {
'BERT': 'bert-base-uncased',
'RoBERTa': 'roberta-base',
'DistilBERT': 'distilbert-base-uncased'
}
model_options = list(model_map.keys())
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
@st.cache_resource
def load_model(model_name):
"""Load pretrained BERT model."""
path = "finetuned_model.pt"
model = torch.load(path)
tokenizer = AutoTokenizer.from_pretrained(model_map[model_name])
return model, tokenizer
def classify_text(model, tokenizer, text):
"""Classify text using pretrained BERT model."""
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
padding='max_length',
return_tensors='pt',
truncation=True
)
with torch.no_grad():
logits = model(inputs['input_ids'],inputs['attention_mask'])[0]
probabilities = torch.softmax(logits, dim=1)[0]
pred_class = torch.argmax(probabilities, dim=0)
return label_cols[pred_class], round(probabilities[0].tolist(),2)
st.title('Toxicity Classification App')
model_name = st.sidebar.selectbox('Select model', model_options)
st.sidebar.write('You selected:', model_name)
model, tokenizer = load_model(model_name)
st.subheader('Enter your text below:')
text_input = st.text_area(label='', height=100, max_chars=500)
if st.button('Classify'):
if not text_input:
st.write('Please enter some text')
else:
class_label, class_prob = classify_text(model, tokenizer, text_input)
st.subheader('Result')
st.write('Input Text:', text_input)
st.write('Highest Toxicity Class:', class_label)
st.write('Probability:', class_prob)
st.subheader('Classification Results')
if 'classification_results' not in st.session_state:
st.session_state.classification_results = pd.DataFrame(columns=['text', 'toxicity_class', 'probability'])
if st.button('Add to Results'):
if not text_input:
st.write('Please enter some text')
else:
class_label, class_prob = classify_text(model, tokenizer, text_input)
st.subheader('Result')
st.write('Input Text:', text_input)
st.write('Highest Toxicity Class:', class_label)
st.write('Probability:', class_prob)
st.session_state.classification_results = st.session_state.classification_results.append({
'text': text_input,
'toxicity_class': class_label,
'probability': class_prob
}, ignore_index=True)
st.write(st.session_state.classification_results)