import streamlit as st import pandas as pd import numpy as np from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel import torch from torch import cuda from torch.utils.data import Dataset, DataLoader import finetuning from finetuning import CustomDistilBertClass model_map = { 'BERT': 'bert-base-uncased', 'RoBERTa': 'roberta-base', 'DistilBERT': 'distilbert-base-uncased' } model_options = list(model_map.keys()) label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] @st.cache_resource def load_model(model_name): """Load pretrained BERT model.""" path = "finetuned_model.pt" model = torch.load(path) tokenizer = AutoTokenizer.from_pretrained(model_map[model_name]) return model, tokenizer def classify_text(model, tokenizer, text): """Classify text using pretrained BERT model.""" inputs = tokenizer.encode_plus( text, add_special_tokens=True, max_length=512, padding='max_length', return_tensors='pt', truncation=True ) with torch.no_grad(): logits = model(inputs['input_ids'],inputs['attention_mask'])[0] probabilities = torch.softmax(logits, dim=1)[0] pred_class = torch.argmax(probabilities, dim=0) return label_cols[pred_class], round(probabilities[0].tolist(),2) st.title('Toxicity Classification App') model_name = st.sidebar.selectbox('Select model', model_options) st.sidebar.write('You selected:', model_name) model, tokenizer = load_model(model_name) st.subheader('Enter your text below:') text_input = st.text_area(label='', height=100, max_chars=500) if st.button('Classify'): if not text_input: st.write('Please enter some text') else: class_label, class_prob = classify_text(model, tokenizer, text_input) st.subheader('Result') st.write('Input Text:', text_input) st.write('Highest Toxicity Class:', class_label) st.write('Probability:', class_prob) st.subheader('Classification Results') if 'classification_results' not in st.session_state: st.session_state.classification_results = pd.DataFrame(columns=['text', 'toxicity_class', 'probability']) if st.button('Add to Results'): if not text_input: st.write('Please enter some text') else: class_label, class_prob = classify_text(model, tokenizer, text_input) st.subheader('Result') st.write('Input Text:', text_input) st.write('Highest Toxicity Class:', class_label) st.write('Probability:', class_prob) st.session_state.classification_results = st.session_state.classification_results.append({ 'text': text_input, 'toxicity_class': class_label, 'probability': class_prob }, ignore_index=True) st.write(st.session_state.classification_results)