Milestone3 / app.py
cdb24's picture
Update app.py
1d7207a
from pathlib import Path
import torch
from torch.utils.data import Dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments
import streamlit as st
import requests
model_one = "distilbert-base-uncased-finetuned-sst-2-english"
model_two = "Newtral/xlm-r-finetuned-toxic-political-tweets-es"
def toxicRating(text, model):
model = AutoModelForSequenceClassification.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
results = classifier(text)
return results
def main():
st.title("TOXIC TWEETS, \n TOXIC OR NOT?")
prompt = st.header("Select Model")
selection = st.radio("Models",('Model 1', 'Model 2'))
input = st.text_area("Enter Tweet: ")
if st.button('Rate') and input:
if selection == 'Model 1':
rating = toxicRating(input, model_one)
st.write(f"Label: {rating[1]} \n Score : {rating[3]}")
elif selection == 'Model 2':
rating = toxicRating(input, model_two)
st.write(f"Label: {rating[1]} \n Score : {rating[4]}")
else:
st.warning("Enter Tweet")
if __name__ == "__main__":
main();