BertChristiaens commited on
Commit
09a5e50
·
1 Parent(s): ce120cb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_drawable_canvas import st_canvas
3
+ from PIL import Image
4
+ import time
5
+ import io
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from camera_input_live import camera_input_live
10
+
11
+ st.set_page_config(layout="wide")
12
+
13
+
14
+ def make_canvas(image):
15
+ canvas_dict = dict(
16
+ fill_color='#F00000',
17
+ stroke_color='#000000',
18
+ background_color="#FFFFFF",
19
+ background_image=image,
20
+ stroke_width=40,
21
+ update_streamlit=True,
22
+ height=512,
23
+ width=512,
24
+ drawing_mode='freedraw',
25
+ key="canvas"
26
+ )
27
+ return st_canvas(**canvas_dict)
28
+
29
+
30
+ def get_mask(image_mask: np.ndarray) -> np.ndarray:
31
+ """Get the mask from the segmentation mask.
32
+ Args:
33
+ image_mask (np.ndarray): segmentation mask
34
+ Returns:
35
+ np.ndarray: mask
36
+ """
37
+ # average the colors of the segmentation masks
38
+ average_color = np.mean(image_mask, axis=(2))
39
+ mask = average_color[:, :] > 0
40
+ if mask.sum() > 0:
41
+ mask = mask * 1
42
+ # 3 channels
43
+ mask = np.stack([mask, mask, mask], axis=2)
44
+ return mask
45
+
46
+
47
+ def make_prompt_fields():
48
+ st.write("### Prompting")
49
+ # prompt
50
+ prompt = st.text_input("Prompt", value="A person in a room with colored hair", key="prompt")
51
+ # negative prompt
52
+ negative_prompt = st.text_input("Negative Prompt", value="Facial hair", key="negative_prompt")
53
+
54
+ return prompt, negative_prompt
55
+
56
+ def make_input_fields():
57
+ st.write("### Parameters")
58
+ guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.25, key="guidance_scale")
59
+ inference_steps = st.slider("Inference Steps", min_value=1, max_value=50, value=20, step=1, key="inference_steps")
60
+ generator_seed = st.slider("Generator Seed", min_value=0, max_value=10_000, value=0, step=1, key="generator_seed")
61
+
62
+ st.write("### Latent walk")
63
+ static_latents = st.checkbox("Static Latents", value=False, key="static_latents")
64
+ latent_walk = st.slider("Latent Walk", min_value=0.0, max_value=1.0, value=0.0, step=0.01, key="latent_walk")
65
+
66
+ return guidance_scale, inference_steps, generator_seed, static_latents, latent_walk
67
+
68
+
69
+ def decode_image(image):
70
+ cv2_img = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
71
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
72
+ image = Image.fromarray(cv2_img).convert("RGB")
73
+ return image
74
+
75
+ if __name__ == "__main__":
76
+
77
+ st.sidebar.title("Sidebar")
78
+
79
+ with st.sidebar:
80
+ webcam = camera_input_live(debounce=1000, key="webcam", width=512, height=512)
81
+ prompt, negative_prompt = make_prompt_fields()
82
+
83
+ guidance_scale, inference_steps, generator_seed, static_latents, latent_walk = make_input_fields()
84
+
85
+
86
+ colA, colB = st.columns(2)
87
+
88
+ with colA:
89
+ st.write("## Webcam image")
90
+ st.write("You can draw the mask on the image below.")
91
+ image = decode_image(webcam.getvalue())
92
+
93
+ canvas = make_canvas(image)
94
+
95
+ mask = get_mask(np.array(canvas.image_data))
96
+
97
+ with colB:
98
+ st.write("## Generated image")
99
+ st.write("The generated image will appear here.")
100
+ if webcam:
101
+ st.image(webcam)
102
+ # st.image(mask*255)