Spaces:
Runtime error
Runtime error
Commit
·
09a5e50
1
Parent(s):
ce120cb
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_drawable_canvas import st_canvas
|
3 |
+
from PIL import Image
|
4 |
+
import time
|
5 |
+
import io
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from camera_input_live import camera_input_live
|
10 |
+
|
11 |
+
st.set_page_config(layout="wide")
|
12 |
+
|
13 |
+
|
14 |
+
def make_canvas(image):
|
15 |
+
canvas_dict = dict(
|
16 |
+
fill_color='#F00000',
|
17 |
+
stroke_color='#000000',
|
18 |
+
background_color="#FFFFFF",
|
19 |
+
background_image=image,
|
20 |
+
stroke_width=40,
|
21 |
+
update_streamlit=True,
|
22 |
+
height=512,
|
23 |
+
width=512,
|
24 |
+
drawing_mode='freedraw',
|
25 |
+
key="canvas"
|
26 |
+
)
|
27 |
+
return st_canvas(**canvas_dict)
|
28 |
+
|
29 |
+
|
30 |
+
def get_mask(image_mask: np.ndarray) -> np.ndarray:
|
31 |
+
"""Get the mask from the segmentation mask.
|
32 |
+
Args:
|
33 |
+
image_mask (np.ndarray): segmentation mask
|
34 |
+
Returns:
|
35 |
+
np.ndarray: mask
|
36 |
+
"""
|
37 |
+
# average the colors of the segmentation masks
|
38 |
+
average_color = np.mean(image_mask, axis=(2))
|
39 |
+
mask = average_color[:, :] > 0
|
40 |
+
if mask.sum() > 0:
|
41 |
+
mask = mask * 1
|
42 |
+
# 3 channels
|
43 |
+
mask = np.stack([mask, mask, mask], axis=2)
|
44 |
+
return mask
|
45 |
+
|
46 |
+
|
47 |
+
def make_prompt_fields():
|
48 |
+
st.write("### Prompting")
|
49 |
+
# prompt
|
50 |
+
prompt = st.text_input("Prompt", value="A person in a room with colored hair", key="prompt")
|
51 |
+
# negative prompt
|
52 |
+
negative_prompt = st.text_input("Negative Prompt", value="Facial hair", key="negative_prompt")
|
53 |
+
|
54 |
+
return prompt, negative_prompt
|
55 |
+
|
56 |
+
def make_input_fields():
|
57 |
+
st.write("### Parameters")
|
58 |
+
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.25, key="guidance_scale")
|
59 |
+
inference_steps = st.slider("Inference Steps", min_value=1, max_value=50, value=20, step=1, key="inference_steps")
|
60 |
+
generator_seed = st.slider("Generator Seed", min_value=0, max_value=10_000, value=0, step=1, key="generator_seed")
|
61 |
+
|
62 |
+
st.write("### Latent walk")
|
63 |
+
static_latents = st.checkbox("Static Latents", value=False, key="static_latents")
|
64 |
+
latent_walk = st.slider("Latent Walk", min_value=0.0, max_value=1.0, value=0.0, step=0.01, key="latent_walk")
|
65 |
+
|
66 |
+
return guidance_scale, inference_steps, generator_seed, static_latents, latent_walk
|
67 |
+
|
68 |
+
|
69 |
+
def decode_image(image):
|
70 |
+
cv2_img = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
71 |
+
cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
|
72 |
+
image = Image.fromarray(cv2_img).convert("RGB")
|
73 |
+
return image
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
|
77 |
+
st.sidebar.title("Sidebar")
|
78 |
+
|
79 |
+
with st.sidebar:
|
80 |
+
webcam = camera_input_live(debounce=1000, key="webcam", width=512, height=512)
|
81 |
+
prompt, negative_prompt = make_prompt_fields()
|
82 |
+
|
83 |
+
guidance_scale, inference_steps, generator_seed, static_latents, latent_walk = make_input_fields()
|
84 |
+
|
85 |
+
|
86 |
+
colA, colB = st.columns(2)
|
87 |
+
|
88 |
+
with colA:
|
89 |
+
st.write("## Webcam image")
|
90 |
+
st.write("You can draw the mask on the image below.")
|
91 |
+
image = decode_image(webcam.getvalue())
|
92 |
+
|
93 |
+
canvas = make_canvas(image)
|
94 |
+
|
95 |
+
mask = get_mask(np.array(canvas.image_data))
|
96 |
+
|
97 |
+
with colB:
|
98 |
+
st.write("## Generated image")
|
99 |
+
st.write("The generated image will appear here.")
|
100 |
+
if webcam:
|
101 |
+
st.image(webcam)
|
102 |
+
# st.image(mask*255)
|