myn0908 commited on
Commit
55a3c9a
1 Parent(s): 74ce519

init sketch2image

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ .idea
3
+ *.pyc
4
+ debug
5
+ .DS_Store
S2I/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modules import Sketch2ImagePipeline
2
+ from .commons import Sketch2ImageController, css, scripts
S2I/commons/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controller import Sketch2ImageController
2
+ from .css import css, scripts
S2I/commons/controller.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import numpy as np
4
+ import base64
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ from S2I import Sketch2ImagePipeline
8
+
9
+
10
+
11
+ class Sketch2ImageController():
12
+ def __init__(self, gr):
13
+ super().__init__()
14
+ self.gr = gr
15
+ self.style_list = [
16
+ {"name": "Comic",
17
+ "prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed"},
18
+ {"name": "Cinematic",
19
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"},
20
+ {"name": "3D Model",
21
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"},
22
+ {"name": "Anime",
23
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"},
24
+ {"name": "Digital Art",
25
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"},
26
+ {"name": "Photographic",
27
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"},
28
+ {"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"},
29
+ {"name": "Fantasy art",
30
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"},
31
+ {"name": "Neonpunk",
32
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"},
33
+ {"name": "Manga",
34
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"},
35
+ ]
36
+
37
+ self.styles = {k["name"]: k["prompt"] for k in self.style_list}
38
+ self.STYLE_NAMES = list(self.styles.keys())
39
+ self.DEFAULT_STYLE_NAME = "Fantasy art"
40
+ self.MAX_SEED = np.iinfo(np.int32).max
41
+
42
+ # Initialize the model once here
43
+ self.pipe = None
44
+ self.zero_options = None
45
+ def load_pipeline(self, zero_options):
46
+ if self.pipe is None or zero_options != self.zero_options:
47
+ self.pipe = Sketch2ImagePipeline()
48
+ self.zero_options = zero_options
49
+
50
+ def update_canvas(self, use_line, use_eraser):
51
+ brush_size = 20 if use_eraser else 4
52
+ _color = "#ffffff" if use_eraser else "#000000"
53
+ return self.gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
54
+
55
+ def upload_sketch(self, file):
56
+ _img = Image.open(file.name).convert("L")
57
+ return self.gr.update(value=_img, source="upload", interactive=True)
58
+
59
+ @staticmethod
60
+ def pil_image_to_data_uri(img, format="PNG"):
61
+ buffered = BytesIO()
62
+ img.save(buffered, format=format)
63
+ img_str = base64.b64encode(buffered.getvalue()).decode()
64
+ return f"data:image/{format.lower()};base64,{img_str}"
65
+
66
+ def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag):
67
+ self.load_pipeline(zero_options=options)
68
+
69
+ prompt = prompt_template.replace("{prompt}", prompt)
70
+
71
+ if type_flag == 'live-sketch':
72
+ img = Image.fromarray(np.array(image["composite"])[:, :, -1])
73
+ elif type_flag == 'upload':
74
+ img = image["composite"]
75
+
76
+ img = img.convert("RGB")
77
+ img = img.resize((512, 512))
78
+
79
+ image_t = F.to_tensor(img) > 0.5
80
+ c_t = image_t.unsqueeze(0).cuda().float()
81
+
82
+ torch.manual_seed(seed)
83
+ _, _, H, W = c_t.shape
84
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
85
+
86
+ with torch.no_grad():
87
+ output_image = self.pipe.generate(c_t, prompt, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
88
+
89
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
90
+
91
+ if type_flag == 'live-sketch':
92
+ input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
93
+ else:
94
+ input_uri = self.pil_image_to_data_uri(img)
95
+
96
+ return output_pil, self.gr.update(link=input_uri)
S2I/commons/css.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ css = """
2
+ @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
3
+
4
+ /* Outer container */
5
+ .main {
6
+ display: flex;
7
+ justify-content: center;
8
+ align-items: flex-start;
9
+ width: 100%;
10
+ max-width: 1200px;
11
+ margin: 0 auto;
12
+ padding: 10px;
13
+ # background: linear-gradient(to right, #6a11cb, #2575fc);
14
+ # animation: diffusionArtAnimation 10s infinite alternate;
15
+ }
16
+
17
+ @keyframes diffusionArtAnimation {
18
+ 0% {
19
+ background: linear-gradient(135deg, #ff9a9e, #fad0c4);
20
+ }
21
+ 20% {
22
+ background: linear-gradient(135deg, #a1c4fd, #c2e9fb);
23
+ }
24
+ 40% {
25
+ background: linear-gradient(135deg, #fbc2eb, #a6c1ee);
26
+ }
27
+ 60% {
28
+ background: linear-gradient(135deg, #ffecd2, #fcb69f);
29
+ }
30
+ 80% {
31
+ background: linear-gradient(135deg, #cfd9df, #e2ebf0);
32
+ }
33
+ 100% {
34
+ background: linear-gradient(135deg, #ff9a9e, #fad0c4);
35
+ }
36
+ }
37
+ #main_row{
38
+ justify-content: center;
39
+ }
40
+ /* Hide class */
41
+ .svelte-p4aq0j {
42
+ display: none;
43
+ }
44
+
45
+ .wrap.svelte-p4aq0j.svelte-p4aq0j {
46
+ display: none;
47
+ }
48
+
49
+ #download_sketch {
50
+ display: none;
51
+ }
52
+
53
+ #download_output {
54
+ display: none;
55
+ }
56
+
57
+ #column_input, #column_output {
58
+ width: 100%;
59
+ max-width: 500px;
60
+ display: flex;
61
+ flex-direction: column;
62
+ align-items: center;
63
+ padding: 10px;
64
+ }
65
+
66
+ #tools_header, #input_header, #output_header, #process_header {
67
+ display: flex;
68
+ justify-content: center;
69
+ align-items: center;
70
+ width: 100%;
71
+ max-width: 400px;
72
+ font-size: 1.2em;
73
+ color: #fff;
74
+ text-shadow: 1px 1px 2px #000;
75
+ }
76
+
77
+ #nn {
78
+ width: 100px;
79
+ height: 100px;
80
+ }
81
+
82
+ #column_process {
83
+ display: flex;
84
+ justify-content: center;
85
+ align-items: center;
86
+ height: 600px;
87
+ }
88
+
89
+ #output_image, #input_image {
90
+ border-radius: 10px;
91
+ border: 5px solid #fff;
92
+ width: 100%;
93
+ max-width: 500px;
94
+ height: 500px;
95
+ box-sizing: border-box;
96
+ display: flex;
97
+ justify-content: center;
98
+ align-items: center;
99
+ background: rgba(255, 255, 255, 0.1);
100
+ animation: zoomInOut 5s infinite alternate;
101
+ }
102
+
103
+ @keyframes zoomInOut {
104
+ 0% {
105
+ transform: scale(1);
106
+ }
107
+ 50% {
108
+ transform: scale(1.05);
109
+ }
110
+ 100% {
111
+ transform: scale(1);
112
+ }
113
+ }
114
+
115
+ #output_image > img {
116
+ border: 5px solid #fff;
117
+ border-radius: 10px;
118
+ width: 100%;
119
+ height: 100%;
120
+ box-sizing: border-box;
121
+ }
122
+
123
+ #input_image > div.image-container.svelte-p3y7hu > div.wrap.svelte-yigbas > canvas:nth-child(1) {
124
+ border: 5px solid #fff;
125
+ border-radius: 10px;
126
+ width: 100%;
127
+ height: 100%;
128
+ box-sizing: border-box;
129
+ }
130
+
131
+ /* Responsive styles */
132
+ @media (max-width: 768px) {
133
+ .main {
134
+ flex-direction: column;
135
+ width: 100%;
136
+ }
137
+
138
+ #column_input, #column_output {
139
+ width: 100%;
140
+ max-width: 100%;
141
+ padding: 10px 0;
142
+ }
143
+
144
+ #tools_header, #input_header, #output_header, #process_header {
145
+ width: 100%;
146
+ }
147
+
148
+ #column_process {
149
+ height: auto;
150
+ }
151
+
152
+ #output_image, #input_image {
153
+ max-width: 100%;
154
+ height: auto;
155
+ }
156
+ }
157
+
158
+ @media (max-width: 480px) {
159
+ #nn {
160
+ width: 80px;
161
+ height: 80px;
162
+ }
163
+
164
+ #tools_header, #input_header, #output_header, #process_header {
165
+ max-width: 100%;
166
+ font-size: 14px;
167
+ }
168
+
169
+ #column_input, #column_output {
170
+ max-width: 100%;
171
+ padding: 10px;
172
+ }
173
+ }
174
+ # .flex{
175
+ # background-color: #0b0f19;
176
+ # }
177
+ """
178
+
179
+ scripts = """
180
+ async () => {
181
+ globalThis.theSketchDownloadFunction = () => {
182
+ console.log("test")
183
+ var link = document.createElement("a");
184
+ dataUri = document.getElementById('download_sketch').href
185
+ link.setAttribute("href", dataUri)
186
+ link.setAttribute("download", "sketch.png")
187
+ document.body.appendChild(link); // Required for Firefox
188
+ link.click();
189
+ document.body.removeChild(link); // Clean up
190
+
191
+ // also call the output download function
192
+ theOutputDownloadFunction();
193
+ return false
194
+ }
195
+ }
196
+ """
S2I/logger.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import logging
2
+ logging.basicConfig(level=logging.INFO,
3
+ format='%(asctime)s - %(levelname)s - %(message)s')
4
+ logger = logging.getLogger()
S2I/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sketch2image import *
S2I/modules/models.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ from diffusers import DDPMScheduler
4
+ from transformers import AutoTokenizer, CLIPTextModel
5
+ from diffusers import AutoencoderKL, UNet2DConditionModel
6
+ from peft import LoraConfig
7
+ from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path
8
+
9
+
10
+ class RelationShipConvolution(torch.nn.Module):
11
+ def __init__(self, conv_in_pretrained, conv_in_curr, r):
12
+ super(RelationShipConvolution, self).__init__()
13
+ self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained)
14
+ self.conv_in_curr = copy.deepcopy(conv_in_curr)
15
+ self.r = r
16
+
17
+ def forward(self, x):
18
+ x1 = self.conv_in_pretrained(x).detach()
19
+ x2 = self.conv_in_curr(x)
20
+ return x1 * (1 - self.r) + x2 * self.r
21
+
22
+
23
+ class PrimaryModel:
24
+ def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'):
25
+ self.backbone_diffusion_path = backbone_diffusion_path
26
+ self.global_unet = None
27
+ self.global_vae = None
28
+ self.global_tokenizer = None
29
+ self.global_text_encoder = None
30
+ self.global_scheduler = None
31
+
32
+ @staticmethod
33
+ def _load_model(path, model_class, unet_mode=False):
34
+ model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda')
35
+ return model
36
+
37
+
38
+ def one_step_scheduler(self):
39
+ noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler")
40
+ noise_scheduler_1step.set_timesteps(1, device="cuda")
41
+ noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
42
+ return noise_scheduler_1step
43
+
44
+ def skip_connections(self, vae):
45
+ vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
46
+ vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
47
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
48
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
49
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
50
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
51
+ vae.decoder.ignore_skip = False
52
+ return vae
53
+
54
+ def from_pretrained(self, model_name, r):
55
+ if self.global_tokenizer is None:
56
+ # self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
57
+ # subfolder="tokenizer")
58
+ self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
59
+
60
+ if self.global_text_encoder is None:
61
+ self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path,
62
+ subfolder="text_encoder").to(device='cuda')
63
+
64
+ if self.global_scheduler is None:
65
+ self.global_scheduler = self.one_step_scheduler()
66
+
67
+ if self.global_vae is None:
68
+ self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL)
69
+ self.global_vae = self.skip_connections(self.global_vae)
70
+
71
+ if self.global_unet is None:
72
+ self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
73
+ p_ckpt_path = download_models()
74
+ p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
75
+ sd = torch.load(p_ckpt, map_location="cpu")
76
+ conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
77
+ self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
78
+ unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
79
+ target_modules=sd["unet_lora_target_modules"])
80
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian",
81
+ target_modules=sd["vae_lora_target_modules"])
82
+ self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
83
+ _sd_vae = self.global_vae.state_dict()
84
+ for k in sd["state_dict_vae"]:
85
+ _sd_vae[k] = sd["state_dict_vae"][k]
86
+ self.global_vae.load_state_dict(_sd_vae)
87
+ self.global_unet.add_adapter(unet_lora_config)
88
+ _sd_unet = self.global_unet.state_dict()
89
+ for k in sd["state_dict_unet"]:
90
+ _sd_unet[k] = sd["state_dict_unet"][k]
91
+ self.global_unet.load_state_dict(_sd_unet, strict=False)
S2I/modules/sketch2image.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
2
+ from S2I.modules.models import PrimaryModel
3
+ import gc
4
+ import torch
5
+ import warnings
6
+
7
+ warnings.filterwarnings("ignore")
8
+
9
+
10
+ class Sketch2ImagePipeline(PrimaryModel):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.timestep = torch.tensor([999], device="cuda").long()
14
+
15
+ def generate(self, c_t, prompt=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
16
+ self.from_pretrained(model_name=model_name, r=r)
17
+ assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
18
+
19
+ if half_model == 'float16':
20
+ output_image = self._generate_fp16(c_t, prompt, prompt_tokens, r, noise_map)
21
+ else:
22
+ output_image = self._generate_full_precision(c_t, prompt, prompt_tokens, r, noise_map)
23
+
24
+ return output_image
25
+
26
+ def _generate_fp16(self, c_t, prompt, prompt_tokens, r, noise_map):
27
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
28
+ caption_enc = self._get_caption_enc(prompt, prompt_tokens)
29
+
30
+ self._set_weights_and_activate_adapters(r)
31
+ encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor
32
+
33
+ unet_input = encoded_control * r + noise_map * (1 - r)
34
+ unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample
35
+ x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample
36
+
37
+ self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks
38
+ self.global_vae.decoder.gamma = r
39
+
40
+ output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1)
41
+
42
+ return output_image
43
+
44
+ def _generate_full_precision(self, c_t, prompt, prompt_tokens, r, noise_map):
45
+ caption_enc = self._get_caption_enc(prompt, prompt_tokens)
46
+
47
+ self._set_weights_and_activate_adapters(r)
48
+ encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor
49
+
50
+ unet_input = encoded_control * r + noise_map * (1 - r)
51
+ unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample
52
+ x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample
53
+
54
+ self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks
55
+ self.global_vae.decoder.gamma = r
56
+
57
+ output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1)
58
+
59
+ return output_image
60
+
61
+ def _get_caption_enc(self, prompt, prompt_tokens):
62
+ if prompt is not None:
63
+ caption_tokens = self.global_tokenizer(prompt, max_length=self.global_tokenizer.model_max_length,
64
+ padding="max_length", truncation=True,
65
+ return_tensors="pt").input_ids.cuda()
66
+ else:
67
+ caption_tokens = prompt_tokens.cuda()
68
+
69
+ return self.global_text_encoder(caption_tokens)[0]
70
+
71
+ def _set_weights_and_activate_adapters(self, r):
72
+ self.global_unet.set_adapters(["default"], weights=[r])
73
+ set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
74
+
75
+ def _move_to_cpu(self, module):
76
+ module.to("cpu")
77
+
78
+ def _move_to_gpu(self, module):
79
+ module.to("cuda")
S2I/modules/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from S2I.logger import logger
5
+
6
+ def sc_vae_encoder_fwd(self, sample):
7
+ sample = self.conv_in(sample)
8
+ self.current_down_blocks = []
9
+
10
+ for down_block in self.down_blocks:
11
+ self.current_down_blocks.append(sample)
12
+ sample = down_block(sample)
13
+
14
+ sample = self.mid_block(sample)
15
+ sample = self.conv_norm_out(sample)
16
+ sample = self.conv_act(sample)
17
+ sample = self.conv_out(sample)
18
+ return sample
19
+
20
+ def sc_vae_decoder_fwd(self, sample, latent_embeds=None):
21
+ sample = self.conv_in(sample)
22
+ upscale_dtype = next(self.up_blocks.parameters()).dtype
23
+ sample = self.mid_block(sample, latent_embeds)
24
+ sample = sample.to(upscale_dtype)
25
+
26
+ if not self.ignore_skip:
27
+ skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
28
+ reversed_skip_acts = self.incoming_skip_acts[::-1]
29
+ for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)):
30
+ skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma)
31
+ sample += skip_in
32
+ sample = up_block(sample, latent_embeds)
33
+ else:
34
+ for up_block in self.up_blocks:
35
+ sample = up_block(sample, latent_embeds)
36
+
37
+ sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample)
38
+ sample = self.conv_act(sample)
39
+ sample = self.conv_out(sample)
40
+ return sample
41
+
42
+ def downloading(url, outf):
43
+ if not os.path.exists(outf):
44
+ print(f"Downloading checkpoint to {outf}")
45
+ response = requests.get(url, stream=True)
46
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
47
+ block_size = 1024 # 1 Kibibyte
48
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
49
+ with open(outf, 'wb') as file:
50
+ for data in response.iter_content(block_size):
51
+ progress_bar.update(len(data))
52
+ file.write(data)
53
+ progress_bar.close()
54
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
55
+ print("ERROR, something went wrong")
56
+ print(f"Downloaded successfully to {outf}")
57
+
58
+
59
+ def download_models():
60
+ urls = {
61
+ '350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true',
62
+ '100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
63
+ }
64
+ # Get the current working directory
65
+ ckpt_folder = os.path.join(os.getcwd(), 'checkpoints')
66
+ os.makedirs(ckpt_folder, exist_ok=True)
67
+
68
+ model_paths = {}
69
+ for model_name, url in urls.items():
70
+ outf = os.path.join(ckpt_folder, f"sketch2image_lora_{model_name}.pkl")
71
+ downloading(url, outf)
72
+ model_paths[model_name] = outf
73
+
74
+ return model_paths
75
+
76
+
77
+ def get_model_path(model_name, model_paths):
78
+ return model_paths.get(model_name, "Model not found")
S2I/samer/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .model_args import generate_sam_args
2
+ from .segmentor import *
3
+ from .seg_anything import *
4
+ from .segment import *
5
+ from .transfer_tools import *
6
+ from .automatic_mask_generator_prob import SamAutomaticMaskAndProbabilityGenerator
7
+ from .sam_controller import SAMController
S2I/samer/automatic_mask_generator_prob.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from segment_anything import SamAutomaticMaskGenerator
7
+ from segment_anything.modeling import Sam
8
+ from segment_anything.utils.amg import (MaskData, area_from_rle,
9
+ batched_mask_to_box, box_xyxy_to_xywh,
10
+ batch_iterator,
11
+ uncrop_boxes_xyxy, uncrop_points,
12
+ calculate_stability_score,
13
+ coco_encode_rle, generate_crop_boxes,
14
+ is_box_near_crop_edge,
15
+ mask_to_rle_pytorch, rle_to_mask,
16
+ uncrop_masks)
17
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
18
+
19
+
20
+ def batched_mask_to_prob(masks: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ For implementation, see the following issue comment:
23
+
24
+ "To get the probability map for a mask,
25
+ we simply do element-wise sigmoid over the logits."
26
+ URL: https://github.com/facebookresearch/segment-anything/issues/226
27
+
28
+ Args:
29
+ masks: Tensor of shape [B, H, W] representing batch of binary masks.
30
+
31
+ Returns:
32
+ Tensor of shape [B, H, W] representing batch of probability maps.
33
+ """
34
+ probs = torch.sigmoid(masks).to(masks.device)
35
+ return probs
36
+
37
+
38
+ def batched_sobel_filter(probs: torch.Tensor, masks: torch.Tensor, bzp: int
39
+ ) -> torch.Tensor:
40
+ """
41
+ For implementation, see section D.2 of the paper:
42
+
43
+ "we apply a Sobel filter to the remaining masks' unthresholded probability
44
+ maps and set values to zero if they do not intersect with the outer
45
+ boundary pixels of a mask."
46
+ URL: https://arxiv.org/abs/2304.02643
47
+
48
+ Args:
49
+ probs: Tensor of shape [B, H, W] representing batch of probability maps.
50
+ masks: Tensor of shape [B, H, W] representing batch of binary masks.
51
+
52
+ Returns:
53
+ Tensor of shape [B, H, W] with filtered probability maps.
54
+ """
55
+ # probs: [B, H, W]
56
+ # Add channel dimension to make it [B, 1, H, W]
57
+ probs = probs.unsqueeze(1)
58
+
59
+ # sobel_filter: [1, 1, 3, 3]
60
+ sobel_filter_x = torch.tensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]],
61
+ dtype=torch.float32
62
+ ).to(probs.device).unsqueeze(0)
63
+ sobel_filter_y = torch.tensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]],
64
+ dtype=torch.float32
65
+ ).to(probs.device).unsqueeze(0)
66
+
67
+ # Apply the Sobel filters
68
+ G_x = F.conv2d(probs, sobel_filter_x, padding=1)
69
+ G_y = F.conv2d(probs, sobel_filter_y, padding=1)
70
+
71
+ # Combine the gradients
72
+ probs = torch.sqrt(G_x ** 2 + G_y ** 2)
73
+
74
+ # Iterate through each image in the batch
75
+ for i in range(probs.shape[0]):
76
+ # Convert binary mask to float
77
+ mask = masks[i].float()
78
+
79
+ G_x = F.conv2d(mask[None, None], sobel_filter_x, padding=1)
80
+ G_y = F.conv2d(mask[None, None], sobel_filter_y, padding=1)
81
+ edge = torch.sqrt(G_x ** 2 + G_y ** 2)
82
+ outer_boundary = (edge > 0).float()
83
+
84
+ # Set to zero values that don't touch the mask's outer boundary.
85
+ probs[i, 0] = probs[i, 0] * outer_boundary
86
+
87
+ # Boundary zero padding (BZP).
88
+ # See "Zero-Shot Edge Detection With SCESAME: Spectral
89
+ # Clustering-Based Ensemble for Segment Anything Model Estimation".
90
+ if bzp > 0:
91
+ probs[i, 0, 0:bzp, :] = 0
92
+ probs[i, 0, -bzp:, :] = 0
93
+ probs[i, 0, :, 0:bzp] = 0
94
+ probs[i, 0, :, -bzp:] = 0
95
+
96
+ # Remove the channel dimension
97
+ probs = probs.squeeze(1)
98
+
99
+ return probs
100
+
101
+
102
+ class SamAutomaticMaskAndProbabilityGenerator(SamAutomaticMaskGenerator):
103
+ def __init__(
104
+ self,
105
+ model: Sam,
106
+ points_per_side: Optional[int] = 16,
107
+ points_per_batch: int = 64,
108
+ pred_iou_thresh: float = 0.88,
109
+ stability_score_thresh: float = 0.95,
110
+ stability_score_offset: float = 1.0,
111
+ box_nms_thresh: float = 0.7,
112
+ crop_n_layers: int = 0,
113
+ crop_nms_thresh: float = 0.7,
114
+ crop_overlap_ratio: float = 512 / 1500,
115
+ crop_n_points_downscale_factor: int = 1,
116
+ point_grids: Optional[List[np.ndarray]] = None,
117
+ min_mask_region_area: int = 0,
118
+ output_mode: str = "binary_mask",
119
+ nms_threshold: float = 0.7,
120
+ bzp: int = 0,
121
+ pred_iou_thresh_filtering=False,
122
+ stability_score_thresh_filtering=False,
123
+ ) -> None:
124
+ """
125
+ Using a SAM model, generates masks for the entire image.
126
+ Generates a grid of point prompts over the image, then filters
127
+ low quality and duplicate masks. The default settings are chosen
128
+ for SAM with a ViT-H backbone.
129
+
130
+ Arguments:
131
+ model (Sam): The SAM model to use for mask prediction.
132
+ points_per_side (int or None): The number of points to be sampled
133
+ along one side of the image. The total number of points is
134
+ points_per_side**2. If None, 'point_grids' must provide explicit
135
+ point sampling.
136
+ points_per_batch (int): Sets the number of points run simultaneously
137
+ by the model. Higher numbers may be faster but use more GPU memory.
138
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
139
+ model's predicted mask quality.
140
+ stability_score_thresh (float): A filtering threshold in [0,1], using
141
+ the stability of the mask under changes to the cutoff used to binarize
142
+ the model's mask predictions.
143
+ stability_score_offset (float): The amount to shift the cutoff when
144
+ calculated the stability score.
145
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
146
+ suppression to filter duplicate masks.
147
+ crop_n_layers (int): If >0, mask prediction will be run again on
148
+ crops of the image. Sets the number of layers to run, where each
149
+ layer has 2**i_layer number of image crops.
150
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
151
+ suppression to filter duplicate masks between different crops.
152
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
153
+ In the first crop layer, crops will overlap by this fraction of
154
+ the image length. Later layers with more crops scale down this overlap.
155
+ crop_n_points_downscale_factor (int): The number of points-per-side
156
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
157
+ point_grids (list(np.ndarray) or None): A list over explicit grids
158
+ of points used for sampling, normalized to [0,1]. The nth grid in the
159
+ list is used in the nth crop layer. Exclusive with points_per_side.
160
+ min_mask_region_area (int): If >0, postprocessing will be applied
161
+ to remove disconnected regions and holes in masks with area smaller
162
+ than min_mask_region_area. Requires opencv.
163
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
164
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
165
+ For large resolutions, 'binary_mask' may consume large amounts of
166
+ memory.
167
+ nms_threshold (float): The IoU threshold used for non-maximal suppression
168
+ """
169
+ super().__init__(
170
+ model,
171
+ points_per_side,
172
+ points_per_batch,
173
+ pred_iou_thresh,
174
+ stability_score_thresh,
175
+ stability_score_offset,
176
+ box_nms_thresh,
177
+ crop_n_layers,
178
+ crop_nms_thresh,
179
+ crop_overlap_ratio,
180
+ crop_n_points_downscale_factor,
181
+ point_grids,
182
+ min_mask_region_area,
183
+ output_mode,
184
+ )
185
+ self.nms_threshold = nms_threshold
186
+ self.bzp = bzp
187
+ self.pred_iou_thresh_filtering = pred_iou_thresh_filtering
188
+ self.stability_score_thresh_filtering = \
189
+ stability_score_thresh_filtering
190
+
191
+ @torch.no_grad()
192
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
193
+ """
194
+ Generates masks for the given image.
195
+
196
+ Arguments:
197
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
198
+
199
+ Returns:
200
+ list(dict(str, any)): A list over records for masks. Each record is
201
+ a dict containing the following keys:
202
+ segmentation (dict(str, any) or np.ndarray): The mask. If
203
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
204
+ is a dictionary containing the RLE.
205
+ bbox (list(float)): The box around the mask, in XYWH format.
206
+ area (int): The area in pixels of the mask.
207
+ predicted_iou (float): The model's own prediction of the mask's
208
+ quality. This is filtered by the pred_iou_thresh parameter.
209
+ point_coords (list(list(float))): The point coordinates input
210
+ to the model to generate this mask.
211
+ stability_score (float): A measure of the mask's quality. This
212
+ is filtered on using the stability_score_thresh parameter.
213
+ crop_box (list(float)): The crop of the image used to generate
214
+ the mask, given in XYWH format.
215
+ """
216
+
217
+ # Generate masks
218
+ mask_data = self._generate_masks(image)
219
+
220
+ # Filter small disconnected regions and holes in masks
221
+ if self.min_mask_region_area > 0:
222
+ mask_data = self.postprocess_small_regions(
223
+ mask_data,
224
+ self.min_mask_region_area,
225
+ max(self.box_nms_thresh, self.crop_nms_thresh),
226
+ )
227
+
228
+ # Encode masks
229
+ if self.output_mode == "coco_rle":
230
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
231
+ elif self.output_mode == "binary_mask":
232
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
233
+ else:
234
+ mask_data["segmentations"] = mask_data["rles"]
235
+
236
+ # Write mask records
237
+ curr_anns = []
238
+ for idx in range(len(mask_data["segmentations"])):
239
+ ann = {
240
+ "segmentation": mask_data["segmentations"][idx],
241
+ "area": area_from_rle(mask_data["rles"][idx]),
242
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
243
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
244
+ "point_coords": [mask_data["points"][idx].tolist()],
245
+ "stability_score": mask_data["stability_score"][idx].item(),
246
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
247
+ "prob": mask_data["probs"][idx],
248
+ }
249
+ curr_anns.append(ann)
250
+
251
+ return curr_anns
252
+
253
+ def _process_crop(
254
+ self,
255
+ image: np.ndarray,
256
+ crop_box: List[int],
257
+ crop_layer_idx: int,
258
+ orig_size: Tuple[int, ...],
259
+ ) -> MaskData:
260
+ # Crop the image and calculate embeddings
261
+ x0, y0, x1, y1 = crop_box
262
+ cropped_im = image[y0:y1, x0:x1, :]
263
+ cropped_im_size = cropped_im.shape[:2]
264
+ self.predictor.set_image(cropped_im)
265
+
266
+ # Get points for this crop
267
+ points_scale = np.array(cropped_im_size)[None, ::-1]
268
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
269
+
270
+ # Generate masks for this crop in batches
271
+ data = MaskData()
272
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
273
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
274
+ data.cat(batch_data)
275
+ del batch_data
276
+ self.predictor.reset_image()
277
+
278
+ # Remove duplicates within this crop.
279
+ keep_by_nms = batched_nms(
280
+ data["boxes"].float(),
281
+ data["iou_preds"],
282
+ torch.zeros_like(data["boxes"][:, 0]), # categories
283
+ iou_threshold=self.box_nms_thresh,
284
+ )
285
+ data.filter(keep_by_nms)
286
+
287
+ # Return to the original image frame
288
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
289
+ data["points"] = uncrop_points(data["points"], crop_box)
290
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
291
+
292
+ padded_probs = torch.zeros((data["probs"].shape[0], *orig_size),
293
+ dtype=torch.float32,
294
+ device=data["probs"].device)
295
+ padded_probs[:, y0:y1, x0:x1] = data["probs"]
296
+ data["probs"] = padded_probs
297
+
298
+ return data
299
+
300
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
301
+ orig_size = image.shape[:2]
302
+ crop_boxes, layer_idxs = generate_crop_boxes(
303
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
304
+ )
305
+
306
+ # Iterate over image crops
307
+ data = MaskData()
308
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
309
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
310
+ data.cat(crop_data)
311
+
312
+ # Remove duplicate masks between crops
313
+ if len(crop_boxes) > 1:
314
+ # Prefer masks from smaller crops
315
+ scores = 1 / box_area(data["crop_boxes"])
316
+ scores = scores.to(data["boxes"].device)
317
+ keep_by_nms = batched_nms(
318
+ data["boxes"].float(),
319
+ scores,
320
+ torch.zeros_like(data["boxes"][:, 0]), # categories
321
+ iou_threshold=self.crop_nms_thresh,
322
+ )
323
+ data.filter(keep_by_nms)
324
+
325
+ data.to_numpy()
326
+ return data
327
+
328
+ def _process_batch(
329
+ self,
330
+ points: np.ndarray,
331
+ im_size: Tuple[int, ...],
332
+ crop_box: List[int],
333
+ orig_size: Tuple[int, ...],
334
+ ) -> MaskData:
335
+ orig_h, orig_w = orig_size
336
+
337
+ # Run model on this batch
338
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
339
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
340
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
341
+ masks, iou_preds, _ = self.predictor.predict_torch(
342
+ in_points[:, None, :],
343
+ in_labels[:, None],
344
+ multimask_output=True,
345
+ return_logits=True,
346
+ )
347
+
348
+ # Serialize predictions and store in MaskData
349
+ data = MaskData(
350
+ masks=masks.flatten(0, 1),
351
+ iou_preds=iou_preds.flatten(0, 1),
352
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
353
+ )
354
+ del masks
355
+
356
+ if self.pred_iou_thresh_filtering and self.pred_iou_thresh > 0.0:
357
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
358
+ data.filter(keep_mask)
359
+
360
+ # Calculate stability score
361
+ data["stability_score"] = calculate_stability_score(
362
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
363
+ )
364
+
365
+ if self.stability_score_thresh_filtering and \
366
+ self.stability_score_thresh > 0.0:
367
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
368
+ data.filter(keep_mask)
369
+
370
+ # Threshold masks and calculate boxes
371
+ data["probs"] = batched_mask_to_prob(data["masks"])
372
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
373
+ data["boxes"] = batched_mask_to_box(data["masks"])
374
+
375
+ # Filter boxes that touch crop boundaries
376
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
377
+ if not torch.all(keep_mask):
378
+ data.filter(keep_mask)
379
+
380
+ # filter by nms
381
+ if self.nms_threshold > 0.0:
382
+ keep_mask = batched_nms(
383
+ data["boxes"].float(),
384
+ data["iou_preds"],
385
+ torch.zeros_like(data["boxes"][:, 0]), # categories
386
+ iou_threshold=self.nms_threshold,
387
+ )
388
+ data.filter(keep_mask)
389
+
390
+ # apply sobel filter for probability map
391
+ data["probs"] = batched_sobel_filter(data["probs"], data["masks"],
392
+ bzp=self.bzp)
393
+
394
+ # set prob to 0 for pixels outside of crop box
395
+ # data["probs"] = batched_crop_probs(data["probs"], data["boxes"])
396
+
397
+ # Compress to RLE
398
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
399
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
400
+ del data["masks"]
401
+
402
+ return data
S2I/samer/model_args.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def generate_sam_args(sam_checkpoint="ckpt", model_type="vit_b", points_per_side=16,
2
+ pred_iou_thresh=0.8, stability_score_thresh=0.9, crop_n_layers=1,
3
+ crop_n_points_downscale_factor=2, min_mask_region_area=200, gpu_id=0):
4
+ sam_args = {
5
+ 'sam_checkpoint': f'{sam_checkpoint}/{model_type}.pth',
6
+ 'model_type': model_type,
7
+ 'generator_args': {
8
+ 'points_per_side': points_per_side,
9
+ 'pred_iou_thresh': pred_iou_thresh,
10
+ 'stability_score_thresh': stability_score_thresh,
11
+ 'crop_n_layers': crop_n_layers,
12
+ 'crop_n_points_downscale_factor': crop_n_points_downscale_factor,
13
+ 'min_mask_region_area': min_mask_region_area,
14
+ },
15
+ 'gpu_id': gpu_id}
16
+
17
+ return sam_args
S2I/samer/sam_controller.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from S2I.samer import SegMent, generate_sam_args
2
+ from S2I.logger import logger
3
+ from tqdm import tqdm
4
+ import gradio as gr
5
+ import numpy as np
6
+ import os
7
+ import shutil
8
+ import cv2
9
+ import requests
10
+
11
+
12
+ class SAMController:
13
+ def __init__(self):
14
+ self.current_model_type = None
15
+ self.refine_mask = None
16
+
17
+ @staticmethod
18
+ def clean():
19
+ return None, None, None, None, None, [[]]
20
+
21
+ @staticmethod
22
+ def save_mask(refined_mask=None, save=False):
23
+
24
+ if refined_mask is not None and save:
25
+ if os.path.exists(os.path.join(os.getcwd(), 'output_render')):
26
+ shutil.rmtree(os.path.join(os.getcwd(), 'output_render'))
27
+ save_path = os.path.join(os.getcwd(), 'output_render')
28
+ os.makedirs(save_path, exist_ok=True)
29
+ cv2.imwrite(os.path.join(save_path, f'refined_mask_result.png'), (refined_mask * 255).astype('uint8'))
30
+ elif refined_mask is None and save:
31
+ return os.path.join(os.path.join(os.getcwd(), 'output_render'), f'refined_mask_result.png')
32
+
33
+ @staticmethod
34
+ def download_models(model_type):
35
+ dir_path = os.path.join(os.getcwd(), 'root_model')
36
+ sam_models_path = os.path.join(dir_path, 'sam_models')
37
+
38
+ # Models URLs
39
+ models_urls = {
40
+ 'sam_models': {
41
+ 'vit_b': 'https://huggingface.co/ybelkada/segment-anything/resolve/main/checkpoints/sam_vit_b_01ec64.pth?download=true',
42
+ 'vit_l': 'https://huggingface.co/segments-arnaud/sam_vit_l/resolve/main/sam_vit_l_0b3195.pth?download=true',
43
+ 'vit_h': 'https://huggingface.co/segments-arnaud/sam_vit_h/resolve/main/sam_vit_h_4b8939.pth?download=true'
44
+ }
45
+ }
46
+
47
+ # Download specified model type
48
+ if model_type in models_urls['sam_models']:
49
+ model_url = models_urls['sam_models'][model_type]
50
+ os.makedirs(sam_models_path, exist_ok=True)
51
+ model_path = os.path.join(sam_models_path, model_type + '.pth')
52
+
53
+ if not os.path.exists(model_path):
54
+ logger.info(f"Downloading {model_type} model...")
55
+ response = requests.get(model_url, stream=True)
56
+ response.raise_for_status() # Raise an exception for non-2xx status codes
57
+
58
+ total_size = int(response.headers.get('content-length', 0)) # Get file size from headers
59
+ with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading {model_type} model") as pbar:
60
+ with open(model_path, 'wb') as f:
61
+ for chunk in response.iter_content(chunk_size=1024):
62
+ f.write(chunk)
63
+ pbar.update(len(chunk))
64
+ logger.info(f"{model_type} model downloaded.")
65
+ else:
66
+ logger.info(f"{model_type} model already exists.")
67
+ return logger.info(f"{model_type} model download complete.")
68
+ else:
69
+ return logger.info(f"Invalid model type: {model_type}")
70
+
71
+ @staticmethod
72
+ def get_models_path(model_type=None, segment=False):
73
+ sam_models_path = os.path.join(os.getcwd(), 'root_model', 'sam_models')
74
+
75
+ if segment:
76
+ sam_args = generate_sam_args(sam_checkpoint=sam_models_path, model_type=model_type)
77
+ return sam_args, sam_models_path
78
+
79
+ @staticmethod
80
+ def get_click_prompt(click_stack, point):
81
+ click_stack[0].append(point["coord"])
82
+ click_stack[1].append(point["mode"]
83
+ )
84
+
85
+ prompt = {
86
+ "points_coord": click_stack[0],
87
+ "points_mode": click_stack[1],
88
+ "multi_mask": "True",
89
+ }
90
+
91
+ return prompt
92
+
93
+ @staticmethod
94
+ def read_temp_file(temp_file_wrapper):
95
+ name = temp_file_wrapper.name
96
+ with open(temp_file_wrapper.name, 'rb') as f:
97
+ # Read the content of the file
98
+ file_content = f.read()
99
+ return file_content, name
100
+
101
+ def get_meta_from_image(self, input_img):
102
+ file_content, _ = self.read_temp_file(input_img)
103
+ np_arr = np.frombuffer(file_content, np.uint8)
104
+
105
+ img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
106
+ first_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
107
+ return first_frame, first_frame
108
+
109
+ def is_sam_model(self, model_type):
110
+ sam_args, sam_models_dir = self.get_models_path(model_type=model_type, segment=True)
111
+ model_path = os.path.join(sam_models_dir, model_type + '.pth')
112
+ if not os.path.exists(model_path):
113
+ self.download_models(model_type=model_type)
114
+ return 'Model is downloaded', sam_args
115
+ else:
116
+ return 'Model is already downloaded', sam_args
117
+
118
+ @staticmethod
119
+ def init_segment(
120
+ points_per_side,
121
+ origin_frame,
122
+ sam_args,
123
+ predict_iou_thresh=0.8,
124
+ stability_score_thresh=0.9,
125
+ crop_n_layers=1,
126
+ crop_n_points_downscale_factor=2,
127
+ min_mask_region_area=200):
128
+ if origin_frame is None:
129
+ return None, origin_frame, [[], []]
130
+ sam_args["generator_args"]["points_per_side"] = points_per_side
131
+ sam_args["generator_args"]["pred_iou_thresh"] = predict_iou_thresh
132
+ sam_args["generator_args"]["stability_score_thresh"] = stability_score_thresh
133
+ sam_args["generator_args"]["crop_n_layers"] = crop_n_layers
134
+ sam_args["generator_args"]["crop_n_points_downscale_factor"] = crop_n_points_downscale_factor
135
+ sam_args["generator_args"]["min_mask_region_area"] = min_mask_region_area
136
+
137
+ segment = SegMent(sam_args)
138
+ logger.info(f"Model Init: {sam_args}")
139
+ return segment, origin_frame, [[], []]
140
+
141
+ @staticmethod
142
+ def seg_acc_click(segment, prompt, origin_frame):
143
+ # seg acc to click
144
+ refined_mask, masked_frame = segment.seg_acc_click(
145
+ origin_frame=origin_frame,
146
+ coords=np.array(prompt["points_coord"]),
147
+ modes=np.array(prompt["points_mode"]),
148
+ multimask=prompt["multi_mask"],
149
+ )
150
+ return refined_mask, masked_frame
151
+
152
+ def undo_click_stack_and_refine_seg(self, segment, origin_frame, click_stack):
153
+ if segment is None:
154
+ return segment, origin_frame, [[], []]
155
+
156
+ logger.info("Undo !")
157
+ if len(click_stack[0]) > 0:
158
+ click_stack[0] = click_stack[0][: -1]
159
+ click_stack[1] = click_stack[1][: -1]
160
+
161
+ if len(click_stack[0]) > 0:
162
+ prompt = {
163
+ "points_coord": click_stack[0],
164
+ "points_mode": click_stack[1],
165
+ "multi_mask": "True",
166
+ }
167
+
168
+ _, masked_frame = self.seg_acc_click(segment, prompt, origin_frame)
169
+ return segment, masked_frame, click_stack
170
+ else:
171
+ return segment, origin_frame, [[], []]
172
+
173
+ def reload_segment(self,
174
+ check_sam,
175
+ segment,
176
+ model_type,
177
+ point_per_sides,
178
+ origin_frame,
179
+ predict_iou_thresh,
180
+ stability_score_thresh,
181
+ crop_n_layers,
182
+ crop_n_points_downscale_factor,
183
+ min_mask_region_area):
184
+ status, sam_args = check_sam(model_type)
185
+ if segment is None or status == 'Model is downloaded':
186
+ segment, _, _ = self.init_segment(point_per_sides,
187
+ origin_frame,
188
+ sam_args,
189
+ predict_iou_thresh,
190
+ stability_score_thresh,
191
+ crop_n_layers,
192
+ crop_n_points_downscale_factor,
193
+ min_mask_region_area)
194
+ self.current_model_type = model_type
195
+ return segment, self.current_model_type, status
196
+
197
+ def sam_click(self,
198
+ evt: gr.SelectData,
199
+ segment,
200
+ origin_frame,
201
+ model_type,
202
+ point_mode,
203
+ click_stack,
204
+ point_per_sides,
205
+ predict_iou_thresh,
206
+ stability_score_thresh,
207
+ crop_n_layers,
208
+ crop_n_points_downscale_factor,
209
+ min_mask_region_area):
210
+ logger.info("Click")
211
+ if point_mode == "Positive":
212
+ point = {"coord": [evt.index[0], evt.index[1]], "mode": 1}
213
+ else:
214
+ point = {"coord": [evt.index[0], evt.index[1]], "mode": 0}
215
+ click_prompt = self.get_click_prompt(click_stack, point)
216
+ segment, self.current_model_type, status = self.reload_segment(
217
+ self.is_sam_model,
218
+ segment,
219
+ model_type,
220
+ point_per_sides,
221
+ origin_frame,
222
+ predict_iou_thresh,
223
+ stability_score_thresh,
224
+ crop_n_layers,
225
+ crop_n_points_downscale_factor,
226
+ min_mask_region_area)
227
+ if segment is not None and model_type != self.current_model_type:
228
+ segment = None
229
+ segment, _, status = self.reload_segment(
230
+ self.is_sam_model,
231
+ segment,
232
+ model_type,
233
+ point_per_sides,
234
+ origin_frame,
235
+ predict_iou_thresh,
236
+ stability_score_thresh,
237
+ crop_n_layers,
238
+ crop_n_points_downscale_factor,
239
+ min_mask_region_area)
240
+ refined_mask, masked_frame = self.seg_acc_click(segment, click_prompt, origin_frame)
241
+ self.save_mask(refined_mask, save=True)
242
+ self.refine_mask = refined_mask
243
+ return segment, masked_frame, click_stack, status
244
+
245
+ @staticmethod
246
+ def normalize_image(image):
247
+ # Normalize the image to the range [0, 1]
248
+ min_val = image.min()
249
+ max_val = image.max()
250
+ image = (image - min_val) / (max_val - min_val)
251
+
252
+ return image
253
+
254
+ @staticmethod
255
+ def compute_probability(masks):
256
+ p_max = None
257
+ for mask in masks:
258
+ p = mask['prob']
259
+ if p_max is None:
260
+ p_max = p
261
+ else:
262
+ p_max = np.maximum(p_max, p)
263
+ return p_max
264
+ @staticmethod
265
+ def download_opencv_model(model_url):
266
+ opencv_model_path = os.path.join(os.getcwd(), 'edges_detection')
267
+ os.makedirs(opencv_model_path, exist_ok=True)
268
+ model_path = os.path.join(opencv_model_path, 'edges_detection' + '.yml.gz')
269
+ response = requests.get(model_url, stream=True)
270
+ response.raise_for_status() # Raise an exception for non-2xx status codes
271
+
272
+ total_size = int(response.headers.get('content-length', 0)) # Get file size from headers
273
+ with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading opencv model") as pbar:
274
+ with open(model_path, 'wb') as f:
275
+ for chunk in response.iter_content(chunk_size=1024):
276
+ f.write(chunk)
277
+ pbar.update(len(chunk))
278
+ return model_path
279
+
280
+ def automatic_sam2sketch(self,
281
+ segment,
282
+ image,
283
+ origin_frame,
284
+ model_type
285
+ ):
286
+ _, sam_args = self.is_sam_model(model_type)
287
+ if segment is None or model_type != sam_args['model_type']:
288
+ segment, _, _ = self.init_segment(
289
+ points_per_side=16,
290
+ origin_frame=origin_frame,
291
+ sam_args=sam_args,
292
+ predict_iou_thresh=0.8,
293
+ stability_score_thresh=0.9,
294
+ crop_n_layers=1,
295
+ crop_n_points_downscale_factor=2,
296
+ min_mask_region_area=200)
297
+ model_path = self.download_opencv_model(model_url='https://github.com/nipunmanral/Object-Detection-using-OpenCV/raw/master/model.yml.gz')
298
+ masks = segment.automatic_generate_mask(image)
299
+ p_max = self.compute_probability(masks)
300
+ edges = self.normalize_image(p_max)
301
+ edge_detection = cv2.ximgproc.createStructuredEdgeDetection(model_path)
302
+ orimap = edge_detection.computeOrientation(edges)
303
+ edges = edge_detection.edgesNms(edges, orimap)
304
+ edges = (edges * 255).astype('uint8')
305
+ edges = 255 - edges
306
+ edges = np.stack((edges,) * 3, axis=-1)
307
+ return edges
S2I/samer/seg_anything.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ from scipy.ndimage import binary_dilation
5
+
6
+ np.random.seed(200)
7
+ _palette = ((np.random.random((3 * 255)) * 0.7 + 0.3) * 255).astype(np.uint8).tolist()
8
+ _palette = [0, 0, 0] + _palette
9
+
10
+
11
+ def save_prediction(predict_mask, output_dir, file_name):
12
+ save_mask = Image.fromarray(predict_mask.astype(np.uint8))
13
+ save_mask = save_mask.convert(mode='P')
14
+ save_mask.putpalette(_palette)
15
+ save_mask.save(os.path.join(output_dir, file_name))
16
+
17
+
18
+ def colorize_mask(predict_mask):
19
+ save_mask = Image.fromarray(predict_mask.astype(np.uint8))
20
+ save_mask = save_mask.convert(mode='P')
21
+ save_mask.putpalette(_palette)
22
+ save_mask = save_mask.convert(mode='RGB')
23
+ return np.array(save_mask)
24
+
25
+
26
+ def draw_mask(img, mask, alpha=0.5, id_cnt=False):
27
+ img_mask = img
28
+ if id_cnt:
29
+ # very slow ~ 1s per image
30
+ obj_ids = np.unique(mask)
31
+ obj_ids = obj_ids[obj_ids != 0]
32
+
33
+ for ids in obj_ids:
34
+ # Overlay color on binary mask
35
+ if ids <= 255:
36
+ color = _palette[ids * 3:ids * 3 + 3]
37
+ else:
38
+ color = [0, 0, 0]
39
+ foreground = img * (1 - alpha) + np.ones_like(img) * alpha * np.array(color)
40
+ binary_mask = (mask == ids)
41
+
42
+ # Compose image
43
+ img_mask[binary_mask] = foreground[binary_mask]
44
+
45
+ cnt = binary_dilation(binary_mask, iterations=1) ^ binary_mask
46
+ img_mask[cnt, :] = 0
47
+ else:
48
+ binary_mask = (mask != 0)
49
+ cnt = binary_dilation(binary_mask, iterations=1) ^ binary_mask
50
+ foreground = img * (1 - alpha) + colorize_mask(mask) * alpha
51
+ img_mask[binary_mask] = foreground[binary_mask]
52
+ img_mask[cnt, :] = 0
53
+
54
+ return img_mask.astype(img.dtype)
S2I/samer/segment.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("../../..")
4
+ sys.path.append("")
5
+ import cv2
6
+ import numpy as np
7
+ from S2I.samer.segmentor import Segmentor
8
+ from S2I.samer.transfer_tools import draw_outline, draw_points
9
+ from S2I.samer.seg_anything import draw_mask
10
+
11
+
12
+ class SegMent:
13
+ def __init__(self, sam_args):
14
+ self.sam = Segmentor(sam_args)
15
+ self.reference_objs_list = []
16
+ self.object_idx = 1
17
+ self.curr_idx = 1
18
+ self.origin_merged_mask = None # init by segment-everything or update
19
+ self.first_frame_mask = None
20
+
21
+ # debug
22
+ self.everything_points = []
23
+ self.everything_labels = []
24
+ print("SegTracker has been initialized")
25
+
26
+ def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray, ):
27
+ # get interactive_mask
28
+ interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
29
+ refined_merged_mask = self.add_mask(interactive_mask)
30
+
31
+ # draw mask
32
+ masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
33
+
34
+ # draw bbox
35
+ masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))
36
+
37
+ return refined_merged_mask, masked_frame
38
+
39
+ def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True):
40
+ # get interactive_mask
41
+ interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask)
42
+
43
+ refined_merged_mask = self.add_mask(interactive_mask)
44
+
45
+ # draw mask
46
+ masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
47
+ masked_frame = draw_points(coords, modes, masked_frame)
48
+
49
+ # draw outline
50
+ masked_frame = draw_outline(interactive_mask, masked_frame)
51
+
52
+ return refined_merged_mask, masked_frame
53
+
54
+ def add_mask(self, interactive_mask: np.ndarray):
55
+ if self.origin_merged_mask is None:
56
+ self.origin_merged_mask = np.zeros(interactive_mask.shape, dtype=np.uint8)
57
+
58
+ refined_merged_mask = self.origin_merged_mask.copy()
59
+ refined_merged_mask[interactive_mask > 0] = self.curr_idx
60
+
61
+ return refined_merged_mask
62
+
63
+ def automatic_generate_mask(self, image):
64
+ masks = self.sam.automatic_segment(image)
65
+ return masks
66
+
67
+
68
+ if __name__ == '__main__':
69
+ pass
S2I/samer/segmentor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
4
+ from .automatic_mask_generator_prob import SamAutomaticMaskAndProbabilityGenerator
5
+
6
+
7
+ class Segmentor:
8
+ def __init__(self, sam_args):
9
+ """
10
+ sam_args:
11
+ sam_checkpoint: path of SAM checkpoint
12
+ generator_args: args for everything_generator
13
+ gpu_id: device
14
+ """
15
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"])
17
+ self.sam.to(device=self.device)
18
+ # self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args'])
19
+ self.automatic_generator = SamAutomaticMaskAndProbabilityGenerator(model=self.sam, **sam_args['generator_args'])
20
+ self.interactive_predictor = self.automatic_generator.predictor
21
+ self.have_embedded = False
22
+
23
+ @torch.no_grad()
24
+ def set_image(self, image):
25
+ # calculate the embedding only once per frame.
26
+ if not self.have_embedded:
27
+ self.interactive_predictor.set_image(image)
28
+ self.have_embedded = True
29
+
30
+ @torch.no_grad()
31
+ def interactive_predict(self, prompts, mode, multimask=True):
32
+ assert self.have_embedded, 'image embedding for sam need be set before predict.'
33
+
34
+ if mode == 'point':
35
+ masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'],
36
+ point_labels=prompts['point_modes'],
37
+ multimask_output=multimask)
38
+ elif mode == 'mask':
39
+ masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'],
40
+ multimask_output=multimask)
41
+ elif mode == 'point_mask':
42
+ masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'],
43
+ point_labels=prompts['point_modes'],
44
+ mask_input=prompts['mask_prompt'],
45
+ multimask_output=multimask)
46
+
47
+ return masks, scores, logits
48
+
49
+ @torch.no_grad()
50
+ def automatic_segment(self, image):
51
+ masks = self.automatic_generator.generate(image)
52
+ return masks
53
+
54
+ @torch.no_grad()
55
+ def segment_with_click(self, origin_frame, coords, modes, multimask=True):
56
+ '''
57
+
58
+ return:
59
+ mask: one-hot
60
+ '''
61
+ self.set_image(origin_frame)
62
+
63
+ prompts = {
64
+ 'point_coords': coords,
65
+ 'point_modes': modes,
66
+ }
67
+ masks, scores, logits = self.interactive_predict(prompts, 'point', multimask)
68
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
69
+ prompts = {
70
+ 'point_coords': coords,
71
+ 'point_modes': modes,
72
+ 'mask_prompt': logit[None, :, :]
73
+ }
74
+ masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask)
75
+
76
+ mask = masks[np.argmax(scores)]
77
+
78
+ return mask.astype(np.uint8)
79
+
80
+ def segment_with_box(self, origin_frame, bbox, reset_image=False):
81
+ if reset_image:
82
+ self.interactive_predictor.set_image(origin_frame)
83
+ else:
84
+ self.set_image(origin_frame)
85
+
86
+ masks, scores, logits = self.interactive_predictor.predict(
87
+ point_coords=None,
88
+ point_labels=None,
89
+ box=np.array([bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]),
90
+ multimask_output=True
91
+ )
92
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
93
+
94
+ masks, scores, logits = self.interactive_predictor.predict(
95
+ point_coords=None,
96
+ point_labels=None,
97
+ box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]),
98
+ mask_input=logit[None, :, :],
99
+ multimask_output=True
100
+ )
101
+ mask = masks[np.argmax(scores)]
102
+
103
+ return [mask]
S2I/samer/transfer_tools.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def mask2bbox(mask):
6
+ if len(np.where(mask > 0)[0]) == 0:
7
+ print(f'not mask')
8
+ return np.array([[0, 0], [0, 0]]).astype(np.int64)
9
+
10
+ x_ = np.sum(mask, axis=0)
11
+ y_ = np.sum(mask, axis=1)
12
+
13
+ x0 = np.min(np.nonzero(x_)[0])
14
+ x1 = np.max(np.nonzero(x_)[0])
15
+ y0 = np.min(np.nonzero(y_)[0])
16
+ y1 = np.max(np.nonzero(y_)[0])
17
+
18
+ return np.array([[x0, y0], [x1, y1]]).astype(np.int64)
19
+
20
+
21
+ def draw_outline(mask, frame):
22
+ _, binary_mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)
23
+
24
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
25
+
26
+ cv2.drawContours(frame, contours, -1, (0, 0, 255), 2)
27
+
28
+ return frame
29
+
30
+
31
+ def draw_points(points, modes, frame):
32
+ neg_points = points[np.argwhere(modes == 0)[:, 0]]
33
+ pos_points = points[np.argwhere(modes == 1)[:, 0]]
34
+
35
+ for i in range(len(neg_points)):
36
+ point = neg_points[i]
37
+ cv2.circle(frame, (point[0], point[1]), 8, (255, 80, 80), -1)
38
+
39
+ for i in range(len(pos_points)):
40
+ point = pos_points[i]
41
+ cv2.circle(frame, (point[0], point[1]), 8, (0, 153, 255), -1)
42
+
43
+ return frame
44
+
45
+
46
+ if __name__ == '__main__':
47
+ pass
app.py CHANGED
@@ -1,146 +1,327 @@
1
- import gradio as gr
2
  import numpy as np
 
 
 
 
 
3
  import random
4
- from diffusers import DiffusionPipeline
5
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
 
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
 
 
 
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
 
22
 
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
 
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return image
39
-
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
-
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
50
- }
51
- """
52
 
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
 
57
 
58
- with gr.Blocks(css=css) as demo:
 
59
 
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
- with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
 
 
 
74
  )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
 
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
  )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
 
 
 
