menikev commited on
Commit
7bf8be4
1 Parent(s): 043c50c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -47
app.py CHANGED
@@ -1,48 +1,113 @@
1
  import streamlit as st
2
- import torch
3
- from prediction_sinhala import MDFEND, TokenizerFromPreTrained
4
-
5
-
6
-
7
- # Set constants for model and tokenizer paths
8
- MODEL_SAVE_PATH = "models/last-epoch-model-2024-03-08-15_34_03_6.pth"
9
- BERT_MODEL_NAME = 'sinhala-nlp/sinbert-sold-si'
10
- DOMAIN_NUM = 3
11
- MAX_LEN = 160
12
- BATCH_SIZE = 100
13
-
14
- # Load model and tokenizer
15
- @st.cache(allow_output_mutation=True)
16
- def load_model():
17
- # Load the tokenizer from the pre-trained model name
18
- tokenizer = TokenizerFromPreTrained(MAX_LEN, BERT_MODEL_NAME)
19
- # Initialize and load the custom model from saved state
20
- model = MDFEND(BERT_MODEL_NAME, DOMAIN_NUM, expert_num=18, mlp_dims=[5080, 4020, 3010, 2024, 1012, 606, 400])
21
- model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=torch.device('cpu')))
22
- model.eval() # Set the model to evaluation mode
23
- return model, tokenizer
24
-
25
- model, tokenizer = load_model()
26
-
27
- # User input
28
- text_input = st.text_area("Enter text here:")
29
-
30
- # Prediction
31
- if st.button("Predict"):
32
- if text_input: # Check if input is not empty
33
- # Process the input text through the custom tokenizer
34
- inputs = tokenizer.tokenize(text_input)
35
-
36
- # Convert to tensor, add batch dimension, and send to same device as model
37
- inputs = torch.tensor(inputs).unsqueeze(0).to(model.device)
38
-
39
- with torch.no_grad(): # No gradient computation
40
- # Get model prediction
41
- output_prob = model.predict(inputs)
42
-
43
- # Interpret the output probability
44
- prediction = 1 if output_prob >= 0.5 else 0
45
- result = "offensive" if prediction == 1 else "not offensive"
46
- st.write(f"Prediction: {result}")
47
- else:
48
- st.error("Please enter some text to predict.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.feature_extraction.text import CountVectorizer
5
+ import seaborn as sns
6
+ import plotly.express as px
7
+ import plotly.io as pio
8
+ import plotly.graph_objects as go
9
+
10
+ # Set page configuration
11
+ st.set_page_config(layout="wide")
12
+
13
+ # Read data into dataframes
14
+ df1 = pd.read_csv("data/reviewed_social_media_english.csv")
15
+ df2 = pd.read_csv("data/reviewed_news_english.csv")
16
+ df3 = pd.read_csv("data/tamil_social_media")
17
+ df4 = pd.read_csv("data/tamil_news")
18
+
19
+ # Normalize Text
20
+ df1['Domain'].replace("MUSLIM", "Muslim", inplace=True)
21
+ df2['Domain'].replace("MUSLIM", "Muslim", inplace=True)
22
+ df3['Domain'].replace("MUSLIM", "Muslim", inplace=True)
23
+ df4['Domain'].replace("MUSLIM", "Muslim", inplace=True)
24
+
25
+ # Drop irrelevant data
26
+ frames = [df1, df2, df3, df4]
27
+ for df in frames:
28
+ df = df[df['Domain'] != 'Not relevant']
29
+ df = df[df['Domain'] != 'None']
30
+ df = df[df['Discrimination'] != 'None']
31
+ df = df[df['Sentiment'] != 'None']
32
+
33
+ # Concatenate/merge dataframes
34
+ df = pd.concat(frames)
35
+
36
+ # Visualization function
37
+ def create_visualizations(df):
38
+ # [Existing visualization code]
39
+ pass
40
+
41
+ # Page navigation
42
+ page = st.sidebar.selectbox("Choose a page", ["Overview", "Sentiment Analysis", "Discrimination Analysis", "Channel Analysis"])
43
+
44
+ if page == "Overview":
45
+ create_visualizations(df) # Placeholder for overview visualizations
46
+ elif page == "Sentiment Analysis":
47
+ create_visualizations(df) # Placeholder for sentiment analysis visualizations
48
+ elif page == "Discrimination Analysis":
49
+ create_visualizations(df) # Placeholder for discrimination analysis visualizations
50
+ elif page == "Channel Analysis":
51
+ create_visualizations(df) # Placeholder for channel analysis visualizations
52
+
53
+ # [Place the rest of the code for the visualizations here]
54
+
55
+
56
+ # Define a color palette for consistent visualization styles
57
+ color_palette = px.colors.sequential.Viridis
58
+
59
+ # Function for Domain Distribution Chart
60
+ def create_domain_distribution_chart(df):
61
+ fig = px.pie(df, names='Domain', title='Distribution of Domains', hole=0.35)
62
+ fig.update_layout(title_x=0.5, margin=dict(l=20, r=20, t=30, b=20), legend=dict(x=0.1, y=1))
63
+ fig.update_traces(marker=dict(colors=color_palette))
64
+ return fig
65
+
66
+ # Function for Sentiment Distribution Across Domains Chart
67
+ def create_sentiment_distribution_chart(df):
68
+ # ... [Include the existing code for the Sentiment Distribution chart]
69
+ fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
70
+ return fig
71
+
72
+ # ... [Define other chart functions following the same pattern]
73
+
74
+ # Function for Channel-wise Sentiment Over Time Chart
75
+ def create_channel_sentiment_over_time_chart(df):
76
+ df['Date'] = pd.to_datetime(df['Date'])
77
+ timeline = df.groupby([df['Date'].dt.to_period('M'), 'Channel', 'Sentiment']).size().unstack(fill_value=0)
78
+ fig = px.line(timeline, x=timeline.index.levels[1].to_timestamp(), y=['Positive', 'Negative', 'Neutral'], color='Channel')
79
+ fig.update_layout(title='Channel-wise Sentiment Over Time', margin=dict(l=20, r=20, t=40, b=20))
80
+ return fig
81
+
82
+ # Function for Channel-wise Distribution of Discriminative Content Chart
83
+ def create_channel_discrimination_chart(df):
84
+ channel_discrimination = df.groupby(['Channel', 'Discrimination']).size().unstack(fill_value=0)
85
+ fig = px.bar(channel_discrimination, x=channel_discrimination.index, y=['Discriminative', 'Non-Discriminative'], barmode='group')
86
+ fig.update_layout(title='Channel-wise Distribution of Discriminative Content', margin=dict(l=20, r=20, t=40, b=20))
87
+ return fig
88
+
89
+ # Dashboard Layout
90
+ def render_dashboard():
91
+ # Overview page layout
92
+ if page == "Overview":
93
+ st.header("Overview of Domains and Sentiments")
94
+ col1, col2 = st.beta_columns(2)
95
+ with col1:
96
+ st.plotly_chart(create_domain_distribution_chart(df))
97
+ with col2:
98
+ st.plotly_chart(create_sentiment_distribution_chart(df))
99
+ # ... [Additional overview charts]
100
+
101
+ # ... [Other pages]
102
+
103
+ # Sidebar Filters
104
+ domain_filter = st.sidebar.multiselect('Select Domain', options=df['Domain'].unique(), default=df['Domain'].unique())
105
+ channel_filter = st.sidebar.multiselect('Select Channel', options=df['Channel'].unique(), default=df['Channel'].unique())
106
+ sentiment_filter = st.sidebar.multiselect('Select Sentiment', options=df['Sentiment'].unique(), default=df['Sentiment'].unique())
107
+ discrimination_filter = st.sidebar.multiselect('Select Discrimination', options=df['Discrimination'].unique(), default=df['Discrimination'].unique())
108
+
109
+ # Apply the filters
110
+ df_filtered = df[df['Domain'].isin(domain_filter) & df['Channel'].isin(channel_filter) & df['Sentiment'].isin(sentiment_filter) & df['Discrimination'].isin(discrimination_filter)]
111
+
112
+ # Render the dashboard with filtered data
113
+ render_dashboard(df_filtered)