File size: 3,619 Bytes
d06496c
 
 
 
b409a80
d06496c
 
 
9615424
d06496c
2e39235
9615424
1cc7996
9615424
b409a80
d06496c
1cc7996
1eec6c3
1cc7996
 
 
d06496c
1cc7996
d06496c
 
 
2113989
d06496c
 
1cc7996
 
 
 
d06496c
2113989
d06496c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c8d160
d06496c
 
4c8d160
d06496c
 
c6afe0a
4c8d160
c6afe0a
 
 
d06496c
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
import pandas as pd
import numpy as np
import altair as alt
import os
from PIL import Image
from embeddings.embeddings import load_model
from sentence_transformers import  util
import warnings

warnings.filterwarnings('ignore')

st.set_page_config(page_title="Sinhala Embedding Space", page_icon=":bar_chart:")
# cluster PNG file
image = Image.open('plots/clusters.png')

# Load data
# @st.cache_data
def load_data():
    chart_data = pd.read_csv(r"data/top_cluster_dataset.csv",dtype={'Headline': str, 'x': np.float64, 'y': np.float64, 'labels': str})
    return chart_data

chart_data = load_data()
# Create a Streamlit app

# Define tabs
tabs = ["Clustering Results","Sentences Similarity"]
selected_tab = st.sidebar.radio("Select a Tab", tabs)

def get_altair_chart():
    chart = alt.Chart(chart_data).mark_circle(size=60).encode(x='x', y='y', color='labels', tooltip=['Headline']).interactive()
    return chart

# Main content
if selected_tab == "Sentences Similarity":
    sample_sentences = chart_data['Headline'].sample(10, random_state=1).tolist()
    st.title("Calculate Sentences Similarity")
    # select model to use dropdown
    st.subheader("Select a model to use")
    model_list = ["Ransaka/SinhalaRoberta","keshan/SinhalaBERTo"]
    selected_model = st.selectbox("Select Model", model_list)
    model = load_model(selected_model)
    
    sentence1 = st.text_input("Enter Sentence 1", "")
    sentence2 = st.text_input("Enter Sentence 2", "")

    if sentence1 and sentence2:
        # add button to calculate similarity
        if st.button("Calculate Similarity"):
            with st.spinner('Calculating Similarity...'):
                # Calculate similarity
                similarity = util.pytorch_cos_sim(model.encode(sentence1), model.encode(sentence2))[0][0]
                if similarity > 0.7:
                    st.success(f"Sentences are similar (Score: {similarity:.3f})")
                elif similarity > 0.5:
                    st.warning(f"Sentences are somewhat similar (Score: {similarity:.3f})")
                else:
                    st.error(f"Sentences are not similar (Score: {similarity:.3f})")
    else:
        st.write("Enter two sentences to calculate similarity. Or start with sample sentences below.")
        # change radio button to randomize sentences and show sample sentences
        if st.button("Randomize Sentences"):
            sample_sentences = chart_data['Headline'].sample(10).tolist()
        for sentence in sample_sentences:
            # show sample sentences in small font
            st.write(sentence)

elif selected_tab == "Clustering Results":
    st.title("Clustering Results")
    
    # Display PNG image
    st.subheader("Full Clustering Results")
    st.image(image, use_column_width=False, caption='Static PNG File',width=750)
    
    # with st.spinner('Loading Interactive Results...'):
        # Display Altair chart
    st.subheader("Interactive Chart")
    chart = get_altair_chart()
    st.altair_chart(chart, use_container_width=True)
    
    # Dropdown functionality to update DataFrame
    st.subheader("Select a cluster")
    unique_clusters = chart_data['labels'].unique().tolist()
    selected_value = st.selectbox("Select Value", unique_clusters)
    
    # Filter and display results based on selected cluster
    if selected_value:
        filtered_data = chart_data[chart_data['labels'].str.contains(selected_value, case=False)].sample(10)[['Headline']].reset_index(drop=True)
        st.dataframe(filtered_data,width=750)
    else:
        st.write("Select a cluster to display results.")