wgcv commited on
Commit
bd18f96
1 Parent(s): d14459f

Add some generation

Browse files
Files changed (3) hide show
  1. app.py +35 -4
  2. utils.py +17 -0
  3. Ω +1 -0
app.py CHANGED
@@ -1,7 +1,38 @@
1
  import streamlit as st
2
 
3
- st.title("This is a demo")
4
- st.markdown("This is a description")
5
- x = st.slider("Put a number")
6
 
7
- st.write(x, "The square is", x*x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ from utils import generation,load_model
 
 
4
 
5
+ ## Site
6
+ st.title("Gen of butterfly")
7
+ st.markdown("This is lightweight_gan")
8
+
9
+ # Sidebar
10
+
11
+ st.sidebar.subheader("This is generated")
12
+ st.sidebar.image("assets/logo.png", width=200)
13
+ st.sidebar.caption("https://wgcv.me")
14
+
15
+ # Values
16
+ model_id="ceyda/butterfly_cropped_uniq1K_512"
17
+ model = load_model(model_id)
18
+ n_gen = 16
19
+
20
+ def run():
21
+ with st.spinner("Loading the model"):
22
+
23
+ ims = generation(model,batch_size=n_gen)
24
+ st.session_state["ims"] = ims
25
+
26
+ if("ims" not in st.session_state):
27
+ st.session_state["ims"] = None
28
+ run()
29
+
30
+
31
+ ims = st.session_state["ims"]
32
+ run_button = st.button("Gen AI butterfly", on_click=run,help="This would run the model")
33
+
34
+ if(ims is not None):
35
+ cols = st.columns(n_gen)
36
+ for j,im in enumerate(ims):
37
+ i = j % n_gen
38
+ cols[i].image(im, use_column_width=True)
utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
7
+
8
+ def load_model(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
9
+ gan = LightweightGAN.from_pretrained(model_name, version=model_version)
10
+ gan.eval()
11
+ return gan
12
+
13
+ def generation(gan, batch_size=1):
14
+ with torch.no_grad():
15
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
16
+ ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
17
+ return ims
Ω ADDED
@@ -0,0 +1 @@
 
 
1
+