95
  )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
  with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
- )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
139
-
140
- run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
 
 
 
145
 
146
- demo.queue().launch()
 
 
 
1
+ import os
2
  import numpy as np
3
+ import io
4
+ os.system("pip install gradio==4.29.0")
5
+ os.system("pip install opencv-python")
6
+ import cv2
7
+ import gradio as gr
8
  import random
9
+ import warnings
10
+ import spaces
11
+ from PIL import Image
12
+ from S2I import Sketch2ImageController, css, scripts
13
+
14
+
15
+ dark_mode_theme = """
16
+ function refresh() {
17
+ const url = new URL(window.location);
18
+
19
+ if (url.searchParams.get('__theme') !== 'dark') {
20
+ url.searchParams.set('__theme', 'dark');
21
+ window.location.href = url.href;
22
+ }
23
+ }
24
+ """
25
+
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
+ warnings.filterwarnings("ignore")
28
+ controller = Sketch2ImageController(gr)
29
+
30
+
31
+ def run_gpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag):
32
+ return controller.artwork(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
33
+
34
+ def run_cpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag):
35
+ return controller.artwork(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
36
+
37
+ def get_dark_mode():
38
+ return """
39
+ () => {
40
+ document.body.classList.toggle('dark');
41
+ }
42
+ """
43
+
44
+ def clear_session():
45
+ return gr.update(value=None), gr.update(value=None)
46
 
 
47
 
48
+ def assign_gpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag):
 
 
 
 
 
 
 
