|
import streamlit as st |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.metrics import precision_recall_curve, auc |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
from scipy.stats import entropy |
|
|
|
|
|
st.sidebar.title("App Navigation") |
|
page = st.sidebar.radio("Choose a feature", ["Sentiment Analysis", "Drift Detection"]) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest") |
|
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest") |
|
|
|
|
|
def analyze_sentiments(tweets): |
|
inputs = tokenizer(tweets, return_tensors="pt", padding=True, truncation=True) |
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).detach().numpy() |
|
labels = ["Negative", "Neutral", "Positive"] |
|
predictions = [labels[np.argmax(prob)] for prob in probs] |
|
return predictions, probs |
|
|
|
|
|
def detect_drift(reference_distribution, production_probs): |
|
|
|
production_distribution = np.mean(production_probs, axis=0) |
|
|
|
|
|
kl_divergence = entropy(reference_distribution, production_distribution) |
|
|
|
return kl_divergence, production_distribution |
|
|
|
|
|
if page == "Sentiment Analysis": |
|
st.title("Twitter Sentiment Analysis App") |
|
|
|
|
|
tweets = st.text_area("Enter tweets (one per line):") |
|
if st.button("Analyze"): |
|
tweets_list = tweets.split("\n") |
|
predictions, _ = analyze_sentiments(tweets_list) |
|
st.write("Predictions:", predictions) |
|
|
|
|
|
if page == "Drift Detection": |
|
st.title("Drift Detection for Sentiment Analysis") |
|
|
|
|
|
reference_distribution = np.array([0.2, 0.5, 0.3]) |
|
|
|
|
|
st.write("Enter production tweets to monitor drift:") |
|
production_tweets = st.text_area("Tweets (one per line):") |
|
|
|
drift_threshold = st.number_input("Set Drift Alert Threshold (KL Divergence)", value=0.1) |
|
|
|
if st.button("Detect Drift"): |
|
tweets_list = production_tweets.split("\n") |
|
_, probs = analyze_sentiments(tweets_list) |
|
|
|
|
|
kl_divergence, production_distribution = detect_drift(reference_distribution, probs) |
|
|
|
|
|
st.write(f"KL Divergence: {kl_divergence:.4f}") |
|
if kl_divergence > drift_threshold: |
|
st.error("⚠️ Drift Alert: KL Divergence exceeds the threshold!") |
|
else: |
|
st.success("✅ No significant drift detected.") |
|
|
|
st.write("Reference Distribution:", reference_distribution) |
|
st.write("Production Distribution:", production_distribution) |
|
|
|
|
|
labels = ["Negative", "Neutral", "Positive"] |
|
x = np.arange(len(labels)) |
|
width = 0.35 |
|
|
|
fig, ax = plt.subplots() |
|
ax.bar(x - width/2, reference_distribution, width, label="Reference") |
|
ax.bar(x + width/2, production_distribution, width, label="Production") |
|
ax.set_ylabel("Proportion") |
|
ax.set_title("Sentiment Distribution Comparison") |
|
ax.set_xticks(x) |
|
ax.set_xticklabels(labels) |
|
ax.legend() |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
tweet_lengths = [len(tweet) for tweet in tweets_list] |
|
st.write("Tweet Length Analysis:") |
|
fig, ax = plt.subplots() |
|
ax.hist(tweet_lengths, bins=10, color='skyblue', edgecolor='black') |
|
ax.set_title("Tweet Length Distribution") |
|
ax.set_xlabel("Length of Tweet") |
|
ax.set_ylabel("Frequency") |
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|