File size: 1,810 Bytes
f01d363
 
 
c06e027
f01d363
 
 
 
 
0eb8992
1f50356
0eb8992
fb41620
f01d363
 
 
 
 
 
c06e027
 
104babd
fb41620
 
1f50356
 
fb41620
104babd
 
f01d363
104babd
 
 
 
1f50356
104babd
1f50356
104babd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import plotly.express as px
import streamlit as st
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_url, cached_download
import umap.umap_ as umap
import pandas as pd
import os
import joblib

def app():
    word_to_embed_list =  st.session_state['embed_list']
    with st.container():
        word_to_embed= st.text_input("Please enter your text here and we will embed it for you.",
                                         value="Woman",)
        
        if st.button("Embed"):
            with st.spinner("👑 load language model (sentence transformer)"):
                model_name = 'sentence-transformers/all-MiniLM-L6-v2'
                model = SentenceTransformer(model_name)
                REPO_ID = "peter2000/umap_embed_3d_all-MiniLM-L6-v2"
                FILENAME = "umap_embed_3d_all-MiniLM-L6-v2.sav"
                umap_model= joblib.load(cached_download(hf_hub_url(REPO_ID, FILENAME)))
                
                word_to_embed_list.append(word_to_embed)
                st.session_state['embed_list'] = word_to_embed_list
                
                examples_embeddings = model.encode(word_to_embed_list)
 
                examples_umap = umap_model.transform(examples_embeddings)

                #st.write(len(examples_umap))
                
                with st.spinner("👑 create visualisation"):  
                      fig = px.scatter_3d(
                          examples_umap[1:] , x=0, y=1, z=2,
                           # color='labels',
                          opacity = .7,    hover_data=[word_to_embed_list[1:]])
                      fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
                      fig.update_traces(marker_size=4)
                      st.plotly_chart(fig)