File size: 2,250 Bytes
f01d363
 
 
c06e027
f01d363
 
 
 
 
ad992d1
 
 
 
 
 
0bde925
ad992d1
0eb8992
1f50356
919b58f
 
0eb8992
15e720d
 
8dc5cc8
15e720d
a1149da
919b58f
f01d363
 
ad992d1
 
0bde925
ad992d1
4a03159
 
 
1f50356
919b58f
 
1f50356
4a03159
 
104babd
 
f01d363
104babd
 
 
 
1f50356
919b58f
1f50356
104babd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import plotly.express as px
import streamlit as st
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_url, cached_download
import umap.umap_ as umap
import pandas as pd
import os
import joblib

def init_models():
    model_name = 'sentence-transformers/all-MiniLM-L6-v2'
    model = SentenceTransformer(model_name)
    REPO_ID = "peter2000/umap_embed_3d_all-MiniLM-L6-v2"
    FILENAME = "umap_embed_3d_all-MiniLM-L6-v2.sav"
    umap_model= joblib.load(cached_download(hf_hub_url(REPO_ID, FILENAME)))
    return model, umap_model

def app():
    word_to_embed_list =  st.session_state['embed_list']
    cat_list =  st.session_state['cat_list']
    
    with st.container():
        col1, col2 = st.columns(2)
        with col1:
            word_to_embed= st.text_input("Please enter your text here and we will embed it for you.", value="",)
        with col2:                                 
            cat= st.selectbox('Categorie',  ('1', '2', '3', '4', '5'))

        
        if st.button("Embed"):
            with st.spinner("👑 Embedding your input"):
                
                model, umap_model = init_models()
                
              
                word_to_embed_list.append(word_to_embed)
                
                st.session_state['embed_list'] = word_to_embed_list
                cat_list .append(cat)
                st.session_state['cat_list '] = cat_list 
                
                phrase_to_embed = ["The book is about "+ wte for wte in word_to_embed_list]
                examples_embeddings = model.encode(phrase_to_embed)
 
                examples_umap = umap_model.transform(examples_embeddings)

                #st.write(len(examples_umap))
                
                with st.spinner("👑 create visualisation"):  
                      fig = px.scatter_3d(
                          examples_umap[1:] , x=0, y=1, z=2,
                           color=cat_list[1:] ,
                          opacity = .7,    hover_data=[word_to_embed_list[1:]])
                      fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
                      fig.update_traces(marker_size=4)
                      st.plotly_chart(fig)