File size: 6,384 Bytes
b0b9e1f
47cfe13
cb5f8d1
579c59e
cb5f8d1
47cfe13
 
 
21feb87
cb5f8d1
b0b9e1f
47cfe13
b0b9e1f
47cfe13
b0b9e1f
 
9fbe234
 
 
 
 
21feb87
 
 
 
 
 
 
 
 
 
 
b0b9e1f
579c59e
 
47cfe13
9fbe234
21feb87
b0b9e1f
47cfe13
 
 
 
21feb87
cb5f8d1
47cfe13
cb5f8d1
47cfe13
9fbe234
cb5f8d1
47cfe13
cb5f8d1
 
 
 
47cfe13
 
 
 
21feb87
 
 
 
cb5f8d1
47cfe13
cb5f8d1
47cfe13
 
1bd7bf1
47cfe13
cb5f8d1
 
 
 
 
 
 
1bd7bf1
21feb87
 
cb5f8d1
47cfe13
21feb87
 
 
 
47cfe13
 
 
21feb87
47cfe13
21feb87
 
47cfe13
21feb87
47cfe13
21feb87
 
 
 
 
 
 
 
cb5f8d1
47cfe13
 
21feb87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47cfe13
21feb87
 
 
47cfe13
 
 
21feb87
 
 
 
 
 
 
 
b0b9e1f
 
cb5f8d1
47cfe13
 
 
 
 
 
21feb87
47cfe13
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed,make_meme

import io

st.sidebar.subheader("This butterfly does not exist! ")
st.sidebar.image("assets/logo.png", width=200)

st.title("ButterflyGAN")

@st.experimental_singleton
def load_model_intocache(model_name,model_version):
    # model_name='ceyda/butterfly_512_base'
    gan = load_model(model_name,model_version)
    return gan

@st.experimental_singleton
def load_dataset():
    dataset=get_dataset()
    return dataset

@st.experimental_singleton
def load_variables():# Don't want to open read files over and over. not sure if it makes a diff
    st.session_state['latent_walk_code']=open("assets/code_snippets/latent_walk.py").read()
    st.session_state['latent_walk_code_music']=open("assets/code_snippets/latent_walk_music.py").read()

def img2download(image):
        imgByteArr = io.BytesIO()
        image.save(imgByteArr, format="JPEG")
        imgByteArr = imgByteArr.getvalue()
        return imgByteArr

model_name='ceyda/butterfly_cropped_uniq1K_512'
model_version='57d36a15546909557d9f967f47713236c8288838'
# model_version=None
model=load_model_intocache(model_name,model_version)
dataset=load_dataset()
load_variables()

generate_menu="πŸ¦‹ Make butterflies"
latent_walk_menu="🎧 Take a latent walk"
make_meme_menu="🐦 Make a meme"
mosaic_menu="πŸ‘€ See the mosaic"
fun_menu="Release the butterflies"

screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu])

if screen == generate_menu:
    
    batch_size=4 #generate 4 butterflies 
    col_num=4
    def run():
        with st.spinner("Generating..."):
            ims=generate(model,batch_size)
            st.session_state['ims'] = ims
    if 'ims' not in st.session_state:
        st.session_state['ims'] = None
        run()
    ims=st.session_state["ims"]
    st.write("Light-GAN model trained on 1000 butterfly images taken from the Smithsonian Museum collection. \n \
Based on [paper:](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*")

    runb=st.button("Generate", on_click=run ,help="generated on the fly maybe slow")
    if ims is not None:
        cols=st.columns(col_num)
        picks=[False]*batch_size
        for j,im in enumerate(ims):
            i=j%col_num
            cols[i].image(im, use_column_width=True)
            picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
        
        if any(picks):
            # st.write("Nearest butterflies:")
            for i,pick in enumerate(picks):
                if pick:
                    scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
                    for r in retrieved_examples["image"]:
                        cols[i].image(r, use_column_width=True)
    st.write("Nearest neighbors found in the training set according to L2 distance on 'microsoft/beit-base-patch16-224' embeddings")
    st.write(f"Latent dimension: {model.latent_dim}, image size:{model.image_size}")

elif screen == latent_walk_menu:

    latent_walk_code=open("assets/code_snippets/latent_walk.py").read()
    latent_walk_music_code=open("assets/code_snippets/latent_walk_music.py").read()
    st.write("Take a latent walk :musical_note: with cute butterflies")

    cols=st.columns(3)

    cols[0].caption("A regular walk (no music)")
    cols[0].video("assets/latent_walks/regular_walk.mp4")
    
    cols[1].caption("Walk with music :butterfly:")
    cols[1].video("assets/latent_walks/walk_happyrock.mp4")
    cols[2].caption("Walk with music :butterfly:")
    cols[2].video("assets/latent_walks/walk_cute.mp4")
    
    st.caption("Royalty Free Music from Bensound")
    st.write("🎧Did those butterflies seem to be dancing to the music?!Here is the secret:")
    with st.expander("See the Code Snippets"):
        st.write("A regular latent walk:")
        st.code(st.session_state['latent_walk_code'], language='python')
        st.write(":musical_note: latent walk with music:")
        st.code(st.session_state['latent_walk_code_music'], language='python')


elif screen == make_meme_menu:
    if "pigeon" not in st.session_state:
        st.session_state['pigeon'] = generate(model,1)[0]

    def get_pigeon():
        st.session_state['pigeon'] = generate(model,1)[0]
    
    cols= st.columns(2)
    cols[0].button("change pigeon",on_click=get_pigeon)
    no_bg=cols[1].checkbox("Remove background?",True,help="Remove the background from pigeon")
    show_text=cols[1].checkbox("Show text?",True)
    
    meme_text=st.text_input("Enter text","Is this a pigeon?")
   
    
    meme=make_meme(st.session_state['pigeon'],text=meme_text,show_text=show_text,remove_background=no_bg)
    st.image(meme)
    coly=st.columns(2)
    coly[0].download_button("Download", img2download(meme),mime="image/jpeg")
    coly[1].write("Made a cool one? [Share](https://twitter.com/intent/tweet?text=Check%20out%20the%20demo%20for%20Butterfly%20GAN%20%F0%9F%A6%8Bhttps%3A//huggingface.co/spaces/huggan/butterfly-gan%0Amade%20by%20%40ceyda_cinarel%20%26%20%40johnowhitaker%20) on Twitter")


elif screen == mosaic_menu:
    cols=st.columns(2)
    cols[0].markdown("These are all the butterflies in our [training set](https://huggingface.co/huggan/smithsonian_butterflies_subset)")
    cols[0].image("assets/train_data_mosaic_lowres.jpg")
    cols[0].write("πŸ”Ž view the high-res version [here](https://www.easyzoom.com/imageaccess/0c77e0e716f14ea7bc235447e5a4c397)")

    cols[1].markdown("These are the butterflies our model generated.")
    cols[1].image("assets/gen_mosaic_lowres.jpg")
    cols[1].write("πŸ”Ž view the high-res version [here](https://www.easyzoom.com/imageaccess/cbb04e81106c4c54a9d9f9dbfb236eab)")


# footer stuff
st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/huggan/smithsonian_butterflies_subset) used")
# Link project repo( scripts etc )

# Credits
st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
st.sidebar.caption(f"Contributors:")
st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")

## Feel free to add more & change stuff ^