49
 
50
+ if options == 'GPU':
51
+ decorated_run = spaces.GPU(run_gpu)
52
+ return decorated_run(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
53
+ else:
54
+ return run_cpu(options, img_init, text_init, prompt_template_init, style_name_init, seeds_init, val_r_values_init, faster_init, model_name_init, clear_flag)
55
 
56
+ def read_temp_file(temp_file_wrapper):
57
+ name = temp_file_wrapper.name
58
+ with open(temp_file_wrapper.name, 'rb') as f:
59
+ # Read the content of the file
60
+ file_content = f.read()
61
+ return file_content, name
62
 
63
+ def convert_to_pencil_sketch(image):
64
+ if image is None:
65
+ raise ValueError(f"Image at path {image} could not be loaded.")
 
66
 
67
+ # Converting it into grayscale
68
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
69
+
70
+ # Inverting the image
71
+ inverted_image = 255 - gray_image
72
+
73
+ # Blurring the image
74
+ blurred = cv2.GaussianBlur(inverted_image, (25, 25), 0)
75
+ inverted_blurred = 255 - blurred
76
+
77
+ # Creating the pencil sketch
78
+ pencil_sketch = cv2.divide(gray_image, inverted_blurred, scale=256.0)
79
+
80
+ return pencil_sketch
81
+
82
+ def get_meta_from_image(input_img, type_image):
83
+ if input_img is None:
84
+ return gr.update(value=None)
85
+
86
+ file_content, _ = read_temp_file(input_img)
87
 
88
+ # Read the image using Pillow
89
+ img = Image.open(io.BytesIO(file_content)).convert("RGB")
90
+ img_np = np.array(img)
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ if type_image == 'RGB':
93
+ sketch = convert_to_pencil_sketch(img_np)
94
+ processed_img = 255 - sketch
95
+ elif type_image == 'SKETCH':
96
+ processed_img = 255 - img_np
97
 
98
+ # Convert the processed image back to PIL Image
99
+ img_pil = Image.fromarray(processed_img.astype('uint8'))
100
 
101
+ return img_pil
102
+
103
+
104
+
105
+ with gr.Blocks(css=css) as demo:
106
+ gr.HTML(
107
+ """
108
+ <!DOCTYPE html>
109
+ <html lang="en">
110
+ <head>
111
+ <meta charset="UTF-8">
112
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
113
+ <title>S2I-Artwork Animation</title>
114
+ <style>
115
+
116
+ @keyframes blinkCursor {
117
+ from { border-right-color: rgba(255, 255, 255, 0.75); }
118
+ to { border-right-color: transparent; }
119
+ }
120
+
121
+
122
+
123
+ @keyframes fadeIn {
124
+ 0% { opacity: 0; transform: translateY(-10px); }
125
+ 100% { opacity: 1; transform: translateY(0); }
126
+ }
127
+
128
+ @keyframes bounce {
129
+ 0%, 20%, 50%, 80%, 100% {
130
+ transform: translateY(0);
131
+ }
132
+ 40% {
133
+ transform: translateY(-10px);
134
+ }
135
+ 60% {
136
+ transform: translateY(-5px);
137
+ }
138
+ }
139
+ .typewriter h1 {
140
+ overflow: hidden;
141
+ border-right: .15em solid rgba(255, 255, 255, 0.75);
142
+ white-space: nowrap;
143
+ margin: 0 auto;
144
+ letter-spacing: .15em;
145
+ animation:
146
+ zoomInOut 4s infinite;
147
+ }
148
+ .animated-heading {
149
+ animation: fadeIn 2s ease-in-out;
150
+ }
151
+
152
+ .animated-link {
153
+ display: inline-block;
154
+ animation: bounce 3s infinite;
155
+ }
156
+ </style>
157
+ </head>
158
+ <body>
159
+ <div>
160
+ <div class="typewriter">
161
+ <h1 style="display: flex; align-items: center; justify-content: center; margin-bottom: 10px; text-align: center;">
162
+ <img src="https://imgur.com/H2SLps2.png" alt="icon" style="margin-left: 10px; height: 30px;">
163
+ S2I-Artwork
164
+ <img src="https://imgur.com/cNMKSAy.png" alt="icon" style="margin-left: 10px; height: 30px;">:
165
+ Personalized Sketch-to-Art 🧨 Diffusion Models
166
+ <img src="https://imgur.com/yDnDd1p.png" alt="icon" style="margin-left: 10px; height: 30px;">
167
+ </h1>
168
+ </div>
169
+ <h3 class="animated-heading" style="text-align: center; margin-bottom: 10px;">Authors: Vo Nguyen An Tin, Nguyen Thiet Su</h3>
170
+ <h4 class="animated-heading" style="margin-bottom: 10px;">*This project is the fine-tuning task with LorA on large datasets included: COCO-2017, LHQ, Danbooru, LandScape and Mid-Journey V6</h4>
171
+ <h4 class="animated-heading" style="margin-bottom: 10px;">* We public 2 sketch2image-models-lora training on 30K and 60K steps with skip-connection and Transformers Super-Resolution variables</h4>
172
+ <h4 class="animated-heading" style="margin-bottom: 10px;">* The inference and demo time of model is faster, you can slowly in the first runtime, but after that, the time process over 1.5 ~ 2s</h4>
173
+ <h4 class="animated-heading" style="margin-bottom: 10px;">* View the full code project:
174
+ <a class="animated-link" href="https://github.com/aihacker111/S2I-Artwork-Sketch-to-Image/" target="_blank">GitHub Repository</a>
175
+ </h4>
176
+ <h4 class="animated-heading" style="margin-bottom: 10px;">
177
+ <a class="animated-link" href="https://github.com/aihacker111/S2I-Artwork-Sketch-to-Image/" target="_blank">
178
+ <img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="100">
179
+ </a>
180
+ </h4>
181
+ </div>
182
+ </body>
183
+ </html>
184
+ """
185
+ )
186
+ with gr.Row(elem_id="main_row"):
187
+ with gr.Column(elem_id="column_input"):
188
+ gr.Markdown("## SKETCH", elem_id="input_header")
189
+ image = gr.Sketchpad(
190
+ type="pil",
191
+ height=512,
192
+ width=512,
193
+ min_width=512,
194
+ image_mode="RGBA",
195
  show_label=False,
196
+ mirror_webcam=False,
197
+ show_download_button=True,
198
+ elem_id='input_image',
199
+ brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=4),
200
+ canvas_size=(1024, 1024),
201
+ layers=False
202
  )
