File size: 2,857 Bytes
c8f27cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f42cb1
c8f27cf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from datasets import load_dataset
import streamlit as st

from data_utils import get_embedding

from bokeh.plotting import figure,show
from bokeh.io import push_notebook, output_notebook
# output_notebook()
from bokeh.palettes import d3

from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter
from bokeh.transform import factor_cmap, factor_mark
import base64
from io import BytesIO

label_columns=["gender","subCategory","masterCategory"]

model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model
                'microsoft/beit-base-patch16-224', # big model
                "facebook/dino-vits8",
                "facebook/levit-128S"]

def convert_base64(img):
    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return "data:image/jpeg;base64,"+img_str
                    
@st.experimental_singleton
def cache_embedding(model_name):
    dataset=load_dataset("ceyda/fashion-products-small", split="train")
    dataset=dataset.shuffle(seed=100) #pick a random seed
    viz_dat=dataset.train_test_split(0.1,shuffle=False) #์ผ๋ถ€๋ฅผ visualization์œ„ํ•ด์„œ ๋ฝ‘์‹œ๋‹จ
    viz_dat=viz_dat["test"]
    embedding = get_embedding(model_name,viz_dat)
    embedding["image"]=embedding["image"].apply(convert_base64)
    labels = {label:viz_dat.unique(label) for label in label_columns}
    return embedding,labels
    
@st.experimental_singleton
def cache_graph(model_name,color_column):
    embedding,labels=cache_embedding(model_name)
    
    color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])]
    source = ColumnDataSource(data=embedding)
    # colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique()) 
    
    
    TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select,"
    TOOLTIPS = """
        <div>
            <div>
                <img
                    src="@image" height="42" alt="@image" width="42"
                    style="float: left; margin: 0px 15px 15px 0px;"
                    border="2"
                ></img>
            </div>
    """
    p = figure(tools=TOOLS,tooltips=TOOLTIPS)
    
    p.scatter(x="x", y="y", source=source,
              # marker=factor_mark('gender', ['circle', 'circle_cross', 'circle_dot','circle_x','circle_y'], labels["gender"]),
              color=factor_cmap(color_column, color_palette, labels[color_column])
             )
 
    return p

st.write("It takes some time for the graph to load...wait please")
model_name=st.sidebar.selectbox("Model",model_interest)
color_column=st.selectbox("Color by",label_columns)
p=cache_graph(model_name,color_column)
st.bokeh_chart(p, use_container_width=False)