enhance prompModule
Browse files- S2I/commons/controller.py +17 -22
- S2I/modules/models.py +10 -3
- S2I/modules/sketch2image.py +19 -0
- app.py +18 -8
S2I/commons/controller.py
CHANGED
@@ -47,31 +47,27 @@ class Sketch2ImageController():
|
|
47 |
self.pipe = Sketch2ImagePipeline()
|
48 |
self.zero_options = zero_options
|
49 |
|
50 |
-
def update_canvas(self, use_line, use_eraser):
|
51 |
-
brush_size = 20 if use_eraser else 4
|
52 |
-
_color = "#ffffff" if use_eraser else "#000000"
|
53 |
-
return self.gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
|
54 |
-
|
55 |
-
def upload_sketch(self, file):
|
56 |
-
_img = Image.open(file.name).convert("L")
|
57 |
-
return self.gr.update(value=_img, source="upload", interactive=True)
|
58 |
-
|
59 |
@staticmethod
|
60 |
def pil_image_to_data_uri(img, format="PNG"):
|
61 |
buffered = BytesIO()
|
62 |
img.save(buffered, format=format)
|
63 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
64 |
return f"data:image/{format.lower()};base64,{img_str}"
|
65 |
-
|
66 |
-
def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag):
|
67 |
self.load_pipeline(zero_options=options)
|
|
|
|
|
68 |
|
69 |
-
|
|
|
|
|
|
|
70 |
|
71 |
-
if type_flag == '
|
72 |
-
img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
73 |
-
elif type_flag == 'url-sketch':
|
74 |
img = image["composite"]
|
|
|
|
|
75 |
|
76 |
img = img.convert("RGB")
|
77 |
img = img.resize((512, 512))
|
@@ -84,14 +80,13 @@ class Sketch2ImageController():
|
|
84 |
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
85 |
|
86 |
with torch.no_grad():
|
87 |
-
output_image = self.pipe.generate(c_t,
|
88 |
|
89 |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
90 |
|
91 |
-
if type_flag == 'live-sketch':
|
92 |
-
|
93 |
-
else:
|
94 |
-
|
95 |
|
96 |
-
return output_pil
|
97 |
-
# , self.gr.update(link=input_uri)
|
|
|
47 |
self.pipe = Sketch2ImagePipeline()
|
48 |
self.zero_options = zero_options
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
@staticmethod
|
51 |
def pil_image_to_data_uri(img, format="PNG"):
|
52 |
buffered = BytesIO()
|
53 |
img.save(buffered, format=format)
|
54 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
55 |
return f"data:image/{format.lower()};base64,{img_str}"
|
56 |
+
|
57 |
+
def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality):
|
58 |
self.load_pipeline(zero_options=options)
|
59 |
+
prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
|
60 |
+
prompt_enhanced = prompt_template.replace("{prompt}", prompt_enhanced)
|
61 |
|
62 |
+
# if type_flag == 'live-sketch':
|
63 |
+
# img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
64 |
+
# elif type_flag == 'url-sketch':
|
65 |
+
# img = image["composite"]
|
66 |
|
67 |
+
if type_flag == 'URL':
|
|
|
|
|
68 |
img = image["composite"]
|
69 |
+
else:
|
70 |
+
img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
71 |
|
72 |
img = img.convert("RGB")
|
73 |
img = img.resize((512, 512))
|
|
|
80 |
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
81 |
|
82 |
with torch.no_grad():
|
83 |
+
output_image = self.pipe.generate(c_t, prompt_enhanced, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
|
84 |
|
85 |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
86 |
|
87 |
+
# if type_flag == 'live-sketch':
|
88 |
+
# input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
|
89 |
+
# else:
|
90 |
+
# input_uri = self.pil_image_to_data_uri(img)
|
91 |
|
92 |
+
return output_pil
|
|
S2I/modules/models.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import copy
|
3 |
import os
|
4 |
from diffusers import DDPMScheduler
|
5 |
-
from transformers import AutoTokenizer, CLIPTextModel
|
6 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
7 |
from peft import LoraConfig
|
8 |
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
|
@@ -29,6 +29,8 @@ class PrimaryModel:
|
|
29 |
self.global_tokenizer = None
|
30 |
self.global_text_encoder = None
|
31 |
self.global_scheduler = None
|
|
|
|
|
32 |
|
33 |
@staticmethod
|
34 |
def _load_model(path, model_class, unet_mode=False):
|
@@ -62,9 +64,14 @@ class PrimaryModel:
|
|
62 |
sd = torch.load(p_ckpt, map_location="cpu")
|
63 |
return sd
|
64 |
def from_pretrained(self, model_name, r):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
if self.global_tokenizer is None:
|
66 |
-
# self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
|
67 |
-
# subfolder="tokenizer")
|
68 |
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
|
69 |
|
70 |
if self.global_text_encoder is None:
|
|
|
2 |
import copy
|
3 |
import os
|
4 |
from diffusers import DDPMScheduler
|
5 |
+
from transformers import AutoTokenizer, CLIPTextModel, pipeline
|
6 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
7 |
from peft import LoraConfig
|
8 |
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
|
|
|
29 |
self.global_tokenizer = None
|
30 |
self.global_text_encoder = None
|
31 |
self.global_scheduler = None
|
32 |
+
self.global_medium_prompt = None
|
33 |
+
self.global_long_prompt = None
|
34 |
|
35 |
@staticmethod
|
36 |
def _load_model(path, model_class, unet_mode=False):
|
|
|
64 |
sd = torch.load(p_ckpt, map_location="cpu")
|
65 |
return sd
|
66 |
def from_pretrained(self, model_name, r):
|
67 |
+
|
68 |
+
if self.global_meidum_prompt is None:
|
69 |
+
self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda')
|
70 |
+
|
71 |
+
if self.global_long_prompt is None:
|
72 |
+
self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda')
|
73 |
+
|
74 |
if self.global_tokenizer is None:
|
|
|
|
|
75 |
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
|
76 |
|
77 |
if self.global_text_encoder is None:
|
S2I/modules/sketch2image.py
CHANGED
@@ -72,6 +72,25 @@ class Sketch2ImagePipeline(PrimaryModel):
|
|
72 |
self.global_unet.set_adapters(["default"], weights=[r])
|
73 |
set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def _move_to_cpu(self, module):
|
76 |
module.to("cpu")
|
77 |
|
|
|
72 |
self.global_unet.set_adapters(["default"], weights=[r])
|
73 |
set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
|
74 |
|
75 |
+
def automatic_enhance_prompt(self, input_prompt, model_choice):
|
76 |
+
|
77 |
+
if model_choice == "short-sentences":
|
78 |
+
result = self.global_medium_prompt("Enhance the description: " + input_prompt)
|
79 |
+
enhanced_text = result[0]['summary_text']
|
80 |
+
|
81 |
+
pattern = r'^.*?of\s+(.*?(?:\.|$))'
|
82 |
+
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
|
83 |
+
|
84 |
+
if match:
|
85 |
+
remaining_text = enhanced_text[match.end():].strip()
|
86 |
+
modified_sentence = match.group(1).capitalize()
|
87 |
+
enhanced_text = modified_sentence + ' ' + remaining_text
|
88 |
+
else:
|
89 |
+
result = self.global_long_prompt("Enhance the description: " + input_prompt)
|
90 |
+
enhanced_text = result[0]['summary_text']
|
91 |
+
|
92 |
+
return enhanced_text
|
93 |
+
|
94 |
def _move_to_cpu(self, module):
|
95 |
module.to("cpu")
|
96 |
|
app.py
CHANGED
@@ -118,7 +118,7 @@ def get_meta_from_image(input_img, type_image):
|
|
118 |
# Convert the processed image back to PIL Image
|
119 |
img_pil = Image.fromarray(processed_img.astype('uint8'))
|
120 |
|
121 |
-
return img_pil
|
122 |
|
123 |
|
124 |
with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
@@ -267,10 +267,20 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
267 |
clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
|
268 |
with gr.Accordion("S2I Advances Option", open=True):
|
269 |
with gr.Row():
|
270 |
-
input_type = gr.Radio(
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
interactive=True)
|
275 |
|
276 |
style = gr.Dropdown(
|
@@ -307,7 +317,7 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
307 |
queue=False,
|
308 |
api_name=False,
|
309 |
)
|
310 |
-
inputs = [zero_gpu_options, image, prompt, prompt_temp, style, seed, val_r, half_model, model_options, input_type]
|
311 |
outputs = [result]
|
312 |
prompt.submit(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
|
313 |
|
@@ -328,8 +338,8 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
328 |
val_r.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
329 |
run_button.click(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
|
330 |
image.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
331 |
-
url_image.submit(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image])
|
332 |
-
url_image.change(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image])
|
333 |
if __name__ == '__main__':
|
334 |
demo.queue()
|
335 |
demo.launch(debug=True, share=False)
|
|
|
118 |
# Convert the processed image back to PIL Image
|
119 |
img_pil = Image.fromarray(processed_img.astype('uint8'))
|
120 |
|
121 |
+
return img_pil, 'URL'
|
122 |
|
123 |
|
124 |
with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
|
267 |
clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
|
268 |
with gr.Accordion("S2I Advances Option", open=True):
|
269 |
with gr.Row():
|
270 |
+
# input_type = gr.Radio(
|
271 |
+
# choices=["live-sketch", "url-sketch"],
|
272 |
+
# value="live-sketch",
|
273 |
+
# label="Type Sketch2Image models",
|
274 |
+
# interactive=True)
|
275 |
+
|
276 |
+
input_type = gr.Textbox(
|
277 |
+
label="Check URL or Real-time Input",
|
278 |
+
interactive=True)
|
279 |
+
|
280 |
+
prompt_quality = gr.Radio(
|
281 |
+
choices=["short-sentences", "long-sentences"],
|
282 |
+
value="short-sentences",
|
283 |
+
label="Long/Short of Text Prompt",
|
284 |
interactive=True)
|
285 |
|
286 |
style = gr.Dropdown(
|
|
|
317 |
queue=False,
|
318 |
api_name=False,
|
319 |
)
|
320 |
+
inputs = [zero_gpu_options, image, prompt, prompt_temp, style, seed, val_r, half_model, model_options, input_type, prompt_quality]
|
321 |
outputs = [result]
|
322 |
prompt.submit(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
|
323 |
|
|
|
338 |
val_r.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
339 |
run_button.click(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
|
340 |
image.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
341 |
+
url_image.submit(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image, input_type])
|
342 |
+
url_image.change(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image, input_type])
|
343 |
if __name__ == '__main__':
|
344 |
demo.queue()
|
345 |
demo.launch(debug=True, share=False)
|