cs482-project / app.py
cgr28's picture
milestone-3
cf5d81e
raw
history blame
1.25 kB
import streamlit as st
from transformers import AutoTokenizer, RobertaForSequenceClassification
import numpy as np
import torch
# assignment 2
st.title("CS482 Project Sentiment Analysis")
text = st.text_area(label="Text to be analyzed", value="This sentiment analysis app is great!")
selected_model = st.radio(label="Model", options=["Model 1", "Model 2"])
analyze_button = st.button(label="Analyze")
st.markdown("**:red[Sentiment:]**")
with st.spinner(text="Analyzing..."):
if analyze_button:
if selected_model=="Model 1":
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-emotion")
model = RobertaForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-emotion")
else:
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = RobertaForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
prediction_id = logits.argmax().item()
results = model.config.id2label[prediction_id]
st.write(results)