peter2000's picture
Update apps/sdg.py
85f22bf
raw
history blame
2.77 kB
import plotly.express as px
import streamlit as st
from sentence_transformers import SentenceTransformer
import umap.umap_ as umap
import pandas as pd
import os
def app():
st.title("SDG Embedding Visualisation")
with st.spinner("πŸ‘‘ load language model (sentence transformer)"):
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)
with st.spinner("πŸ‘‘ load and embed SDG texts"):
df_osdg = pd.read_csv('https://zenodo.org/record/5550238/files/osdg-community-dataset-v21-09-30.csv',sep='\t')
df_osdg = df_osdg[df_osdg['agreement']>.95]
df_osdg = df_osdg[df_osdg['labels_positive']>4]
#df_osdg = df_osdg[:1000]
_lab_dict = {0: 'no_cat',
1:'SDG 1 - No poverty',
2:'SDG 2 - Zero hunger',
3:'SDG 3 - Good health and well-being',
4:'SDG 4 - Quality education',
5:'SDG 5 - Gender equality',
6:'SDG 6 - Clean water and sanitation',
7:'SDG 7 - Affordable and clean energy',
8:'SDG 8 - Decent work and economic growth',
9:'SDG 9 - Industry, Innovation and Infrastructure',
10:'SDG 10 - Reduced inequality',
11:'SDG 11 - Sustainable cities and communities',
12:'SDG 12 - Responsible consumption and production',
13:'SDG 13 - Climate action',
14:'SDG 14 - Life below water',
15:'SDG 15 - Life on land',
16:'SDG 16 - Peace, justice and strong institutions',
17:'SDG 17 - Partnership for the goals',}
labels = [_lab_dict[lab] for lab in df_osdg['sdg'] ]
#keys = list(df_osdg['keys'])
docs = list(df_osdg['text'])
docs_embeddings = model.encode(docs)
with st.spinner("πŸ‘‘ map to 3D for visualisation"):
n_neighbors = 15
n_components = 3
random_state =42
umap_model = (umap.UMAP(n_neighbors=n_neighbors,
n_components=n_components,
metric='cosine',
random_state=random_state)
.fit(docs_embeddings))
docs_umap = umap_model.transform(docs_embeddings)
with st.spinner("πŸ‘‘ create visualisation"):
fig = px.scatter_3d(
docs_umap, x=0, y=1, z=2,
color=labels,
opacity = .5)#, hover_data=[keys])
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.update_traces(marker_size=4)
st.plotly_chart(fig)