tt-ai / app.py
Kingston Yip
added desc
9bceaae
raw
history blame
No virus
5.11 kB
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast
import pandas as pd
import comments
from random import randint
import requests
def predict_cyberbullying_probability(sentence, tokenizer, model):
# Preprocess the input sentence
inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt')
attention_mask = inputs['attention_mask']
inputs = inputs['input_ids']
with torch.no_grad():
# Forward pass
outputs = model(inputs, attention_mask=attention_mask)
probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten())
res = probs.numpy().tolist()
return res
# @st.cache
def perform_cyberbullying_analysis(tweet):
with st.spinner(text="loading model, wait until spinner ends..."):
model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying')
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
df = pd.DataFrame({'comment': [tweet]})
list_probs = predict_cyberbullying_probability(tweet, tokenizer, model)
for i, label in enumerate(labels[1:]):
df[label] = list_probs[i]
return df
def perform_default_analysis(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt")
tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong")
submitted = st.form_submit_button("Analyze")
if submitted:
#loading bar
with st.spinner(text="loading..."):
out = clf(tweet)
st.json(out)
if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS":
st.balloons()
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!"
# response = generator(prompt, max_length=1000)[0]
st.success("nice tweet!")
else:
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!"
# response = generator(prompt, max_length=1000)[0]
st.error("bad tweet!")
# main -->
st.title("Toxic Tweets Analyzer")
st.write("πŸ’‘ Toxic Tweets Analyzer is an app that helps you determine the likelihood of a tweet or any text being toxic, abusive or cyberbullying. The app offers different pre-trained models to choose from, each with their own strengths and limitations. kingsotn/finetuned_cyberbullying is a finetuned distilbert. It uses artificial intelligence to analyze the text you input and then calculates a probability score for each label: toxic, severe_toxic, obscene, threat, insult, and identity_hate. The scores range from 0 to 1, with 1 being the highest probability of that label being present in the tweet. The output is a table that shows the probability scores for each label, giving you an idea of the toxicity of the tweet. This can be helpful in identifying and preventing cyberbullying and other forms of online abuse.")
image = "kanye_loves_tweet.jpg"
st.image(image, use_column_width=True)
labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
with st.form("my_form"):
#select model
model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"])
if model_name == "kingsotn/finetuned_cyberbullying":
default = "I'm not even going to lie to you. I love me so much right now."
tweet = st.text_area(label="Enter Text:",value=default)
submitted = st.form_submit_button("Analyze textbox")
random = st.form_submit_button("Get a random 😈😈😈 tweet (warning!!)")
kanye = st.form_submit_button("Get a ye quote 🐻🎀🎧🎢")
if random:
tweet = comments.comments[randint(0, 354)]
st.write(tweet)
submitted = True
if kanye:
response = requests.get('https://api.kanye.rest/')
if response.status_code == 200:
data = response.json()
tweet = data['quote']
else:
st.error("Error getting Kanye quote | status code: " + str(response.status_code))
st.write(tweet)
submitted = True
if submitted:
df = perform_cyberbullying_analysis(tweet)
st.table(df)
else:
perform_default_analysis(model_name)