File size: 3,876 Bytes
82baf39
3295fa4
82baf39
 
cc59fa1
 
7c33340
8cff4a6
3295fa4
82baf39
7c33340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b77bc7b
32cc64b
82baf39
 
7c33340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82baf39
b11a3be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from scipy.stats import entropy

# Sidebar navigation
st.sidebar.title("App Navigation")
page = st.sidebar.radio("Choose a feature", ["Sentiment Analysis", "Drift Detection"])

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")

# Helper function for sentiment analysis
def analyze_sentiments(tweets):
    inputs = tokenizer(tweets, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1).detach().numpy()
    labels = ["Negative", "Neutral", "Positive"]
    predictions = [labels[np.argmax(prob)] for prob in probs]
    return predictions, probs

# Drift Detection Function
def detect_drift(reference_distribution, production_probs):
    # Calculate production distribution
    production_distribution = np.mean(production_probs, axis=0)

    # Compute KL Divergence
    kl_divergence = entropy(reference_distribution, production_distribution)

    return kl_divergence, production_distribution

# Sentiment Analysis Page
if page == "Sentiment Analysis":
    st.title("Twitter Sentiment Analysis App")
    
    # Input tweets
    tweets = st.text_area("Enter tweets (one per line):")
    if st.button("Analyze"):
        tweets_list = tweets.split("\n")
        predictions, _ = analyze_sentiments(tweets_list)
        st.write("Predictions:", predictions)

# Drift Detection Page
if page == "Drift Detection":
    st.title("Drift Detection for Sentiment Analysis")

    # Reference distribution (from training data)
    reference_distribution = np.array([0.2, 0.5, 0.3])  # Example: Negative, Neutral, Positive

    # Input tweets
    st.write("Enter production tweets to monitor drift:")
    production_tweets = st.text_area("Tweets (one per line):")
    
    drift_threshold = st.number_input("Set Drift Alert Threshold (KL Divergence)", value=0.1)

    if st.button("Detect Drift"):
        tweets_list = production_tweets.split("\n")
        _, probs = analyze_sentiments(tweets_list)

        # Detect drift
        kl_divergence, production_distribution = detect_drift(reference_distribution, probs)

        # Display results
        st.write(f"KL Divergence: {kl_divergence:.4f}")
        if kl_divergence > drift_threshold:
            st.error("⚠️ Drift Alert: KL Divergence exceeds the threshold!")
        else:
            st.success("✅ No significant drift detected.")

        st.write("Reference Distribution:", reference_distribution)
        st.write("Production Distribution:", production_distribution)

        # Plot comparison
        labels = ["Negative", "Neutral", "Positive"]
        x = np.arange(len(labels))
        width = 0.35
        
        fig, ax = plt.subplots()
        ax.bar(x - width/2, reference_distribution, width, label="Reference")
        ax.bar(x + width/2, production_distribution, width, label="Production")
        ax.set_ylabel("Proportion")
        ax.set_title("Sentiment Distribution Comparison")
        ax.set_xticks(x)
        ax.set_xticklabels(labels)
        ax.legend()

        st.pyplot(fig)

        # Input visualization: Analyze tweet lengths
        tweet_lengths = [len(tweet) for tweet in tweets_list]
        st.write("Tweet Length Analysis:")
        fig, ax = plt.subplots()
        ax.hist(tweet_lengths, bins=10, color='skyblue', edgecolor='black')
        ax.set_title("Tweet Length Distribution")
        ax.set_xlabel("Length of Tweet")
        ax.set_ylabel("Frequency")
        st.pyplot(fig)