init sketch2image
Browse files- .gitignore +5 -0
- S2I/__init__.py +2 -0
- S2I/commons/__init__.py +2 -0
- S2I/commons/controller.py +96 -0
- S2I/commons/css.py +196 -0
- S2I/logger.py +4 -0
- S2I/modules/__init__.py +1 -0
- S2I/modules/models.py +91 -0
- S2I/modules/sketch2image.py +79 -0
- S2I/modules/utils.py +78 -0
- S2I/samer/__init__.py +7 -0
- S2I/samer/automatic_mask_generator_prob.py +402 -0
- S2I/samer/model_args.py +17 -0
- S2I/samer/sam_controller.py +307 -0
- S2I/samer/seg_anything.py +54 -0
- S2I/samer/segment.py +69 -0
- S2I/samer/segmentor.py +103 -0
- S2I/samer/transfer_tools.py +47 -0
- app.py +306 -125
- requirements.txt +87 -6
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.idea
|
3 |
+
*.pyc
|
4 |
+
debug
|
5 |
+
.DS_Store
|
S2I/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modules import Sketch2ImagePipeline
|
2 |
+
from .commons import Sketch2ImageController, css, scripts
|
S2I/commons/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .controller import Sketch2ImageController
|
2 |
+
from .css import css, scripts
|
S2I/commons/controller.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import numpy as np
|
4 |
+
import base64
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from S2I import Sketch2ImagePipeline
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class Sketch2ImageController():
|
12 |
+
def __init__(self, gr):
|
13 |
+
super().__init__()
|
14 |
+
self.gr = gr
|
15 |
+
self.style_list = [
|
16 |
+
{"name": "Comic",
|
17 |
+
"prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed"},
|
18 |
+
{"name": "Cinematic",
|
19 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"},
|
20 |
+
{"name": "3D Model",
|
21 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"},
|
22 |
+
{"name": "Anime",
|
23 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"},
|
24 |
+
{"name": "Digital Art",
|
25 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"},
|
26 |
+
{"name": "Photographic",
|
27 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"},
|
28 |
+
{"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"},
|
29 |
+
{"name": "Fantasy art",
|
30 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"},
|
31 |
+
{"name": "Neonpunk",
|
32 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"},
|
33 |
+
{"name": "Manga",
|
34 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"},
|
35 |
+
]
|
36 |
+
|
37 |
+
self.styles = {k["name"]: k["prompt"] for k in self.style_list}
|
38 |
+
self.STYLE_NAMES = list(self.styles.keys())
|
39 |
+
self.DEFAULT_STYLE_NAME = "Fantasy art"
|
40 |
+
self.MAX_SEED = np.iinfo(np.int32).max
|
41 |
+
|
42 |
+
# Initialize the model once here
|
43 |
+
self.pipe = None
|
44 |
+
self.zero_options = None
|
45 |
+
def load_pipeline(self, zero_options):
|
46 |
+
if self.pipe is None or zero_options != self.zero_options:
|
47 |
+
self.pipe = Sketch2ImagePipeline()
|
48 |
+
self.zero_options = zero_options
|
49 |
+
|
50 |
+
def update_canvas(self, use_line, use_eraser):
|
51 |
+
brush_size = 20 if use_eraser else 4
|
52 |
+
_color = "#ffffff" if use_eraser else "#000000"
|
53 |
+
return self.gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
|
54 |
+
|
55 |
+
def upload_sketch(self, file):
|
56 |
+
_img = Image.open(file.name).convert("L")
|
57 |
+
return self.gr.update(value=_img, source="upload", interactive=True)
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def pil_image_to_data_uri(img, format="PNG"):
|
61 |
+
buffered = BytesIO()
|
62 |
+
img.save(buffered, format=format)
|
63 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
64 |
+
return f"data:image/{format.lower()};base64,{img_str}"
|
65 |
+
|
66 |
+
def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag):
|
67 |
+
self.load_pipeline(zero_options=options)
|
68 |
+
|
69 |
+
prompt = prompt_template.replace("{prompt}", prompt)
|
70 |
+
|
71 |
+
if type_flag == 'live-sketch':
|
72 |
+
img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
73 |
+
elif type_flag == 'upload':
|
74 |
+
img = image["composite"]
|
75 |
+
|
76 |
+
img = img.convert("RGB")
|
77 |
+
img = img.resize((512, 512))
|
78 |
+
|
79 |
+
image_t = F.to_tensor(img) > 0.5
|
80 |
+
c_t = image_t.unsqueeze(0).cuda().float()
|
81 |
+
|
82 |
+
torch.manual_seed(seed)
|
83 |
+
_, _, H, W = c_t.shape
|
84 |
+
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
output_image = self.pipe.generate(c_t, prompt, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
|
88 |
+
|
89 |
+
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
90 |
+
|
91 |
+
if type_flag == 'live-sketch':
|
92 |
+
input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
|
93 |
+
else:
|
94 |
+
input_uri = self.pil_image_to_data_uri(img)
|
95 |
+
|
96 |
+
return output_pil, self.gr.update(link=input_uri)
|
S2I/commons/css.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
css = """
|
2 |
+
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
|
3 |
+
|
4 |
+
/* Outer container */
|
5 |
+
.main {
|
6 |
+
display: flex;
|
7 |
+
justify-content: center;
|
8 |
+
align-items: flex-start;
|
9 |
+
width: 100%;
|
10 |
+
max-width: 1200px;
|
11 |
+
margin: 0 auto;
|
12 |
+
padding: 10px;
|
13 |
+
# background: linear-gradient(to right, #6a11cb, #2575fc);
|
14 |
+
# animation: diffusionArtAnimation 10s infinite alternate;
|
15 |
+
}
|
16 |
+
|
17 |
+
@keyframes diffusionArtAnimation {
|
18 |
+
0% {
|
19 |
+
background: linear-gradient(135deg, #ff9a9e, #fad0c4);
|
20 |
+
}
|
21 |
+
20% {
|
22 |
+
background: linear-gradient(135deg, #a1c4fd, #c2e9fb);
|
23 |
+
}
|
24 |
+
40% {
|
25 |
+
background: linear-gradient(135deg, #fbc2eb, #a6c1ee);
|
26 |
+
}
|
27 |
+
60% {
|
28 |
+
background: linear-gradient(135deg, #ffecd2, #fcb69f);
|
29 |
+
}
|
30 |
+
80% {
|
31 |
+
background: linear-gradient(135deg, #cfd9df, #e2ebf0);
|
32 |
+
}
|
33 |
+
100% {
|
34 |
+
background: linear-gradient(135deg, #ff9a9e, #fad0c4);
|
35 |
+
}
|
36 |
+
}
|
37 |
+
#main_row{
|
38 |
+
justify-content: center;
|
39 |
+
}
|
40 |
+
/* Hide class */
|
41 |
+
.svelte-p4aq0j {
|
42 |
+
display: none;
|
43 |
+
}
|
44 |
+
|
45 |
+
.wrap.svelte-p4aq0j.svelte-p4aq0j {
|
46 |
+
display: none;
|
47 |
+
}
|
48 |
+
|
49 |
+
#download_sketch {
|
50 |
+
display: none;
|
51 |
+
}
|
52 |
+
|
53 |
+
#download_output {
|
54 |
+
display: none;
|
55 |
+
}
|
56 |
+
|
57 |
+
#column_input, #column_output {
|
58 |
+
width: 100%;
|
59 |
+
max-width: 500px;
|
60 |
+
display: flex;
|
61 |
+
flex-direction: column;
|
62 |
+
align-items: center;
|
63 |
+
padding: 10px;
|
64 |
+
}
|
65 |
+
|
66 |
+
#tools_header, #input_header, #output_header, #process_header {
|
67 |
+
display: flex;
|
68 |
+
justify-content: center;
|
69 |
+
align-items: center;
|
70 |
+
width: 100%;
|
71 |
+
max-width: 400px;
|
72 |
+
font-size: 1.2em;
|
73 |
+
color: #fff;
|
74 |
+
text-shadow: 1px 1px 2px #000;
|
75 |
+
}
|
76 |
+
|
77 |
+
#nn {
|
78 |
+
width: 100px;
|
79 |
+
height: 100px;
|
80 |
+
}
|
81 |
+
|
82 |
+
#column_process {
|
83 |
+
display: flex;
|
84 |
+
justify-content: center;
|
85 |
+
align-items: center;
|
86 |
+
height: 600px;
|
87 |
+
}
|
88 |
+
|
89 |
+
#output_image, #input_image {
|
90 |
+
border-radius: 10px;
|
91 |
+
border: 5px solid #fff;
|
92 |
+
width: 100%;
|
93 |
+
max-width: 500px;
|
94 |
+
height: 500px;
|
95 |
+
box-sizing: border-box;
|
96 |
+
display: flex;
|
97 |
+
justify-content: center;
|
98 |
+
align-items: center;
|
99 |
+
background: rgba(255, 255, 255, 0.1);
|
100 |
+
animation: zoomInOut 5s infinite alternate;
|
101 |
+
}
|
102 |
+
|
103 |
+
@keyframes zoomInOut {
|
104 |
+
0% {
|
105 |
+
transform: scale(1);
|
106 |
+
}
|
107 |
+
50% {
|
108 |
+
transform: scale(1.05);
|
109 |
+
}
|
110 |
+
100% {
|
111 |
+
transform: scale(1);
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
#output_image > img {
|
116 |
+
border: 5px solid #fff;
|
117 |
+
border-radius: 10px;
|
118 |
+
width: 100%;
|
119 |
+
height: 100%;
|
120 |
+
box-sizing: border-box;
|
121 |
+
}
|
122 |
+
|
123 |
+
#input_image > div.image-container.svelte-p3y7hu > div.wrap.svelte-yigbas > canvas:nth-child(1) {
|
124 |
+
border: 5px solid #fff;
|
125 |
+
border-radius: 10px;
|
126 |
+
width: 100%;
|
127 |
+
height: 100%;
|
128 |
+
box-sizing: border-box;
|
129 |
+
}
|
130 |
+
|
131 |
+
/* Responsive styles */
|
132 |
+
@media (max-width: 768px) {
|
133 |
+
.main {
|
134 |
+
flex-direction: column;
|
135 |
+
width: 100%;
|
136 |
+
}
|
137 |
+
|
138 |
+
#column_input, #column_output {
|
139 |
+
width: 100%;
|
140 |
+
max-width: 100%;
|
141 |
+
padding: 10px 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
#tools_header, #input_header, #output_header, #process_header {
|
145 |
+
width: 100%;
|
146 |
+
}
|
147 |
+
|
148 |
+
#column_process {
|
149 |
+
height: auto;
|
150 |
+
}
|
151 |
+
|
152 |
+
#output_image, #input_image {
|
153 |
+
max-width: 100%;
|
154 |
+
height: auto;
|
155 |
+
}
|
156 |
+
}
|
157 |
+
|
158 |
+
@media (max-width: 480px) {
|
159 |
+
#nn {
|
160 |
+
width: 80px;
|
161 |
+
height: 80px;
|
162 |
+
}
|
163 |
+
|
164 |
+
#tools_header, #input_header, #output_header, #process_header {
|
165 |
+
max-width: 100%;
|
166 |
+
font-size: 14px;
|
167 |
+
}
|
168 |
+
|
169 |
+
#column_input, #column_output {
|
170 |
+
max-width: 100%;
|
171 |
+
padding: 10px;
|
172 |
+
}
|
173 |
+
}
|
174 |
+
# .flex{
|
175 |
+
# background-color: #0b0f19;
|
176 |
+
# }
|
177 |
+
"""
|
178 |
+
|
179 |
+
scripts = """
|
180 |
+
async () => {
|
181 |
+
globalThis.theSketchDownloadFunction = () => {
|
182 |
+
console.log("test")
|
183 |
+
var link = document.createElement("a");
|
184 |
+
dataUri = document.getElementById('download_sketch').href
|
185 |
+
link.setAttribute("href", dataUri)
|
186 |
+
link.setAttribute("download", "sketch.png")
|
187 |
+
document.body.appendChild(link); // Required for Firefox
|
188 |
+
link.click();
|
189 |
+
document.body.removeChild(link); // Clean up
|
190 |
+
|
191 |
+
// also call the output download function
|
192 |
+
theOutputDownloadFunction();
|
193 |
+
return false
|
194 |
+
}
|
195 |
+
}
|
196 |
+
"""
|
S2I/logger.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logging.basicConfig(level=logging.INFO,
|
3 |
+
format='%(asctime)s - %(levelname)s - %(message)s')
|
4 |
+
logger = logging.getLogger()
|
S2I/modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sketch2image import *
|
S2I/modules/models.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
from diffusers import DDPMScheduler
|
4 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
5 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
+
from peft import LoraConfig
|
7 |
+
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path
|
8 |
+
|
9 |
+
|
10 |
+
class RelationShipConvolution(torch.nn.Module):
|
11 |
+
def __init__(self, conv_in_pretrained, conv_in_curr, r):
|
12 |
+
super(RelationShipConvolution, self).__init__()
|
13 |
+
self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained)
|
14 |
+
self.conv_in_curr = copy.deepcopy(conv_in_curr)
|
15 |
+
self.r = r
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x1 = self.conv_in_pretrained(x).detach()
|
19 |
+
x2 = self.conv_in_curr(x)
|
20 |
+
return x1 * (1 - self.r) + x2 * self.r
|
21 |
+
|
22 |
+
|
23 |
+
class PrimaryModel:
|
24 |
+
def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'):
|
25 |
+
self.backbone_diffusion_path = backbone_diffusion_path
|
26 |
+
self.global_unet = None
|
27 |
+
self.global_vae = None
|
28 |
+
self.global_tokenizer = None
|
29 |
+
self.global_text_encoder = None
|
30 |
+
self.global_scheduler = None
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def _load_model(path, model_class, unet_mode=False):
|
34 |
+
model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda')
|
35 |
+
return model
|
36 |
+
|
37 |
+
|
38 |
+
def one_step_scheduler(self):
|
39 |
+
noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler")
|
40 |
+
noise_scheduler_1step.set_timesteps(1, device="cuda")
|
41 |
+
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
|
42 |
+
return noise_scheduler_1step
|
43 |
+
|
44 |
+
def skip_connections(self, vae):
|
45 |
+
vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
46 |
+
vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
47 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
48 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
49 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
50 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
51 |
+
vae.decoder.ignore_skip = False
|
52 |
+
return vae
|
53 |
+
|
54 |
+
def from_pretrained(self, model_name, r):
|
55 |
+
if self.global_tokenizer is None:
|
56 |
+
# self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
|
57 |
+
# subfolder="tokenizer")
|
58 |
+
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
|
59 |
+
|
60 |
+
if self.global_text_encoder is None:
|
61 |
+
self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path,
|
62 |
+
subfolder="text_encoder").to(device='cuda')
|
63 |
+
|
64 |
+
if self.global_scheduler is None:
|
65 |
+
self.global_scheduler = self.one_step_scheduler()
|
66 |
+
|
67 |
+
if self.global_vae is None:
|
68 |
+
self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL)
|
69 |
+
self.global_vae = self.skip_connections(self.global_vae)
|
70 |
+
|
71 |
+
if self.global_unet is None:
|
72 |
+
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
|
73 |
+
p_ckpt_path = download_models()
|
74 |
+
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
|
75 |
+
sd = torch.load(p_ckpt, map_location="cpu")
|
76 |
+
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
|
77 |
+
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
|
78 |
+
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
|
79 |
+
target_modules=sd["unet_lora_target_modules"])
|
80 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian",
|
81 |
+
target_modules=sd["vae_lora_target_modules"])
|
82 |
+
self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
83 |
+
_sd_vae = self.global_vae.state_dict()
|
84 |
+
for k in sd["state_dict_vae"]:
|
85 |
+
_sd_vae[k] = sd["state_dict_vae"][k]
|
86 |
+
self.global_vae.load_state_dict(_sd_vae)
|
87 |
+
self.global_unet.add_adapter(unet_lora_config)
|
88 |
+
_sd_unet = self.global_unet.state_dict()
|
89 |
+
for k in sd["state_dict_unet"]:
|
90 |
+
_sd_unet[k] = sd["state_dict_unet"][k]
|
91 |
+
self.global_unet.load_state_dict(_sd_unet, strict=False)
|
S2I/modules/sketch2image.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
|
2 |
+
from S2I.modules.models import PrimaryModel
|
3 |
+
import gc
|
4 |
+
import torch
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
+
|
10 |
+
class Sketch2ImagePipeline(PrimaryModel):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.timestep = torch.tensor([999], device="cuda").long()
|
14 |
+
|
15 |
+
def generate(self, c_t, prompt=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
|
16 |
+
self.from_pretrained(model_name=model_name, r=r)
|
17 |
+
assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
|
18 |
+
|
19 |
+
if half_model == 'float16':
|
20 |
+
output_image = self._generate_fp16(c_t, prompt, prompt_tokens, r, noise_map)
|
21 |
+
else:
|
22 |
+
output_image = self._generate_full_precision(c_t, prompt, prompt_tokens, r, noise_map)
|
23 |
+
|
24 |
+
return output_image
|
25 |
+
|
26 |
+
def _generate_fp16(self, c_t, prompt, prompt_tokens, r, noise_map):
|
27 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
28 |
+
caption_enc = self._get_caption_enc(prompt, prompt_tokens)
|
29 |
+
|
30 |
+
self._set_weights_and_activate_adapters(r)
|
31 |
+
encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor
|
32 |
+
|
33 |
+
unet_input = encoded_control * r + noise_map * (1 - r)
|
34 |
+
unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample
|
35 |
+
x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample
|
36 |
+
|
37 |
+
self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks
|
38 |
+
self.global_vae.decoder.gamma = r
|
39 |
+
|
40 |
+
output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1)
|
41 |
+
|
42 |
+
return output_image
|
43 |
+
|
44 |
+
def _generate_full_precision(self, c_t, prompt, prompt_tokens, r, noise_map):
|
45 |
+
caption_enc = self._get_caption_enc(prompt, prompt_tokens)
|
46 |
+
|
47 |
+
self._set_weights_and_activate_adapters(r)
|
48 |
+
encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor
|
49 |
+
|
50 |
+
unet_input = encoded_control * r + noise_map * (1 - r)
|
51 |
+
unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample
|
52 |
+
x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample
|
53 |
+
|
54 |
+
self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks
|
55 |
+
self.global_vae.decoder.gamma = r
|
56 |
+
|
57 |
+
output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1)
|
58 |
+
|
59 |
+
return output_image
|
60 |
+
|
61 |
+
def _get_caption_enc(self, prompt, prompt_tokens):
|
62 |
+
if prompt is not None:
|
63 |
+
caption_tokens = self.global_tokenizer(prompt, max_length=self.global_tokenizer.model_max_length,
|
64 |
+
padding="max_length", truncation=True,
|
65 |
+
return_tensors="pt").input_ids.cuda()
|
66 |
+
else:
|
67 |
+
caption_tokens = prompt_tokens.cuda()
|
68 |
+
|
69 |
+
return self.global_text_encoder(caption_tokens)[0]
|
70 |
+
|
71 |
+
def _set_weights_and_activate_adapters(self, r):
|
72 |
+
self.global_unet.set_adapters(["default"], weights=[r])
|
73 |
+
set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
|
74 |
+
|
75 |
+
def _move_to_cpu(self, module):
|
76 |
+
module.to("cpu")
|
77 |
+
|
78 |
+
def _move_to_gpu(self, module):
|
79 |
+
module.to("cuda")
|
S2I/modules/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
from S2I.logger import logger
|
5 |
+
|
6 |
+
def sc_vae_encoder_fwd(self, sample):
|
7 |
+
sample = self.conv_in(sample)
|
8 |
+
self.current_down_blocks = []
|
9 |
+
|
10 |
+
for down_block in self.down_blocks:
|
11 |
+
self.current_down_blocks.append(sample)
|
12 |
+
sample = down_block(sample)
|
13 |
+
|
14 |
+
sample = self.mid_block(sample)
|
15 |
+
sample = self.conv_norm_out(sample)
|
16 |
+
sample = self.conv_act(sample)
|
17 |
+
sample = self.conv_out(sample)
|
18 |
+
return sample
|
19 |
+
|
20 |
+
def sc_vae_decoder_fwd(self, sample, latent_embeds=None):
|
21 |
+
sample = self.conv_in(sample)
|
22 |
+
upscale_dtype = next(self.up_blocks.parameters()).dtype
|
23 |
+
sample = self.mid_block(sample, latent_embeds)
|
24 |
+
sample = sample.to(upscale_dtype)
|
25 |
+
|
26 |
+
if not self.ignore_skip:
|
27 |
+
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
|
28 |
+
reversed_skip_acts = self.incoming_skip_acts[::-1]
|
29 |
+
for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)):
|
30 |
+
skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma)
|
31 |
+
sample += skip_in
|
32 |
+
sample = up_block(sample, latent_embeds)
|
33 |
+
else:
|
34 |
+
for up_block in self.up_blocks:
|
35 |
+
sample = up_block(sample, latent_embeds)
|
36 |
+
|
37 |
+
sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample)
|
38 |
+
sample = self.conv_act(sample)
|
39 |
+
sample = self.conv_out(sample)
|
40 |
+
return sample
|
41 |
+
|
42 |
+
def downloading(url, outf):
|
43 |
+
if not os.path.exists(outf):
|
44 |
+
print(f"Downloading checkpoint to {outf}")
|
45 |
+
response = requests.get(url, stream=True)
|
46 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
47 |
+
block_size = 1024 # 1 Kibibyte
|
48 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
49 |
+
with open(outf, 'wb') as file:
|
50 |
+
for data in response.iter_content(block_size):
|
51 |
+
progress_bar.update(len(data))
|
52 |
+
file.write(data)
|
53 |
+
progress_bar.close()
|
54 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
55 |
+
print("ERROR, something went wrong")
|
56 |
+
print(f"Downloaded successfully to {outf}")
|
57 |
+
|
58 |
+
|
59 |
+
def download_models():
|
60 |
+
urls = {
|
61 |
+
'350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true',
|
62 |
+
'100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
|
63 |
+
}
|
64 |
+
# Get the current working directory
|
65 |
+
ckpt_folder = os.path.join(os.getcwd(), 'checkpoints')
|
66 |
+
os.makedirs(ckpt_folder, exist_ok=True)
|
67 |
+
|
68 |
+
model_paths = {}
|
69 |
+
for model_name, url in urls.items():
|
70 |
+
outf = os.path.join(ckpt_folder, f"sketch2image_lora_{model_name}.pkl")
|
71 |
+
downloading(url, outf)
|
72 |
+
model_paths[model_name] = outf
|
73 |
+
|
74 |
+
return model_paths
|
75 |
+
|
76 |
+
|
77 |
+
def get_model_path(model_name, model_paths):
|
78 |
+
return model_paths.get(model_name, "Model not found")
|
S2I/samer/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model_args import generate_sam_args
|
2 |
+
from .segmentor import *
|
3 |
+
from .seg_anything import *
|
4 |
+
from .segment import *
|
5 |
+
from .transfer_tools import *
|
6 |
+
from .automatic_mask_generator_prob import SamAutomaticMaskAndProbabilityGenerator
|
7 |
+
from .sam_controller import SAMController
|
S2I/samer/automatic_mask_generator_prob.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from segment_anything import SamAutomaticMaskGenerator
|
7 |
+
from segment_anything.modeling import Sam
|
8 |
+
from segment_anything.utils.amg import (MaskData, area_from_rle,
|
9 |
+
batched_mask_to_box, box_xyxy_to_xywh,
|
10 |
+
batch_iterator,
|
11 |
+
uncrop_boxes_xyxy, uncrop_points,
|
12 |
+
calculate_stability_score,
|
13 |
+
coco_encode_rle, generate_crop_boxes,
|
14 |
+
is_box_near_crop_edge,
|
15 |
+
mask_to_rle_pytorch, rle_to_mask,
|
16 |
+
uncrop_masks)
|
17 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
18 |
+
|
19 |
+
|
20 |
+
def batched_mask_to_prob(masks: torch.Tensor) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
For implementation, see the following issue comment:
|
23 |
+
|
24 |
+
"To get the probability map for a mask,
|
25 |
+
we simply do element-wise sigmoid over the logits."
|
26 |
+
URL: https://github.com/facebookresearch/segment-anything/issues/226
|
27 |
+
|
28 |
+
Args:
|
29 |
+
masks: Tensor of shape [B, H, W] representing batch of binary masks.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Tensor of shape [B, H, W] representing batch of probability maps.
|
33 |
+
"""
|
34 |
+
probs = torch.sigmoid(masks).to(masks.device)
|
35 |
+
return probs
|
36 |
+
|
37 |
+
|
38 |
+
def batched_sobel_filter(probs: torch.Tensor, masks: torch.Tensor, bzp: int
|
39 |
+
) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
For implementation, see section D.2 of the paper:
|
42 |
+
|
43 |
+
"we apply a Sobel filter to the remaining masks' unthresholded probability
|
44 |
+
maps and set values to zero if they do not intersect with the outer
|
45 |
+
boundary pixels of a mask."
|
46 |
+
URL: https://arxiv.org/abs/2304.02643
|
47 |
+
|
48 |
+
Args:
|
49 |
+
probs: Tensor of shape [B, H, W] representing batch of probability maps.
|
50 |
+
masks: Tensor of shape [B, H, W] representing batch of binary masks.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Tensor of shape [B, H, W] with filtered probability maps.
|
54 |
+
"""
|
55 |
+
# probs: [B, H, W]
|
56 |
+
# Add channel dimension to make it [B, 1, H, W]
|
57 |
+
probs = probs.unsqueeze(1)
|
58 |
+
|
59 |
+
# sobel_filter: [1, 1, 3, 3]
|
60 |
+
sobel_filter_x = torch.tensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]],
|
61 |
+
dtype=torch.float32
|
62 |
+
).to(probs.device).unsqueeze(0)
|
63 |
+
sobel_filter_y = torch.tensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]],
|
64 |
+
dtype=torch.float32
|
65 |
+
).to(probs.device).unsqueeze(0)
|
66 |
+
|
67 |
+
# Apply the Sobel filters
|
68 |
+
G_x = F.conv2d(probs, sobel_filter_x, padding=1)
|
69 |
+
G_y = F.conv2d(probs, sobel_filter_y, padding=1)
|
70 |
+
|
71 |
+
# Combine the gradients
|
72 |
+
probs = torch.sqrt(G_x ** 2 + G_y ** 2)
|
73 |
+
|
74 |
+
# Iterate through each image in the batch
|
75 |
+
for i in range(probs.shape[0]):
|
76 |
+
# Convert binary mask to float
|
77 |
+
mask = masks[i].float()
|
78 |
+
|
79 |
+
G_x = F.conv2d(mask[None, None], sobel_filter_x, padding=1)
|
80 |
+
G_y = F.conv2d(mask[None, None], sobel_filter_y, padding=1)
|
81 |
+
edge = torch.sqrt(G_x ** 2 + G_y ** 2)
|
82 |
+
outer_boundary = (edge > 0).float()
|
83 |
+
|
84 |
+
# Set to zero values that don't touch the mask's outer boundary.
|
85 |
+
probs[i, 0] = probs[i, 0] * outer_boundary
|
86 |
+
|
87 |
+
# Boundary zero padding (BZP).
|
88 |
+
# See "Zero-Shot Edge Detection With SCESAME: Spectral
|
89 |
+
# Clustering-Based Ensemble for Segment Anything Model Estimation".
|
90 |
+
if bzp > 0:
|
91 |
+
probs[i, 0, 0:bzp, :] = 0
|
92 |
+
probs[i, 0, -bzp:, :] = 0
|
93 |
+
probs[i, 0, :, 0:bzp] = 0
|
94 |
+
probs[i, 0, :, -bzp:] = 0
|
95 |
+
|
96 |
+
# Remove the channel dimension
|
97 |
+
probs = probs.squeeze(1)
|
98 |
+
|
99 |
+
return probs
|
100 |
+
|
101 |
+
|
102 |
+
class SamAutomaticMaskAndProbabilityGenerator(SamAutomaticMaskGenerator):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
model: Sam,
|
106 |
+
points_per_side: Optional[int] = 16,
|
107 |
+
points_per_batch: int = 64,
|
108 |
+
pred_iou_thresh: float = 0.88,
|
109 |
+
stability_score_thresh: float = 0.95,
|
110 |
+
stability_score_offset: float = 1.0,
|
111 |
+
box_nms_thresh: float = 0.7,
|
112 |
+
crop_n_layers: int = 0,
|
113 |
+
crop_nms_thresh: float = 0.7,
|
114 |
+
crop_overlap_ratio: float = 512 / 1500,
|
115 |
+
crop_n_points_downscale_factor: int = 1,
|
116 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
117 |
+
min_mask_region_area: int = 0,
|
118 |
+
output_mode: str = "binary_mask",
|
119 |
+
nms_threshold: float = 0.7,
|
120 |
+
bzp: int = 0,
|
121 |
+
pred_iou_thresh_filtering=False,
|
122 |
+
stability_score_thresh_filtering=False,
|
123 |
+
) -> None:
|
124 |
+
"""
|
125 |
+
Using a SAM model, generates masks for the entire image.
|
126 |
+
Generates a grid of point prompts over the image, then filters
|
127 |
+
low quality and duplicate masks. The default settings are chosen
|
128 |
+
for SAM with a ViT-H backbone.
|
129 |
+
|
130 |
+
Arguments:
|
131 |
+
model (Sam): The SAM model to use for mask prediction.
|
132 |
+
points_per_side (int or None): The number of points to be sampled
|
133 |
+
along one side of the image. The total number of points is
|
134 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
135 |
+
point sampling.
|
136 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
137 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
138 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
139 |
+
model's predicted mask quality.
|
140 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
141 |
+
the stability of the mask under changes to the cutoff used to binarize
|
142 |
+
the model's mask predictions.
|
143 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
144 |
+
calculated the stability score.
|
145 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
146 |
+
suppression to filter duplicate masks.
|
147 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
148 |
+
crops of the image. Sets the number of layers to run, where each
|
149 |
+
layer has 2**i_layer number of image crops.
|
150 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
151 |
+
suppression to filter duplicate masks between different crops.
|
152 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
153 |
+
In the first crop layer, crops will overlap by this fraction of
|
154 |
+
the image length. Later layers with more crops scale down this overlap.
|
155 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
156 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
157 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
158 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
159 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
160 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
161 |
+
to remove disconnected regions and holes in masks with area smaller
|
162 |
+
than min_mask_region_area. Requires opencv.
|
163 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
164 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
165 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
166 |
+
memory.
|
167 |
+
nms_threshold (float): The IoU threshold used for non-maximal suppression
|
168 |
+
"""
|
169 |
+
super().__init__(
|
170 |
+
model,
|
171 |
+
points_per_side,
|
172 |
+
points_per_batch,
|
173 |
+
pred_iou_thresh,
|
174 |
+
stability_score_thresh,
|
175 |
+
stability_score_offset,
|
176 |
+
box_nms_thresh,
|
177 |
+
crop_n_layers,
|
178 |
+
crop_nms_thresh,
|
179 |
+
crop_overlap_ratio,
|
180 |
+
crop_n_points_downscale_factor,
|
181 |
+
point_grids,
|
182 |
+
min_mask_region_area,
|
183 |
+
output_mode,
|
184 |
+
)
|
185 |
+
self.nms_threshold = nms_threshold
|
186 |
+
self.bzp = bzp
|
187 |
+
self.pred_iou_thresh_filtering = pred_iou_thresh_filtering
|
188 |
+
self.stability_score_thresh_filtering = \
|
189 |
+
stability_score_thresh_filtering
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
193 |
+
"""
|
194 |
+
Generates masks for the given image.
|
195 |
+
|
196 |
+
Arguments:
|
197 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
201 |
+
a dict containing the following keys:
|
202 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
203 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
204 |
+
is a dictionary containing the RLE.
|
205 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
206 |
+
area (int): The area in pixels of the mask.
|
207 |
+
predicted_iou (float): The model's own prediction of the mask's
|
208 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
209 |
+
point_coords (list(list(float))): The point coordinates input
|
210 |
+
to the model to generate this mask.
|
211 |
+
stability_score (float): A measure of the mask's quality. This
|
212 |
+
is filtered on using the stability_score_thresh parameter.
|
213 |
+
crop_box (list(float)): The crop of the image used to generate
|
214 |
+
the mask, given in XYWH format.
|
215 |
+
"""
|
216 |
+
|
217 |
+
# Generate masks
|
218 |
+
mask_data = self._generate_masks(image)
|
219 |
+
|
220 |
+
# Filter small disconnected regions and holes in masks
|
221 |
+
if self.min_mask_region_area > 0:
|
222 |
+
mask_data = self.postprocess_small_regions(
|
223 |
+
mask_data,
|
224 |
+
self.min_mask_region_area,
|
225 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
226 |
+
)
|
227 |
+
|
228 |
+
# Encode masks
|
229 |
+
if self.output_mode == "coco_rle":
|
230 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
231 |
+
elif self.output_mode == "binary_mask":
|
232 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
233 |
+
else:
|
234 |
+
mask_data["segmentations"] = mask_data["rles"]
|
235 |
+
|
236 |
+
# Write mask records
|
237 |
+
curr_anns = []
|
238 |
+
for idx in range(len(mask_data["segmentations"])):
|
239 |
+
ann = {
|
240 |
+
"segmentation": mask_data["segmentations"][idx],
|
241 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
242 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
243 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
244 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
245 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
246 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
247 |
+
"prob": mask_data["probs"][idx],
|
248 |
+
}
|
249 |
+
curr_anns.append(ann)
|
250 |
+
|
251 |
+
return curr_anns
|
252 |
+
|
253 |
+
def _process_crop(
|
254 |
+
self,
|
255 |
+
image: np.ndarray,
|
256 |
+
crop_box: List[int],
|
257 |
+
crop_layer_idx: int,
|
258 |
+
orig_size: Tuple[int, ...],
|
259 |
+
) -> MaskData:
|
260 |
+
# Crop the image and calculate embeddings
|
261 |
+
x0, y0, x1, y1 = crop_box
|
262 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
263 |
+
cropped_im_size = cropped_im.shape[:2]
|
264 |
+
self.predictor.set_image(cropped_im)
|
265 |
+
|
266 |
+
# Get points for this crop
|
267 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
268 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
269 |
+
|
270 |
+
# Generate masks for this crop in batches
|
271 |
+
data = MaskData()
|
272 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
273 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
274 |
+
data.cat(batch_data)
|
275 |
+
del batch_data
|
276 |
+
self.predictor.reset_image()
|
277 |
+
|
278 |
+
# Remove duplicates within this crop.
|
279 |
+
keep_by_nms = batched_nms(
|
280 |
+
data["boxes"].float(),
|
281 |
+
data["iou_preds"],
|
282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
283 |
+
iou_threshold=self.box_nms_thresh,
|
284 |
+
)
|
285 |
+
data.filter(keep_by_nms)
|
286 |
+
|
287 |
+
# Return to the original image frame
|
288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
291 |
+
|
292 |
+
padded_probs = torch.zeros((data["probs"].shape[0], *orig_size),
|
293 |
+
dtype=torch.float32,
|
294 |
+
device=data["probs"].device)
|
295 |
+
padded_probs[:, y0:y1, x0:x1] = data["probs"]
|
296 |
+
data["probs"] = padded_probs
|
297 |
+
|
298 |
+
return data
|
299 |
+
|
300 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
301 |
+
orig_size = image.shape[:2]
|
302 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
303 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
304 |
+
)
|
305 |
+
|
306 |
+
# Iterate over image crops
|
307 |
+
data = MaskData()
|
308 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
309 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
310 |
+
data.cat(crop_data)
|
311 |
+
|
312 |
+
# Remove duplicate masks between crops
|
313 |
+
if len(crop_boxes) > 1:
|
314 |
+
# Prefer masks from smaller crops
|
315 |
+
scores = 1 / box_area(data["crop_boxes"])
|
316 |
+
scores = scores.to(data["boxes"].device)
|
317 |
+
keep_by_nms = batched_nms(
|
318 |
+
data["boxes"].float(),
|
319 |
+
scores,
|
320 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
321 |
+
iou_threshold=self.crop_nms_thresh,
|
322 |
+
)
|
323 |
+
data.filter(keep_by_nms)
|
324 |
+
|
325 |
+
data.to_numpy()
|
326 |
+
return data
|
327 |
+
|
328 |
+
def _process_batch(
|
329 |
+
self,
|
330 |
+
points: np.ndarray,
|
331 |
+
im_size: Tuple[int, ...],
|
332 |
+
crop_box: List[int],
|
333 |
+
orig_size: Tuple[int, ...],
|
334 |
+
) -> MaskData:
|
335 |
+
orig_h, orig_w = orig_size
|
336 |
+
|
337 |
+
# Run model on this batch
|
338 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
339 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
340 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
341 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
342 |
+
in_points[:, None, :],
|
343 |
+
in_labels[:, None],
|
344 |
+
multimask_output=True,
|
345 |
+
return_logits=True,
|
346 |
+
)
|
347 |
+
|
348 |
+
# Serialize predictions and store in MaskData
|
349 |
+
data = MaskData(
|
350 |
+
masks=masks.flatten(0, 1),
|
351 |
+
iou_preds=iou_preds.flatten(0, 1),
|
352 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
353 |
+
)
|
354 |
+
del masks
|
355 |
+
|
356 |
+
if self.pred_iou_thresh_filtering and self.pred_iou_thresh > 0.0:
|
357 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
358 |
+
data.filter(keep_mask)
|
359 |
+
|
360 |
+
# Calculate stability score
|
361 |
+
data["stability_score"] = calculate_stability_score(
|
362 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
363 |
+
)
|
364 |
+
|
365 |
+
if self.stability_score_thresh_filtering and \
|
366 |
+
self.stability_score_thresh > 0.0:
|
367 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
368 |
+
data.filter(keep_mask)
|
369 |
+
|
370 |
+
# Threshold masks and calculate boxes
|
371 |
+
data["probs"] = batched_mask_to_prob(data["masks"])
|
372 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
373 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
374 |
+
|
375 |
+
# Filter boxes that touch crop boundaries
|
376 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
377 |
+
if not torch.all(keep_mask):
|
378 |
+
data.filter(keep_mask)
|
379 |
+
|
380 |
+
# filter by nms
|
381 |
+
if self.nms_threshold > 0.0:
|
382 |
+
keep_mask = batched_nms(
|
383 |
+
data["boxes"].float(),
|
384 |
+
data["iou_preds"],
|
385 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
386 |
+
iou_threshold=self.nms_threshold,
|
387 |
+
)
|
388 |
+
data.filter(keep_mask)
|
389 |
+
|
390 |
+
# apply sobel filter for probability map
|
391 |
+
data["probs"] = batched_sobel_filter(data["probs"], data["masks"],
|
392 |
+
bzp=self.bzp)
|
393 |
+
|
394 |
+
# set prob to 0 for pixels outside of crop box
|
395 |
+
# data["probs"] = batched_crop_probs(data["probs"], data["boxes"])
|
396 |
+
|
397 |
+
# Compress to RLE
|
398 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
399 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
400 |
+
del data["masks"]
|
401 |
+
|
402 |
+
return data
|
S2I/samer/model_args.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def generate_sam_args(sam_checkpoint="ckpt", model_type="vit_b", points_per_side=16,
|
2 |
+
pred_iou_thresh=0.8, stability_score_thresh=0.9, crop_n_layers=1,
|
3 |
+
crop_n_points_downscale_factor=2, min_mask_region_area=200, gpu_id=0):
|
4 |
+
sam_args = {
|
5 |
+
'sam_checkpoint': f'{sam_checkpoint}/{model_type}.pth',
|
6 |
+
'model_type': model_type,
|
7 |
+
'generator_args': {
|
8 |
+
'points_per_side': points_per_side,
|
9 |
+
'pred_iou_thresh': pred_iou_thresh,
|
10 |
+
'stability_score_thresh': stability_score_thresh,
|
11 |
+
'crop_n_layers': crop_n_layers,
|
12 |
+
'crop_n_points_downscale_factor': crop_n_points_downscale_factor,
|
13 |
+
'min_mask_region_area': min_mask_region_area,
|
14 |
+
},
|
15 |
+
'gpu_id': gpu_id}
|
16 |
+
|
17 |
+
return sam_args
|
S2I/samer/sam_controller.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from S2I.samer import SegMent, generate_sam_args
|
2 |
+
from S2I.logger import logger
|
3 |
+
from tqdm import tqdm
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import cv2
|
9 |
+
import requests
|
10 |
+
|
11 |
+
|
12 |
+
class SAMController:
|
13 |
+
def __init__(self):
|
14 |
+
self.current_model_type = None
|
15 |
+
self.refine_mask = None
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def clean():
|
19 |
+
return None, None, None, None, None, [[]]
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def save_mask(refined_mask=None, save=False):
|
23 |
+
|
24 |
+
if refined_mask is not None and save:
|
25 |
+
if os.path.exists(os.path.join(os.getcwd(), 'output_render')):
|
26 |
+
shutil.rmtree(os.path.join(os.getcwd(), 'output_render'))
|
27 |
+
save_path = os.path.join(os.getcwd(), 'output_render')
|
28 |
+
os.makedirs(save_path, exist_ok=True)
|
29 |
+
cv2.imwrite(os.path.join(save_path, f'refined_mask_result.png'), (refined_mask * 255).astype('uint8'))
|
30 |
+
elif refined_mask is None and save:
|
31 |
+
return os.path.join(os.path.join(os.getcwd(), 'output_render'), f'refined_mask_result.png')
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def download_models(model_type):
|
35 |
+
dir_path = os.path.join(os.getcwd(), 'root_model')
|
36 |
+
sam_models_path = os.path.join(dir_path, 'sam_models')
|
37 |
+
|
38 |
+
# Models URLs
|
39 |
+
models_urls = {
|
40 |
+
'sam_models': {
|
41 |
+
'vit_b': 'https://huggingface.co/ybelkada/segment-anything/resolve/main/checkpoints/sam_vit_b_01ec64.pth?download=true',
|
42 |
+
'vit_l': 'https://huggingface.co/segments-arnaud/sam_vit_l/resolve/main/sam_vit_l_0b3195.pth?download=true',
|
43 |
+
'vit_h': 'https://huggingface.co/segments-arnaud/sam_vit_h/resolve/main/sam_vit_h_4b8939.pth?download=true'
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
# Download specified model type
|
48 |
+
if model_type in models_urls['sam_models']:
|
49 |
+
model_url = models_urls['sam_models'][model_type]
|
50 |
+
os.makedirs(sam_models_path, exist_ok=True)
|
51 |
+
model_path = os.path.join(sam_models_path, model_type + '.pth')
|
52 |
+
|
53 |
+
if not os.path.exists(model_path):
|
54 |
+
logger.info(f"Downloading {model_type} model...")
|
55 |
+
response = requests.get(model_url, stream=True)
|
56 |
+
response.raise_for_status() # Raise an exception for non-2xx status codes
|
57 |
+
|
58 |
+
total_size = int(response.headers.get('content-length', 0)) # Get file size from headers
|
59 |
+
with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading {model_type} model") as pbar:
|
60 |
+
with open(model_path, 'wb') as f:
|
61 |
+
for chunk in response.iter_content(chunk_size=1024):
|
62 |
+
f.write(chunk)
|
63 |
+
pbar.update(len(chunk))
|
64 |
+
logger.info(f"{model_type} model downloaded.")
|
65 |
+
else:
|
66 |
+
logger.info(f"{model_type} model already exists.")
|
67 |
+
return logger.info(f"{model_type} model download complete.")
|
68 |
+
else:
|
69 |
+
return logger.info(f"Invalid model type: {model_type}")
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def get_models_path(model_type=None, segment=False):
|
73 |
+
sam_models_path = os.path.join(os.getcwd(), 'root_model', 'sam_models')
|
74 |
+
|
75 |
+
if segment:
|
76 |
+
sam_args = generate_sam_args(sam_checkpoint=sam_models_path, model_type=model_type)
|
77 |
+
return sam_args, sam_models_path
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def get_click_prompt(click_stack, point):
|
81 |
+
click_stack[0].append(point["coord"])
|
82 |
+
click_stack[1].append(point["mode"]
|
83 |
+
)
|
84 |
+
|
85 |
+
prompt = {
|
86 |
+
"points_coord": click_stack[0],
|
87 |
+
"points_mode": click_stack[1],
|
88 |
+
"multi_mask": "True",
|
89 |
+
}
|
90 |
+
|
91 |
+
return prompt
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def read_temp_file(temp_file_wrapper):
|
95 |
+
name = temp_file_wrapper.name
|
96 |
+
with open(temp_file_wrapper.name, 'rb') as f:
|
97 |
+
# Read the content of the file
|
98 |
+
file_content = f.read()
|
99 |
+
return file_content, name
|
100 |
+
|
101 |
+
def get_meta_from_image(self, input_img):
|
102 |
+
file_content, _ = self.read_temp_file(input_img)
|
103 |
+
np_arr = np.frombuffer(file_content, np.uint8)
|
104 |
+
|
105 |
+
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
106 |
+
first_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
107 |
+
return first_frame, first_frame
|
108 |
+
|
109 |
+
def is_sam_model(self, model_type):
|
110 |
+
sam_args, sam_models_dir = self.get_models_path(model_type=model_type, segment=True)
|
111 |
+
model_path = os.path.join(sam_models_dir, model_type + '.pth')
|
112 |
+
if not os.path.exists(model_path):
|
113 |
+
self.download_models(model_type=model_type)
|
114 |
+
return 'Model is downloaded', sam_args
|
115 |
+
else:
|
116 |
+
return 'Model is already downloaded', sam_args
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def init_segment(
|
120 |
+
points_per_side,
|
121 |
+
origin_frame,
|
122 |
+
sam_args,
|
123 |
+
predict_iou_thresh=0.8,
|
124 |
+
stability_score_thresh=0.9,
|
125 |
+
crop_n_layers=1,
|
126 |
+
crop_n_points_downscale_factor=2,
|
127 |
+
min_mask_region_area=200):
|
128 |
+
if origin_frame is None:
|
129 |
+
return None, origin_frame, [[], []]
|
130 |
+
sam_args["generator_args"]["points_per_side"] = points_per_side
|
131 |
+
sam_args["generator_args"]["pred_iou_thresh"] = predict_iou_thresh
|
132 |
+
sam_args["generator_args"]["stability_score_thresh"] = stability_score_thresh
|
133 |
+
sam_args["generator_args"]["crop_n_layers"] = crop_n_layers
|
134 |
+
sam_args["generator_args"]["crop_n_points_downscale_factor"] = crop_n_points_downscale_factor
|
135 |
+
sam_args["generator_args"]["min_mask_region_area"] = min_mask_region_area
|
136 |
+
|
137 |
+
segment = SegMent(sam_args)
|
138 |
+
logger.info(f"Model Init: {sam_args}")
|
139 |
+
return segment, origin_frame, [[], []]
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def seg_acc_click(segment, prompt, origin_frame):
|
143 |
+
# seg acc to click
|
144 |
+
refined_mask, masked_frame = segment.seg_acc_click(
|
145 |
+
origin_frame=origin_frame,
|
146 |
+
coords=np.array(prompt["points_coord"]),
|
147 |
+
modes=np.array(prompt["points_mode"]),
|
148 |
+
multimask=prompt["multi_mask"],
|
149 |
+
)
|
150 |
+
return refined_mask, masked_frame
|
151 |
+
|
152 |
+
def undo_click_stack_and_refine_seg(self, segment, origin_frame, click_stack):
|
153 |
+
if segment is None:
|
154 |
+
return segment, origin_frame, [[], []]
|
155 |
+
|
156 |
+
logger.info("Undo !")
|
157 |
+
if len(click_stack[0]) > 0:
|
158 |
+
click_stack[0] = click_stack[0][: -1]
|
159 |
+
click_stack[1] = click_stack[1][: -1]
|
160 |
+
|
161 |
+
if len(click_stack[0]) > 0:
|
162 |
+
prompt = {
|
163 |
+
"points_coord": click_stack[0],
|
164 |
+
"points_mode": click_stack[1],
|
165 |
+
"multi_mask": "True",
|
166 |
+
}
|
167 |
+
|
168 |
+
_, masked_frame = self.seg_acc_click(segment, prompt, origin_frame)
|
169 |
+
return segment, masked_frame, click_stack
|
170 |
+
else:
|
171 |
+
return segment, origin_frame, [[], []]
|
172 |
+
|
173 |
+
def reload_segment(self,
|
174 |
+
check_sam,
|
175 |
+
segment,
|
176 |
+
model_type,
|
177 |
+
point_per_sides,
|
178 |
+
origin_frame,
|
179 |
+
predict_iou_thresh,
|
180 |
+
stability_score_thresh,
|
181 |
+
crop_n_layers,
|
182 |
+
crop_n_points_downscale_factor,
|
183 |
+
min_mask_region_area):
|
184 |
+
status, sam_args = check_sam(model_type)
|
185 |
+
if segment is None or status == 'Model is downloaded':
|
186 |
+
segment, _, _ = self.init_segment(point_per_sides,
|
187 |
+
origin_frame,
|
188 |
+
sam_args,
|
189 |
+
predict_iou_thresh,
|
190 |
+
stability_score_thresh,
|
191 |
+
crop_n_layers,
|
192 |
+
crop_n_points_downscale_factor,
|
193 |
+
min_mask_region_area)
|
194 |
+
self.current_model_type = model_type
|
195 |
+
return segment, self.current_model_type, status
|
196 |
+
|
197 |
+
def sam_click(self,
|
198 |
+
evt: gr.SelectData,
|
199 |
+
segment,
|
200 |
+
origin_frame,
|
201 |
+
model_type,
|
202 |
+
point_mode,
|
203 |
+
click_stack,
|
204 |
+
point_per_sides,
|
205 |
+
predict_iou_thresh,
|
206 |
+
stability_score_thresh,
|
207 |
+
crop_n_layers,
|
208 |
+
crop_n_points_downscale_factor,
|
209 |
+
min_mask_region_area):
|
210 |
+
logger.info("Click")
|
211 |
+
if point_mode == "Positive":
|
212 |
+
point = {"coord": [evt.index[0], evt.index[1]], "mode": 1}
|
213 |
+
else:
|
214 |
+
point = {"coord": [evt.index[0], evt.index[1]], "mode": 0}
|
215 |
+
click_prompt = self.get_click_prompt(click_stack, point)
|
216 |
+
segment, self.current_model_type, status = self.reload_segment(
|
217 |
+
self.is_sam_model,
|
218 |
+
segment,
|
219 |
+
model_type,
|
220 |
+
point_per_sides,
|
221 |
+
origin_frame,
|
222 |
+
predict_iou_thresh,
|
223 |
+
stability_score_thresh,
|
224 |
+
crop_n_layers,
|
225 |
+
crop_n_points_downscale_factor,
|
226 |
+
min_mask_region_area)
|
227 |
+
if segment is not None and model_type != self.current_model_type:
|
228 |
+
segment = None
|
229 |
+
segment, _, status = self.reload_segment(
|
230 |
+
self.is_sam_model,
|
231 |
+
segment,
|
232 |
+
model_type,
|
233 |
+
point_per_sides,
|
234 |
+
origin_frame,
|
235 |
+
predict_iou_thresh,
|
236 |
+
stability_score_thresh,
|
237 |
+
crop_n_layers,
|
238 |
+
crop_n_points_downscale_factor,
|
239 |
+
min_mask_region_area)
|
240 |
+
refined_mask, masked_frame = self.seg_acc_click(segment, click_prompt, origin_frame)
|
241 |
+
self.save_mask(refined_mask, save=True)
|
242 |
+
self.refine_mask = refined_mask
|
243 |
+
return segment, masked_frame, click_stack, status
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def normalize_image(image):
|
247 |
+
# Normalize the image to the range [0, 1]
|
248 |
+
min_val = image.min()
|
249 |
+
max_val = image.max()
|
250 |
+
image = (image - min_val) / (max_val - min_val)
|
251 |
+
|
252 |
+
return image
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def compute_probability(masks):
|
256 |
+
p_max = None
|
257 |
+
for mask in masks:
|
258 |
+
p = mask['prob']
|
259 |
+
if p_max is None:
|
260 |
+
p_max = p
|
261 |
+
else:
|
262 |
+
p_max = np.maximum(p_max, p)
|
263 |
+
return p_max
|
264 |
+
@staticmethod
|
265 |
+
def download_opencv_model(model_url):
|
266 |
+
opencv_model_path = os.path.join(os.getcwd(), 'edges_detection')
|
267 |
+
os.makedirs(opencv_model_path, exist_ok=True)
|
268 |
+
model_path = os.path.join(opencv_model_path, 'edges_detection' + '.yml.gz')
|
269 |
+
response = requests.get(model_url, stream=True)
|
270 |
+
response.raise_for_status() # Raise an exception for non-2xx status codes
|
271 |
+
|
272 |
+
total_size = int(response.headers.get('content-length', 0)) # Get file size from headers
|
273 |
+
with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading opencv model") as pbar:
|
274 |
+
with open(model_path, 'wb') as f:
|
275 |
+
for chunk in response.iter_content(chunk_size=1024):
|
276 |
+
f.write(chunk)
|
277 |
+
pbar.update(len(chunk))
|
278 |
+
return model_path
|
279 |
+
|
280 |
+
def automatic_sam2sketch(self,
|
281 |
+
segment,
|
282 |
+
image,
|
283 |
+
origin_frame,
|
284 |
+
model_type
|
285 |
+
):
|
286 |
+
_, sam_args = self.is_sam_model(model_type)
|
287 |
+
if segment is None or model_type != sam_args['model_type']:
|
288 |
+
segment, _, _ = self.init_segment(
|
289 |
+
points_per_side=16,
|
290 |
+
origin_frame=origin_frame,
|
291 |
+
sam_args=sam_args,
|
292 |
+
predict_iou_thresh=0.8,
|
293 |
+
stability_score_thresh=0.9,
|
294 |
+
crop_n_layers=1,
|
295 |
+
crop_n_points_downscale_factor=2,
|
296 |
+
min_mask_region_area=200)
|
297 |
+
model_path = self.download_opencv_model(model_url='https://github.com/nipunmanral/Object-Detection-using-OpenCV/raw/master/model.yml.gz')
|
298 |
+
masks = segment.automatic_generate_mask(image)
|
299 |
+
p_max = self.compute_probability(masks)
|
300 |
+
edges = self.normalize_image(p_max)
|
301 |
+
edge_detection = cv2.ximgproc.createStructuredEdgeDetection(model_path)
|
302 |
+
orimap = edge_detection.computeOrientation(edges)
|
303 |
+
edges = edge_detection.edgesNms(edges, orimap)
|
304 |
+
edges = (edges * 255).astype('uint8')
|
305 |
+
edges = 255 - edges
|
306 |
+
edges = np.stack((edges,) * 3, axis=-1)
|
307 |
+
return edges
|
S2I/samer/seg_anything.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from scipy.ndimage import binary_dilation
|
5 |
+
|
6 |
+
np.random.seed(200)
|
7 |
+
_palette = ((np.random.random((3 * 255)) * 0.7 + 0.3) * 255).astype(np.uint8).tolist()
|
8 |
+
_palette = [0, 0, 0] + _palette
|
9 |
+
|
10 |
+
|
11 |
+
def save_prediction(predict_mask, output_dir, file_name):
|
12 |
+
save_mask = Image.fromarray(predict_mask.astype(np.uint8))
|
13 |
+
save_mask = save_mask.convert(mode='P')
|
14 |
+
save_mask.putpalette(_palette)
|
15 |
+
save_mask.save(os.path.join(output_dir, file_name))
|
16 |
+
|
17 |
+
|
18 |
+
def colorize_mask(predict_mask):
|
19 |
+
save_mask = Image.fromarray(predict_mask.astype(np.uint8))
|
20 |
+
save_mask = save_mask.convert(mode='P')
|
21 |
+
save_mask.putpalette(_palette)
|
22 |
+
save_mask = save_mask.convert(mode='RGB')
|
23 |
+
return np.array(save_mask)
|
24 |
+
|
25 |
+
|
26 |
+
def draw_mask(img, mask, alpha=0.5, id_cnt=False):
|
27 |
+
img_mask = img
|
28 |
+
if id_cnt:
|
29 |
+
# very slow ~ 1s per image
|
30 |
+
obj_ids = np.unique(mask)
|
31 |
+
obj_ids = obj_ids[obj_ids != 0]
|
32 |
+
|
33 |
+
for ids in obj_ids:
|
34 |
+
# Overlay color on binary mask
|
35 |
+
if ids <= 255:
|
36 |
+
color = _palette[ids * 3:ids * 3 + 3]
|
37 |
+
else:
|
38 |
+
color = [0, 0, 0]
|
39 |
+
foreground = img * (1 - alpha) + np.ones_like(img) * alpha * np.array(color)
|
40 |
+
binary_mask = (mask == ids)
|
41 |
+
|
42 |
+
# Compose image
|
43 |
+
img_mask[binary_mask] = foreground[binary_mask]
|
44 |
+
|
45 |
+
cnt = binary_dilation(binary_mask, iterations=1) ^ binary_mask
|
46 |
+
img_mask[cnt, :] = 0
|
47 |
+
else:
|
48 |
+
binary_mask = (mask != 0)
|
49 |
+
cnt = binary_dilation(binary_mask, iterations=1) ^ binary_mask
|
50 |
+
foreground = img * (1 - alpha) + colorize_mask(mask) * alpha
|
51 |
+
img_mask[binary_mask] = foreground[binary_mask]
|
52 |
+
img_mask[cnt, :] = 0
|
53 |
+
|
54 |
+
return img_mask.astype(img.dtype)
|
S2I/samer/segment.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path.append("../../..")
|
4 |
+
sys.path.append("")
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from S2I.samer.segmentor import Segmentor
|
8 |
+
from S2I.samer.transfer_tools import draw_outline, draw_points
|
9 |
+
from S2I.samer.seg_anything import draw_mask
|
10 |
+
|
11 |
+
|
12 |
+
class SegMent:
|
13 |
+
def __init__(self, sam_args):
|
14 |
+
self.sam = Segmentor(sam_args)
|
15 |
+
self.reference_objs_list = []
|
16 |
+
self.object_idx = 1
|
17 |
+
self.curr_idx = 1
|
18 |
+
self.origin_merged_mask = None # init by segment-everything or update
|
19 |
+
self.first_frame_mask = None
|
20 |
+
|
21 |
+
# debug
|
22 |
+
self.everything_points = []
|
23 |
+
self.everything_labels = []
|
24 |
+
print("SegTracker has been initialized")
|
25 |
+
|
26 |
+
def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray, ):
|
27 |
+
# get interactive_mask
|
28 |
+
interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
|
29 |
+
refined_merged_mask = self.add_mask(interactive_mask)
|
30 |
+
|
31 |
+
# draw mask
|
32 |
+
masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
|
33 |
+
|
34 |
+
# draw bbox
|
35 |
+
masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))
|
36 |
+
|
37 |
+
return refined_merged_mask, masked_frame
|
38 |
+
|
39 |
+
def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True):
|
40 |
+
# get interactive_mask
|
41 |
+
interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask)
|
42 |
+
|
43 |
+
refined_merged_mask = self.add_mask(interactive_mask)
|
44 |
+
|
45 |
+
# draw mask
|
46 |
+
masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
|
47 |
+
masked_frame = draw_points(coords, modes, masked_frame)
|
48 |
+
|
49 |
+
# draw outline
|
50 |
+
masked_frame = draw_outline(interactive_mask, masked_frame)
|
51 |
+
|
52 |
+
return refined_merged_mask, masked_frame
|
53 |
+
|
54 |
+
def add_mask(self, interactive_mask: np.ndarray):
|
55 |
+
if self.origin_merged_mask is None:
|
56 |
+
self.origin_merged_mask = np.zeros(interactive_mask.shape, dtype=np.uint8)
|
57 |
+
|
58 |
+
refined_merged_mask = self.origin_merged_mask.copy()
|
59 |
+
refined_merged_mask[interactive_mask > 0] = self.curr_idx
|
60 |
+
|
61 |
+
return refined_merged_mask
|
62 |
+
|
63 |
+
def automatic_generate_mask(self, image):
|
64 |
+
masks = self.sam.automatic_segment(image)
|
65 |
+
return masks
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
pass
|
S2I/samer/segmentor.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
4 |
+
from .automatic_mask_generator_prob import SamAutomaticMaskAndProbabilityGenerator
|
5 |
+
|
6 |
+
|
7 |
+
class Segmentor:
|
8 |
+
def __init__(self, sam_args):
|
9 |
+
"""
|
10 |
+
sam_args:
|
11 |
+
sam_checkpoint: path of SAM checkpoint
|
12 |
+
generator_args: args for everything_generator
|
13 |
+
gpu_id: device
|
14 |
+
"""
|
15 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"])
|
17 |
+
self.sam.to(device=self.device)
|
18 |
+
# self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args'])
|
19 |
+
self.automatic_generator = SamAutomaticMaskAndProbabilityGenerator(model=self.sam, **sam_args['generator_args'])
|
20 |
+
self.interactive_predictor = self.automatic_generator.predictor
|
21 |
+
self.have_embedded = False
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def set_image(self, image):
|
25 |
+
# calculate the embedding only once per frame.
|
26 |
+
if not self.have_embedded:
|
27 |
+
self.interactive_predictor.set_image(image)
|
28 |
+
self.have_embedded = True
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def interactive_predict(self, prompts, mode, multimask=True):
|
32 |
+
assert self.have_embedded, 'image embedding for sam need be set before predict.'
|
33 |
+
|
34 |
+
if mode == 'point':
|
35 |
+
masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'],
|
36 |
+
point_labels=prompts['point_modes'],
|
37 |
+
multimask_output=multimask)
|
38 |
+
elif mode == 'mask':
|
39 |
+
masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'],
|
40 |
+
multimask_output=multimask)
|
41 |
+
elif mode == 'point_mask':
|
42 |
+
masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'],
|
43 |
+
point_labels=prompts['point_modes'],
|
44 |
+
mask_input=prompts['mask_prompt'],
|
45 |
+
multimask_output=multimask)
|
46 |
+
|
47 |
+
return masks, scores, logits
|
48 |
+
|
49 |
+
@torch.no_grad()
|
50 |
+
def automatic_segment(self, image):
|
51 |
+
masks = self.automatic_generator.generate(image)
|
52 |
+
return masks
|
53 |
+
|
54 |
+
@torch.no_grad()
|
55 |
+
def segment_with_click(self, origin_frame, coords, modes, multimask=True):
|
56 |
+
'''
|
57 |
+
|
58 |
+
return:
|
59 |
+
mask: one-hot
|
60 |
+
'''
|
61 |
+
self.set_image(origin_frame)
|
62 |
+
|
63 |
+
prompts = {
|
64 |
+
'point_coords': coords,
|
65 |
+
'point_modes': modes,
|
66 |
+
}
|
67 |
+
masks, scores, logits = self.interactive_predict(prompts, 'point', multimask)
|
68 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
69 |
+
prompts = {
|
70 |
+
'point_coords': coords,
|
71 |
+
'point_modes': modes,
|
72 |
+
'mask_prompt': logit[None, :, :]
|
73 |
+
}
|
74 |
+
masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask)
|
75 |
+
|
76 |
+
mask = masks[np.argmax(scores)]
|
77 |
+
|
78 |
+
return mask.astype(np.uint8)
|
79 |
+
|
80 |
+
def segment_with_box(self, origin_frame, bbox, reset_image=False):
|
81 |
+
if reset_image:
|
82 |
+
self.interactive_predictor.set_image(origin_frame)
|
83 |
+
else:
|
84 |
+
self.set_image(origin_frame)
|
85 |
+
|
86 |
+
masks, scores, logits = self.interactive_predictor.predict(
|
87 |
+
point_coords=None,
|
88 |
+
point_labels=None,
|
89 |
+
box=np.array([bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]),
|
90 |
+
multimask_output=True
|
91 |
+
)
|
92 |
+
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
93 |
+
|
94 |
+
masks, scores, logits = self.interactive_predictor.predict(
|
95 |
+
point_coords=None,
|
96 |
+
point_labels=None,
|
97 |
+
box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]),
|
98 |
+
mask_input=logit[None, :, :],
|
99 |
+
multimask_output=True
|
100 |
+
)
|
101 |
+
mask = masks[np.argmax(scores)]
|
102 |
+
|
103 |
+
return [mask]
|
S2I/samer/transfer_tools.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def mask2bbox(mask):
|
6 |
+
if len(np.where(mask > 0)[0]) == 0:
|
7 |
+
print(f'not mask')
|
8 |
+
return np.array([[0, 0], [0, 0]]).astype(np.int64)
|
9 |
+
|
10 |
+
x_ = np.sum(mask, axis=0)
|
11 |
+
y_ = np.sum(mask, axis=1)
|
12 |
+
|
13 |
+
x0 = np.min(np.nonzero(x_)[0])
|
14 |
+
x1 = np.max(np.nonzero(x_)[0])
|
15 |
+
y0 = np.min(np.nonzero(y_)[0])
|
16 |
+
y1 = np.max(np.nonzero(y_)[0])
|
17 |
+
|
18 |
+
return np.array([[x0, y0], [x1, y1]]).astype(np.int64)
|
19 |
+
|
20 |
+
|
21 |
+
def draw_outline(mask, frame):
|
22 |
+
_, binary_mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)
|
23 |
+
|
24 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
25 |
+
|
26 |
+
cv2.drawContours(frame, contours, -1, (0, 0, 255), 2)
|
27 |
+
|
28 |
+
return frame
|
29 |
+
|
30 |
+
|
31 |
+
def draw_points(points, modes, frame):
|
32 |
+
neg_points = points[np.argwhere(modes == 0)[:, 0]]
|
33 |
+
pos_points = points[np.argwhere(modes == 1)[:, 0]]
|
34 |
+
|
35 |
+
for i in range(len(neg_points)):
|
36 |
+
point = neg_points[i]
|
37 |
+
cv2.circle(frame, (point[0], point[1]), 8, (255, 80, 80), -1)
|
38 |
+
|
39 |
+
for i in range(len(pos_points)):
|
40 |
+
point = pos_points[i]
|
41 |
+
cv2.circle(frame, (point[0], point[1]), 8, (0, 153, 255), -1)
|
42 |
+
|
43 |
+
return frame
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
pass
|
app.py
CHANGED
@@ -1,146 +1,327 @@
|
|
1 |
-
import
|
2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
3 |
import random
|
4 |
-
|
5 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
|
9 |
-
|
10 |
-
torch.cuda.max_memory_allocated(device=device)
|
11 |
-
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
|
12 |
-
pipe.enable_xformers_memory_efficient_attention()
|
13 |
-
pipe = pipe.to(device)
|
14 |
-
else:
|
15 |
-
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
|
16 |
-
pipe = pipe.to(device)
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
generator = torch.Generator().manual_seed(seed)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
42 |
-
"An astronaut riding a green horse",
|
43 |
-
"A delicious ceviche cheesecake slice",
|
44 |
-
]
|
45 |
-
|
46 |
-
css="""
|
47 |
-
#col-container {
|
48 |
-
margin: 0 auto;
|
49 |
-
max-width: 520px;
|
50 |
-
}
|
51 |
-
"""
|
52 |
|
53 |
-
if
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
57 |
|
58 |
-
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
show_label=False,
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
74 |
)
|
75 |
-
|
76 |
-
run_button = gr.Button("Run", scale=0)
|
77 |
-
|
78 |
-
result = gr.Image(label="Result", show_label=False)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
negative_prompt = gr.Text(
|
83 |
-
label="Negative prompt",
|
84 |
-
max_lines=1,
|
85 |
-
placeholder="Enter a negative prompt",
|
86 |
-
visible=False,
|
87 |
)
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
95 |
)
|
96 |
-
|
97 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
98 |
-
|
99 |
with gr.Row():
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
|
|
|
|
|
|
145 |
|
146 |
-
|
|
|
|
|
|
1 |
+
import os
|
2 |
import numpy as np
|
3 |
+
import io
|
4 |
+
os.system("pip install gradio==4.29.0")
|
5 |
+
os.system("pip install opencv-python")
|
6 |
+
import cv2
|
7 |
+
import gradio as gr
|
8 |
import random
|
9 |
+
import warnings
|
10 |
+
import spaces
|
11 |
+
from PIL import Image
|
12 |
+
from S2I import Sketch2ImageController, css, scripts
|
13 |
+
|
14 |
+
|
15 |
+
dark_mode_theme = """
|
16 |
+
function refresh() {
|
17 |
+
const url = new URL(window.location);
|
18 |
+
|
19 |
+
if (url.searchParams.get('__theme') !== 'dark') {
|
20 |
+
url.searchParams.set('__theme', 'dark');
|
21 |
+
window.location.href = url.href;
|
22 |
+
}
|
23 |
+
}
|
24 |
+
"""
|
25 |
+
|
26 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
27 |
+
warnings.filterwarnings("ignore")
|
28 |
+
controller = Sketch2ImageController(gr)
|
29 |
+
|
30 |
+
|
31 |
+
def run_gpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag):
|
32 |
+
return controller.artwork(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
|
33 |
+
|
34 |
+
def run_cpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag):
|
35 |
+
return controller.artwork(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
|
36 |
+
|
37 |
+
def get_dark_mode():
|
38 |
+
return """
|
39 |
+
() => {
|
40 |
+
document.body.classList.toggle('dark');
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
def clear_session():
|
45 |
+
return gr.update(value=None), gr.update(value=None)
|
46 |
|
|
|
47 |
|
48 |
+
def assign_gpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
if options == 'GPU':
|
51 |
+
decorated_run = spaces.GPU(run_gpu)
|
52 |
+
return decorated_run(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
|
53 |
+
else:
|
54 |
+
return run_cpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
|
55 |
|
56 |
+
def read_temp_file(temp_file_wrapper):
|
57 |
+
name = temp_file_wrapper.name
|
58 |
+
with open(temp_file_wrapper.name, 'rb') as f:
|
59 |
+
# Read the content of the file
|
60 |
+
file_content = f.read()
|
61 |
+
return file_content, name
|
62 |
|
63 |
+
def convert_to_pencil_sketch(image):
|
64 |
+
if image is None:
|
65 |
+
raise ValueError(f"Image at path {image} could not be loaded.")
|
|
|
66 |
|
67 |
+
# Converting it into grayscale
|
68 |
+
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
69 |
+
|
70 |
+
# Inverting the image
|
71 |
+
inverted_image = 255 - gray_image
|
72 |
+
|
73 |
+
# Blurring the image
|
74 |
+
blurred = cv2.GaussianBlur(inverted_image, (25, 25), 0)
|
75 |
+
inverted_blurred = 255 - blurred
|
76 |
+
|
77 |
+
# Creating the pencil sketch
|
78 |
+
pencil_sketch = cv2.divide(gray_image, inverted_blurred, scale=256.0)
|
79 |
+
|
80 |
+
return pencil_sketch
|
81 |
+
|
82 |
+
def get_meta_from_image(input_img, type_image):
|
83 |
+
if input_img is None:
|
84 |
+
return gr.update(value=None)
|
85 |
+
|
86 |
+
file_content, _ = read_temp_file(input_img)
|
87 |
|
88 |
+
# Read the image using Pillow
|
89 |
+
img = Image.open(io.BytesIO(file_content)).convert("RGB")
|
90 |
+
img_np = np.array(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
if type_image == 'RGB':
|
93 |
+
sketch = convert_to_pencil_sketch(img_np)
|
94 |
+
processed_img = 255 - sketch
|
95 |
+
elif type_image == 'SKETCH':
|
96 |
+
processed_img = 255 - img_np
|
97 |
|
98 |
+
# Convert the processed image back to PIL Image
|
99 |
+
img_pil = Image.fromarray(processed_img.astype('uint8'))
|
100 |
|
101 |
+
return img_pil
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
with gr.Blocks(css=css) as demo:
|
106 |
+
gr.HTML(
|
107 |
+
"""
|
108 |
+
<!DOCTYPE html>
|
109 |
+
<html lang="en">
|
110 |
+
<head>
|
111 |
+
<meta charset="UTF-8">
|
112 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
113 |
+
<title>S2I-Artwork Animation</title>
|
114 |
+
<style>
|
115 |
+
|
116 |
+
@keyframes blinkCursor {
|
117 |
+
from { border-right-color: rgba(255, 255, 255, 0.75); }
|
118 |
+
to { border-right-color: transparent; }
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
@keyframes fadeIn {
|
124 |
+
0% { opacity: 0; transform: translateY(-10px); }
|
125 |
+
100% { opacity: 1; transform: translateY(0); }
|
126 |
+
}
|
127 |
+
|
128 |
+
@keyframes bounce {
|
129 |
+
0%, 20%, 50%, 80%, 100% {
|
130 |
+
transform: translateY(0);
|
131 |
+
}
|
132 |
+
40% {
|
133 |
+
transform: translateY(-10px);
|
134 |
+
}
|
135 |
+
60% {
|
136 |
+
transform: translateY(-5px);
|
137 |
+
}
|
138 |
+
}
|
139 |
+
.typewriter h1 {
|
140 |
+
overflow: hidden;
|
141 |
+
border-right: .15em solid rgba(255, 255, 255, 0.75);
|
142 |
+
white-space: nowrap;
|
143 |
+
margin: 0 auto;
|
144 |
+
letter-spacing: .15em;
|
145 |
+
animation:
|
146 |
+
zoomInOut 4s infinite;
|
147 |
+
}
|
148 |
+
.animated-heading {
|
149 |
+
animation: fadeIn 2s ease-in-out;
|
150 |
+
}
|
151 |
+
|
152 |
+
.animated-link {
|
153 |
+
display: inline-block;
|
154 |
+
animation: bounce 3s infinite;
|
155 |
+
}
|
156 |
+
</style>
|
157 |
+
</head>
|
158 |
+
<body>
|
159 |
+
<div>
|
160 |
+
<div class="typewriter">
|
161 |
+
<h1 style="display: flex; align-items: center; justify-content: center; margin-bottom: 10px; text-align: center;">
|
162 |
+
<img src="https://imgur.com/H2SLps2.png" alt="icon" style="margin-left: 10px; height: 30px;">
|
163 |
+
S2I-Artwork
|
164 |
+
<img src="https://imgur.com/cNMKSAy.png" alt="icon" style="margin-left: 10px; height: 30px;">:
|
165 |
+
Personalized Sketch-to-Art 🧨 Diffusion Models
|
166 |
+
<img src="https://imgur.com/yDnDd1p.png" alt="icon" style="margin-left: 10px; height: 30px;">
|
167 |
+
</h1>
|
168 |
+
</div>
|
169 |
+
<h3 class="animated-heading" style="text-align: center; margin-bottom: 10px;">Authors: Vo Nguyen An Tin, Nguyen Thiet Su</h3>
|
170 |
+
<h4 class="animated-heading" style="margin-bottom: 10px;">*This project is the fine-tuning task with LorA on large datasets included: COCO-2017, LHQ, Danbooru, LandScape and Mid-Journey V6</h4>
|
171 |
+
<h4 class="animated-heading" style="margin-bottom: 10px;">* We public 2 sketch2image-models-lora training on 30K and 60K steps with skip-connection and Transformers Super-Resolution variables</h4>
|
172 |
+
<h4 class="animated-heading" style="margin-bottom: 10px;">* The inference and demo time of model is faster, you can slowly in the first runtime, but after that, the time process over 1.5 ~ 2s</h4>
|
173 |
+
<h4 class="animated-heading" style="margin-bottom: 10px;">* View the full code project:
|
174 |
+
<a class="animated-link" href="https://github.com/aihacker111/S2I-Artwork-Sketch-to-Image/" target="_blank">GitHub Repository</a>
|
175 |
+
</h4>
|
176 |
+
<h4 class="animated-heading" style="margin-bottom: 10px;">
|
177 |
+
<a class="animated-link" href="https://github.com/aihacker111/S2I-Artwork-Sketch-to-Image/" target="_blank">
|
178 |
+
<img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="100">
|
179 |
+
</a>
|
180 |
+
</h4>
|
181 |
+
</div>
|
182 |
+
</body>
|
183 |
+
</html>
|
184 |
+
"""
|
185 |
+
)
|
186 |
+
with gr.Row(elem_id="main_row"):
|
187 |
+
with gr.Column(elem_id="column_input"):
|
188 |
+
gr.Markdown("## SKETCH", elem_id="input_header")
|
189 |
+
image = gr.Sketchpad(
|
190 |
+
type="pil",
|
191 |
+
height=512,
|
192 |
+
width=512,
|
193 |
+
min_width=512,
|
194 |
+
image_mode="RGBA",
|
195 |
show_label=False,
|
196 |
+
mirror_webcam=False,
|
197 |
+
show_download_button=True,
|
198 |
+
elem_id='input_image',
|
199 |
+
brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=4),
|
200 |
+
canvas_size=(1024, 1024),
|
201 |
+
layers=False
|
202 |
)
|
203 |
+
input_image = gr.File(label='Input image')
|
|
|
|
|
|
|
204 |
|
205 |
+
download_sketch = gr.Button(
|
206 |
+
"Download sketch", scale=1, elem_id="download_sketch"
|
|
|
|
|
|
|
|
|
|
|
207 |
)
|
208 |
+
|
209 |
+
with gr.Column(elem_id="column_output"):
|
210 |
+
gr.Markdown("## IMAGE GENERATE", elem_id="output_header")
|
211 |
+
result = gr.Image(
|
212 |
+
label="Result",
|
213 |
+
height=440,
|
214 |
+
width=440,
|
215 |
+
elem_id="output_image",
|
216 |
+
show_label=False,
|
217 |
+
show_download_button=True,
|
218 |
)
|
|
|
|
|
|
|
219 |
with gr.Row():
|
220 |
+
run_button = gr.Button("Generate 🪄", min_width=5, variant='primary')
|
221 |
+
randomize_seed = gr.Button(value='\U0001F3B2', variant='primary')
|
222 |
+
clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
|
223 |
+
prompt = gr.Textbox(label="Personalized Text", value="", show_label=True)
|
224 |
+
with gr.Accordion("S2I Advances Option", open=True):
|
225 |
+
with gr.Row():
|
226 |
+
ui_mode = gr.Radio(
|
227 |
+
choices=["Light Mode", "Dark Mode"],
|
228 |
+
value="Light Mode",
|
229 |
+
label="Switch Light/Dark Mode UI",
|
230 |
+
interactive=True)
|
231 |
+
type_image = gr.Radio(
|
232 |
+
choices=["RGB", "SKETCH"],
|
233 |
+
value="SKETCH",
|
234 |
+
label="Type of Image (Color Image or Sketch Image)",
|
235 |
+
interactive=True)
|
236 |
+
input_type = gr.Radio(
|
237 |
+
choices=["live-sketch", "upload"],
|
238 |
+
value="live-sketch",
|
239 |
+
label="Type Sketch2Image models",
|
240 |
+
interactive=True)
|
241 |
+
style = gr.Dropdown(
|
242 |
+
label="Style",
|
243 |
+
choices=controller.STYLE_NAMES,
|
244 |
+
value=controller.DEFAULT_STYLE_NAME,
|
245 |
+
scale=1,
|
246 |
+
)
|
247 |
+
prompt_temp = gr.Textbox(
|
248 |
+
label="Prompt Style Template",
|
249 |
+
value=controller.styles[controller.DEFAULT_STYLE_NAME],
|
250 |
+
scale=2,
|
251 |
+
max_lines=1,
|
252 |
+
)
|
253 |
+
seed = gr.Textbox(label="Seed", value='42', scale=1, min_width=50)
|
254 |
+
zero_gpu_options = gr.Radio(
|
255 |
+
choices=["GPU", "CPU"],
|
256 |
+
value="GPU",
|
257 |
+
label="GPU & CPU Options Spaces",
|
258 |
+
interactive=True)
|
259 |
+
half_model = gr.Radio(
|
260 |
+
choices=["float32", "float16"],
|
261 |
+
value="float16",
|
262 |
+
label="Demo Speed",
|
263 |
+
interactive=True)
|
264 |
+
model_options = gr.Radio(
|
265 |
+
choices=["100k", "350k"],
|
266 |
+
value="350k",
|
267 |
+
label="Type Sketch2Image models",
|
268 |
+
interactive=True)
|
269 |
|
270 |
+
val_r = gr.Slider(
|
271 |
+
label="Sketch guidance: ",
|
272 |
+
show_label=True,
|
273 |
+
minimum=0,
|
274 |
+
maximum=1,
|
275 |
+
value=0.4,
|
276 |
+
step=0.01,
|
277 |
+
scale=3,
|
278 |
+
)
|
279 |
+
|
280 |
+
demo.load(None, None, None, js=scripts)
|
281 |
+
ui_mode.change(None, [], [], js=get_dark_mode())
|
282 |
+
randomize_seed.click(
|
283 |
+
lambda x: random.randint(0, controller.MAX_SEED),
|
284 |
+
inputs=[],
|
285 |
+
outputs=seed,
|
286 |
+
queue=False,
|
287 |
+
api_name=False,
|
288 |
+
)
|
289 |
+
inputs = [zero_gpu_options, image, prompt, prompt_temp, style, seed, val_r, half_model, model_options, input_type]
|
290 |
+
outputs = [result, download_sketch]
|
291 |
+
prompt.submit(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
|
292 |
+
|
293 |
+
input_image.change(
|
294 |
+
fn=get_meta_from_image,
|
295 |
+
inputs=[
|
296 |
+
input_image, type_image
|
297 |
+
],
|
298 |
+
outputs=[
|
299 |
+
image
|
300 |
+
]
|
301 |
+
)
|
302 |
+
|
303 |
+
style.change(
|
304 |
+
lambda x: controller.styles[x],
|
305 |
+
inputs=[style],
|
306 |
+
outputs=[prompt_temp],
|
307 |
+
queue=False,
|
308 |
+
api_name=False,
|
309 |
+
).then(
|
310 |
+
fn=assign_gpu,
|
311 |
+
inputs=inputs,
|
312 |
+
outputs=outputs,
|
313 |
+
api_name=False,
|
314 |
+
)
|
315 |
+
clear_button.click(fn=clear_session, inputs=[], outputs=[image, result]).then(
|
316 |
+
fn=assign_gpu,
|
317 |
+
inputs=inputs,
|
318 |
+
outputs=outputs,
|
319 |
+
api_name=False,
|
320 |
)
|
321 |
+
val_r.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
322 |
+
run_button.click(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
|
323 |
+
image.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
324 |
|
325 |
+
if __name__ == '__main__':
|
326 |
+
demo.queue()
|
327 |
+
demo.launch(debug=True, share=False)
|
requirements.txt
CHANGED
@@ -1,6 +1,87 @@
|
|
1 |
-
accelerate
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.30.1
|
2 |
+
aiofiles==23.2.1
|
3 |
+
altair==5.3.0
|
4 |
+
annotated-types==0.7.0
|
5 |
+
anyio==4.4.0
|
6 |
+
attrs==23.2.0
|
7 |
+
certifi==2024.2.2
|
8 |
+
charset-normalizer==3.3.2
|
9 |
+
click==8.1.7
|
10 |
+
contourpy==1.2.1
|
11 |
+
cycler==0.12.1
|
12 |
+
diffusers==0.25.1
|
13 |
+
dnspython==2.6.1
|
14 |
+
email_validator==2.1.1
|
15 |
+
exceptiongroup==1.2.1
|
16 |
+
fastapi==0.111.0
|
17 |
+
fastapi-cli==0.0.4
|
18 |
+
ffmpy==0.3.2
|
19 |
+
filelock==3.14.0
|
20 |
+
fonttools==4.52.4
|
21 |
+
fsspec==2024.5.0
|
22 |
+
gradio==4.29.0
|
23 |
+
h11==0.14.0
|
24 |
+
httpcore==1.0.5
|
25 |
+
httptools==0.6.1
|
26 |
+
httpx==0.27.0
|
27 |
+
huggingface-hub==0.23.0
|
28 |
+
idna==3.7
|
29 |
+
importlib_metadata==7.1.0
|
30 |
+
importlib_resources==6.4.0
|
31 |
+
Jinja2==3.1.4
|
32 |
+
jsonschema==4.22.0
|
33 |
+
jsonschema-specifications==2023.12.1
|
34 |
+
kiwisolver==1.4.5
|
35 |
+
markdown-it-py==3.0.0
|
36 |
+
MarkupSafe==2.1.5
|
37 |
+
matplotlib==3.9.0
|
38 |
+
mdurl==0.1.2
|
39 |
+
mpmath==1.3.0
|
40 |
+
networkx==3.3
|
41 |
+
numpy==1.26.4
|
42 |
+
orjson==3.10.3
|
43 |
+
packaging==24.0
|
44 |
+
pandas==2.2.2
|
45 |
+
peft==0.11.1
|
46 |
+
pillow==10.3.0
|
47 |
+
psutil==5.9.8
|
48 |
+
pydantic==2.7.2
|
49 |
+
pydantic_core==2.18.3
|
50 |
+
pydub==0.25.1
|
51 |
+
Pygments==2.18.0
|
52 |
+
pyparsing==3.1.2
|
53 |
+
python-dateutil==2.9.0.post0
|
54 |
+
python-dotenv==1.0.1
|
55 |
+
python-multipart==0.0.9
|
56 |
+
pytz==2024.1
|
57 |
+
PyYAML==6.0.1
|
58 |
+
referencing==0.35.1
|
59 |
+
regex==2024.5.15
|
60 |
+
requests==2.32.0
|
61 |
+
rich==13.7.1
|
62 |
+
rpds-py==0.18.1
|
63 |
+
ruff==0.4.6
|
64 |
+
safetensors==0.4.3
|
65 |
+
semantic-version==2.10.0
|
66 |
+
shellingham==1.5.4
|
67 |
+
six==1.16.0
|
68 |
+
sniffio==1.3.1
|
69 |
+
starlette==0.37.2
|
70 |
+
sympy==1.12
|
71 |
+
tokenizers==0.19.1
|
72 |
+
tomlkit==0.12.0
|
73 |
+
toolz==0.12.1
|
74 |
+
torch==2.3.0
|
75 |
+
torchvision==0.18.0
|
76 |
+
tqdm==4.66.4
|
77 |
+
transformers==4.41.0
|
78 |
+
typer==0.12.3
|
79 |
+
typing_extensions==4.11.0
|
80 |
+
tzdata==2024.1
|
81 |
+
ujson==5.10.0
|
82 |
+
urllib3==2.2.1
|
83 |
+
uvicorn==0.30.0
|
84 |
+
uvloop==0.19.0
|
85 |
+
watchfiles==0.22.0
|
86 |
+
websockets==11.0.3
|
87 |
+
zipp==3.18.2
|