Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast | |
import pandas as pd | |
import comments | |
from random import randint | |
import requests | |
def predict_cyberbullying_probability(sentence, tokenizer, model): | |
# Preprocess the input sentence | |
inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt') | |
attention_mask = inputs['attention_mask'] | |
inputs = inputs['input_ids'] | |
with torch.no_grad(): | |
# Forward pass | |
outputs = model(inputs, attention_mask=attention_mask) | |
probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten()) | |
res = probs.numpy().tolist() | |
return res | |
# @st.cache | |
def perform_cyberbullying_analysis(tweet): | |
with st.spinner(text="loading model, wait until spinner ends..."): | |
model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying') | |
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') | |
df = pd.DataFrame({'comment': [tweet]}) | |
list_probs = predict_cyberbullying_probability(tweet, tokenizer, model) | |
for i, label in enumerate(labels[1:]): | |
df[label] = list_probs[i] | |
return df | |
def perform_default_analysis(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt") | |
tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong") | |
submitted = st.form_submit_button("Analyze") | |
if submitted: | |
#loading bar | |
with st.spinner(text="loading..."): | |
out = clf(tweet) | |
st.json(out) | |
if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS": | |
st.balloons() | |
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!" | |
# response = generator(prompt, max_length=1000)[0] | |
st.success("nice tweet!") | |
else: | |
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!" | |
# response = generator(prompt, max_length=1000)[0] | |
st.error("bad tweet!") | |
# main --> | |
st.title("Toxic Tweets Analyzer") | |
st.write("π‘ Toxic Tweets Analyzer uses AI with kingsotn/finetuned_cyberbullying (distilbert) to score tweets for toxicity, threat, and insult.") | |
image = "kanye_loves_tweet.jpg" | |
st.image(image, use_column_width=True) | |
labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
with st.form("my_form"): | |
#select model | |
model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"]) | |
if model_name == "kingsotn/finetuned_cyberbullying": | |
default = "I'm not even going to lie to you. I love me so much right now." | |
tweet = st.text_area(label="Enter Text:",value=default) | |
submitted = st.form_submit_button("Analyze textbox") | |
random = st.form_submit_button("Get a random πππ tweet (warning!!)") | |
kanye = st.form_submit_button("Get a ye quote π»π€π§πΆ") | |
if random: | |
tweet = comments.comments[randint(0, 354)] | |
st.write(tweet) | |
submitted = True | |
if kanye: | |
response = requests.get('https://api.kanye.rest/') | |
if response.status_code == 200: | |
data = response.json() | |
tweet = data['quote'] | |
else: | |
st.error("Error getting Kanye quote | status code: " + str(response.status_code)) | |
st.write(tweet) | |
submitted = True | |
if submitted: | |
df = perform_cyberbullying_analysis(tweet) | |
st.table(df) | |
else: | |
perform_default_analysis(model_name) |