aaa / app.py
Yi-666's picture
Update app.py
7c33340 verified
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from scipy.stats import entropy
# Sidebar navigation
st.sidebar.title("App Navigation")
page = st.sidebar.radio("Choose a feature", ["Sentiment Analysis", "Drift Detection"])
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
# Helper function for sentiment analysis
def analyze_sentiments(tweets):
inputs = tokenizer(tweets, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).detach().numpy()
labels = ["Negative", "Neutral", "Positive"]
predictions = [labels[np.argmax(prob)] for prob in probs]
return predictions, probs
# Drift Detection Function
def detect_drift(reference_distribution, production_probs):
# Calculate production distribution
production_distribution = np.mean(production_probs, axis=0)
# Compute KL Divergence
kl_divergence = entropy(reference_distribution, production_distribution)
return kl_divergence, production_distribution
# Sentiment Analysis Page
if page == "Sentiment Analysis":
st.title("Twitter Sentiment Analysis App")
# Input tweets
tweets = st.text_area("Enter tweets (one per line):")
if st.button("Analyze"):
tweets_list = tweets.split("\n")
predictions, _ = analyze_sentiments(tweets_list)
st.write("Predictions:", predictions)
# Drift Detection Page
if page == "Drift Detection":
st.title("Drift Detection for Sentiment Analysis")
# Reference distribution (from training data)
reference_distribution = np.array([0.2, 0.5, 0.3]) # Example: Negative, Neutral, Positive
# Input tweets
st.write("Enter production tweets to monitor drift:")
production_tweets = st.text_area("Tweets (one per line):")
drift_threshold = st.number_input("Set Drift Alert Threshold (KL Divergence)", value=0.1)
if st.button("Detect Drift"):
tweets_list = production_tweets.split("\n")
_, probs = analyze_sentiments(tweets_list)
# Detect drift
kl_divergence, production_distribution = detect_drift(reference_distribution, probs)
# Display results
st.write(f"KL Divergence: {kl_divergence:.4f}")
if kl_divergence > drift_threshold:
st.error("⚠️ Drift Alert: KL Divergence exceeds the threshold!")
else:
st.success("✅ No significant drift detected.")
st.write("Reference Distribution:", reference_distribution)
st.write("Production Distribution:", production_distribution)
# Plot comparison
labels = ["Negative", "Neutral", "Positive"]
x = np.arange(len(labels))
width = 0.35
fig, ax = plt.subplots()
ax.bar(x - width/2, reference_distribution, width, label="Reference")
ax.bar(x + width/2, production_distribution, width, label="Production")
ax.set_ylabel("Proportion")
ax.set_title("Sentiment Distribution Comparison")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
st.pyplot(fig)
# Input visualization: Analyze tweet lengths
tweet_lengths = [len(tweet) for tweet in tweets_list]
st.write("Tweet Length Analysis:")
fig, ax = plt.subplots()
ax.hist(tweet_lengths, bins=10, color='skyblue', edgecolor='black')
ax.set_title("Tweet Length Distribution")
ax.set_xlabel("Length of Tweet")
ax.set_ylabel("Frequency")
st.pyplot(fig)