203
+ input_image = gr.File(label='Input image')
 
 
 
204
 
205
+ download_sketch = gr.Button(
206
+ "Download sketch", scale=1, elem_id="download_sketch"
 
 
 
 
 
207
  )
208
+
209
+ with gr.Column(elem_id="column_output"):
210
+ gr.Markdown("## IMAGE GENERATE", elem_id="output_header")
211
+ result = gr.Image(
212
+ label="Result",
213
+ height=440,
214
+ width=440,
215
+ elem_id="output_image",
216
+ show_label=False,
217
+ show_download_button=True,
218
  )
 
 
 
219
  with gr.Row():
220
+ run_button = gr.Button("Generate 🪄", min_width=5, variant='primary')
221
+ randomize_seed = gr.Button(value='\U0001F3B2', variant='primary')
222
+ clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
223
+ prompt = gr.Textbox(label="Personalized Text", value="", show_label=True)
224
+ with gr.Accordion("S2I Advances Option", open=True):
225
+ with gr.Row():
226
+ ui_mode = gr.Radio(
227
+ choices=["Light Mode", "Dark Mode"],
228
+ value="Light Mode",
229
+ label="Switch Light/Dark Mode UI",
230
+ interactive=True)
231
+ type_image = gr.Radio(
232
+ choices=["RGB", "SKETCH"],
233
+ value="SKETCH",
234
+ label="Type of Image (Color Image or Sketch Image)",
235
+ interactive=True)
236
+ input_type = gr.Radio(
237
+ choices=["live-sketch", "upload"],
238
+ value="live-sketch",
239
+ label="Type Sketch2Image models",
240
+ interactive=True)
241
+ style = gr.Dropdown(
242
+ label="Style",
243
+ choices=controller.STYLE_NAMES,
244
+ value=controller.DEFAULT_STYLE_NAME,
245
+ scale=1,
246
+ )
247
+ prompt_temp = gr.Textbox(
248
+ label="Prompt Style Template",
249
+ value=controller.styles[controller.DEFAULT_STYLE_NAME],
250
+ scale=2,
251
+ max_lines=1,
252
+ )
253
+ seed = gr.Textbox(label="Seed", value='42', scale=1, min_width=50)
254
+ zero_gpu_options = gr.Radio(
255
+ choices=["GPU", "CPU"],
256
+ value="GPU",
257
+ label="GPU & CPU Options Spaces",
258
+ interactive=True)
259
+ half_model = gr.Radio(
260
+ choices=["float32", "float16"],
261
+ value="float16",
262
+ label="Demo Speed",
263
+ interactive=True)
264
+ model_options = gr.Radio(
265
+ choices=["100k", "350k"],
266
+ value="350k",
267
+ label="Type Sketch2Image models",
268
+ interactive=True)
269
 
