import streamlit as st import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import precision_recall_curve, auc from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from scipy.stats import entropy # Sidebar navigation st.sidebar.title("App Navigation") page = st.sidebar.radio("Choose a feature", ["Sentiment Analysis", "Drift Detection"]) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest") model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest") # Helper function for sentiment analysis def analyze_sentiments(tweets): inputs = tokenizer(tweets, return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).detach().numpy() labels = ["Negative", "Neutral", "Positive"] predictions = [labels[np.argmax(prob)] for prob in probs] return predictions, probs # Drift Detection Function def detect_drift(reference_distribution, production_probs): # Calculate production distribution production_distribution = np.mean(production_probs, axis=0) # Compute KL Divergence kl_divergence = entropy(reference_distribution, production_distribution) return kl_divergence, production_distribution # Sentiment Analysis Page if page == "Sentiment Analysis": st.title("Twitter Sentiment Analysis App") # Input tweets tweets = st.text_area("Enter tweets (one per line):") if st.button("Analyze"): tweets_list = tweets.split("\n") predictions, _ = analyze_sentiments(tweets_list) st.write("Predictions:", predictions) # Drift Detection Page if page == "Drift Detection": st.title("Drift Detection for Sentiment Analysis") # Reference distribution (from training data) reference_distribution = np.array([0.2, 0.5, 0.3]) # Example: Negative, Neutral, Positive # Input tweets st.write("Enter production tweets to monitor drift:") production_tweets = st.text_area("Tweets (one per line):") drift_threshold = st.number_input("Set Drift Alert Threshold (KL Divergence)", value=0.1) if st.button("Detect Drift"): tweets_list = production_tweets.split("\n") _, probs = analyze_sentiments(tweets_list) # Detect drift kl_divergence, production_distribution = detect_drift(reference_distribution, probs) # Display results st.write(f"KL Divergence: {kl_divergence:.4f}") if kl_divergence > drift_threshold: st.error("⚠️ Drift Alert: KL Divergence exceeds the threshold!") else: st.success("✅ No significant drift detected.") st.write("Reference Distribution:", reference_distribution) st.write("Production Distribution:", production_distribution) # Plot comparison labels = ["Negative", "Neutral", "Positive"] x = np.arange(len(labels)) width = 0.35 fig, ax = plt.subplots() ax.bar(x - width/2, reference_distribution, width, label="Reference") ax.bar(x + width/2, production_distribution, width, label="Production") ax.set_ylabel("Proportion") ax.set_title("Sentiment Distribution Comparison") ax.set_xticks(x) ax.set_xticklabels(labels) ax.legend() st.pyplot(fig) # Input visualization: Analyze tweet lengths tweet_lengths = [len(tweet) for tweet in tweets_list] st.write("Tweet Length Analysis:") fig, ax = plt.subplots() ax.hist(tweet_lengths, bins=10, color='skyblue', edgecolor='black') ax.set_title("Tweet Length Distribution") ax.set_xlabel("Length of Tweet") ax.set_ylabel("Frequency") st.pyplot(fig)