tt-ai / app.py
Kingston Yip
done with milestone-3
5e81a63
raw
history blame
3.88 kB
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast
import pandas as pd
import comments
from random import randint
def predict_cyberbullying_probability(sentence, tokenizer, model):
# Preprocess the input sentence
inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt')
attention_mask = inputs['attention_mask'].flatten()
inputs = inputs['input_ids'].flatten()
# print("\n\ninputs\n\n", inputs)
# Disable gradient computation
with torch.no_grad():
# Forward pass
outputs = model(inputs, attention_mask=attention_mask)
probs = torch.sigmoid(outputs.logits.flatten())
res = probs.numpy().tolist()
return res
@st.cache
def perform_cyberbullying_analysis(tweet):
with st.spinner(text="loading model..."):
model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying')
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
df = pd.DataFrame({'comment': [tweet]})
list_probs = predict_cyberbullying_probability(tweet, tokenizer, model)
for i, label in enumerate(labels[1:]):
df[label] = list_probs[i]
return df
def perform_default_analysis(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt")
tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong")
submitted = st.form_submit_button("Analyze")
if submitted:
#loading bar
with st.spinner(text="loading..."):
out = clf(tweet)
st.json(out)
if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS":
st.balloons()
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!"
# response = generator(prompt, max_length=1000)[0]
st.success("nice tweet!")
else:
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!"
# response = generator(prompt, max_length=1000)[0]
st.error("bad tweet!")
# main -->
st.title("Toxic Tweets Analyzer")
image = "kanye_tweet.jpg"
st.image(image, use_column_width=True)
labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# toxic_list = st.cache(comments.comments)
with st.form("my_form"):
#select model
model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"])
if model_name == "kingsotn/finetuned_cyberbullying":
if model_name == "kingsotn/finetuned_cyberbullying":
default = "I'm nice at ping pong"
tweet = st.text_area(label="Enter Text:",value=default)
submitted = st.form_submit_button("Analyze textbox")
random = st.form_submit_button("Analyze a random 😈😈😈 tweet")
if random:
tweet = comments.comments[randint(0, 354)]
df = perform_cyberbullying_analysis(tweet)
# Display the cached table
st.table(df)
else:
perform_default_analysis(model_name)