270
+ val_r = gr.Slider(
271
+ label="Sketch guidance: ",
272
+ show_label=True,
273
+ minimum=0,
274
+ maximum=1,
275
+ value=0.4,
276
+ step=0.01,
277
+ scale=3,
278
+ )
279
+
280
+ demo.load(None, None, None, js=scripts)
281
+ ui_mode.change(None, [], [], js=get_dark_mode())
282
+ randomize_seed.click(
283
+ lambda x: random.randint(0, controller.MAX_SEED),
284
+ inputs=[],
285
+ outputs=seed,
286
+ queue=False,
287
+ api_name=False,
288
+ )
289
+ inputs = [zero_gpu_options, image, prompt, prompt_temp, style, seed, val_r, half_model, model_options, input_type]
290
+ outputs = [result, download_sketch]
291
+ prompt.submit(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
292
+
293
+ input_image.change(
294
+ fn=get_meta_from_image,
295
+ inputs=[
296
+ input_image, type_image
297
+ ],
298
+ outputs=[
299
+ image
300
+ ]
301
+ )
302
+
303
+ style.change(
304
+ lambda x: controller.styles[x],
305
+ inputs=[style],
306
+ outputs=[prompt_temp],
307
+ queue=False,
308
+ api_name=False,
309
+ ).then(
310
+ fn=assign_gpu,
311
+ inputs=inputs,
312
+ outputs=outputs,
313
+ api_name=False,
314
+ )
315
+ clear_button.click(fn=clear_session, inputs=[], outputs=[image, result]).then(
316
+ fn=assign_gpu,
317
+ inputs=inputs,
318
+ outputs=outputs,
319
+ api_name=False,
320
  )
321
+ val_r.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
322
+ run_button.click(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
323
+ image.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
324
 
325
+ if __name__ == '__main__':
326
+ demo.queue()
327
+ demo.launch(debug=True, share=False)
requirements.txt CHANGED
@@ -1,6 +1,87 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.1
2
+ aiofiles==23.2.1
3
+ altair==5.3.0
4
+ annotated-types==0.7.0
5
+ anyio==4.4.0
6
+ attrs==23.2.0
7
+ certifi==2024.2.2
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ contourpy==1.2.1
11
+ cycler==0.12.1
12
+ diffusers==0.25.1
13
+ dnspython==2.6.1
14
+ email_validator==2.1.1
15
+ exceptiongroup==1.2.1
16
+ fastapi==0.111.0
17
+ fastapi-cli==0.0.4
18
+ ffmpy==0.3.2
19
+ filelock==3.14.0
20
+ fonttools==4.52.4
21
+ fsspec==2024.5.0
22
+ gradio==4.29.0
23
+ h11==0.14.0
24
+ httpcore==1.0.5
25
+ httptools==0.6.1
26
+ httpx==0.27.0
27
+ huggingface-hub==0.23.0
28
+ idna==3.7
29
+ importlib_metadata==7.1.0
30
+ importlib_resources==6.4.0
31
+ Jinja2==3.1.4
32
+ jsonschema==4.22.0
33
+ jsonschema-specifications==2023.12.1
34
+ kiwisolver==1.4.5
35
+ markdown-it-py==3.0.0
36
+ MarkupSafe==2.1.5
37
+ matplotlib==3.9.0
38
+ mdurl==0.1.2
39
+ mpmath==1.3.0
40
+ networkx==3.3
41
+ numpy==1.26.4
42
+ orjson==3.10.3
43
+ packaging==24.0
44
+ pandas==2.2.2
45
+ peft==0.11.1
46
+ pillow==10.3.0
47
+ psutil==5.9.8
48
+ pydantic==2.7.2
49
+ pydantic_core==2.18.3
50
+ pydub==0.25.1
51
+ Pygments==2.18.0
52
+ pyparsing==3.1.2
53
+ python-dateutil==2.9.0.post0
54
+ python-dotenv==1.0.1
55
+ python-multipart==0.0.9
56
+ pytz==2024.1
57
+ PyYAML==6.0.1
58
+ referencing==0.35.1
59
+ regex==2024.5.15
60
+ requests==2.32.0
61
+ rich==13.7.1
62
+ rpds-py==0.18.1
63
+ ruff==0.4.6
64
+ safetensors==0.4.3
65
+ semantic-version==2.10.0
66
+ shellingham==1.5.4
67
+ six==1.16.0
68
+ sniffio==1.3.1
69
+ starlette==0.37.2
70
+ sympy==1.12
71
+ tokenizers==0.19.1
72
+ tomlkit==0.12.0
73
+ toolz==0.12.1
74
+ torch==2.3.0
75
+ torchvision==0.18.0
76
+ tqdm==4.66.4
77
+ transformers==4.41.0
78
+ typer==0.12.3
79
+ typing_extensions==4.11.0
80
+ tzdata==2024.1
81
+ ujson==5.10.0
82
+ urllib3==2.2.1
83
+ uvicorn==0.30.0
84
+ uvloop==0.19.0
85
+ watchfiles==0.22.0
86
+ websockets==11.0.3
87
+ zipp==3.18.2