import streamlit as st import torch from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast import pandas as pd import comments from random import randint def predict_cyberbullying_probability(sentence, tokenizer, model): # Preprocess the input sentence inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt') attention_mask = inputs['attention_mask'].flatten() inputs = inputs['input_ids'].flatten() # print("\n\ninputs\n\n", inputs) # Disable gradient computation with torch.no_grad(): # Forward pass outputs = model(inputs, attention_mask=attention_mask) probs = torch.sigmoid(outputs.logits.flatten()) res = probs.numpy().tolist() return res @st.cache def perform_cyberbullying_analysis(tweet): with st.spinner(text="loading model..."): model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying') tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') df = pd.DataFrame({'comment': [tweet]}) list_probs = predict_cyberbullying_probability(tweet, tokenizer, model) for i, label in enumerate(labels[1:]): df[label] = list_probs[i] return df def perform_default_analysis(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt") tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong") submitted = st.form_submit_button("Analyze") if submitted: #loading bar with st.spinner(text="loading..."): out = clf(tweet) st.json(out) if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS": st.balloons() # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!" # response = generator(prompt, max_length=1000)[0] st.success("nice tweet!") else: # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!" # response = generator(prompt, max_length=1000)[0] st.error("bad tweet!") # main --> st.title("Toxic Tweets Analyzer") image = "kanye_tweet.jpg" st.image(image, use_column_width=True) labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] # toxic_list = st.cache(comments.comments) with st.form("my_form"): #select model model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"]) if model_name == "kingsotn/finetuned_cyberbullying": if model_name == "kingsotn/finetuned_cyberbullying": default = "I'm nice at ping pong" tweet = st.text_area(label="Enter Text:",value=default) submitted = st.form_submit_button("Analyze textbox") random = st.form_submit_button("Analyze a random 😈😈😈 tweet") if random: tweet = comments.comments[randint(0, 354)] df = perform_cyberbullying_analysis(tweet) # Display the cached table st.table(df) else: perform_default_analysis(model_name)