update enhance prompt
Browse files- S2I/commons/controller.py +4 -13
- S2I/modules/models.py +2 -2
- S2I/modules/sketch2image.py +2 -6
- app.py +6 -15
S2I/commons/controller.py
CHANGED
@@ -58,16 +58,11 @@ class Sketch2ImageController():
|
|
58 |
self.load_pipeline(zero_options=options)
|
59 |
# prompt = prompt_template.replace("{prompt}", prompt)
|
60 |
|
61 |
-
|
62 |
-
# img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
63 |
-
# elif type_flag == 'url-sketch':
|
64 |
-
# img = image["composite"]
|
65 |
-
|
66 |
-
if type_flag == 'URL':
|
67 |
-
img = image["composite"]
|
68 |
-
else:
|
69 |
img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
70 |
-
|
|
|
|
|
71 |
img = img.convert("RGB")
|
72 |
img = img.resize((512, 512))
|
73 |
|
@@ -83,9 +78,5 @@ class Sketch2ImageController():
|
|
83 |
|
84 |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
85 |
|
86 |
-
# if type_flag == 'live-sketch':
|
87 |
-
# input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
|
88 |
-
# else:
|
89 |
-
# input_uri = self.pil_image_to_data_uri(img)
|
90 |
|
91 |
return output_pil
|
|
|
58 |
self.load_pipeline(zero_options=options)
|
59 |
# prompt = prompt_template.replace("{prompt}", prompt)
|
60 |
|
61 |
+
if type_flag == 'live-sketch':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
63 |
+
elif type_flag == 'url-sketch':
|
64 |
+
img = image["composite"]
|
65 |
+
|
66 |
img = img.convert("RGB")
|
67 |
img = img.resize((512, 512))
|
68 |
|
|
|
78 |
|
79 |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
80 |
|
|
|
|
|
|
|
|
|
81 |
|
82 |
return output_pil
|
S2I/modules/models.py
CHANGED
@@ -65,10 +65,10 @@ class PrimaryModel:
|
|
65 |
return sd
|
66 |
def from_pretrained(self, model_name, r):
|
67 |
if self.global_medium_prompt is None:
|
68 |
-
self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda')
|
69 |
|
70 |
if self.global_long_prompt is None:
|
71 |
-
self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda')
|
72 |
|
73 |
if self.global_tokenizer is None:
|
74 |
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
|
|
|
65 |
return sd
|
66 |
def from_pretrained(self, model_name, r):
|
67 |
if self.global_medium_prompt is None:
|
68 |
+
self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda' if torch.cuda.is_available() else 'cpu')
|
69 |
|
70 |
if self.global_long_prompt is None:
|
71 |
+
self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda' if torch.cuda.is_available() else 'cpu')
|
72 |
|
73 |
if self.global_tokenizer is None:
|
74 |
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
|
S2I/modules/sketch2image.py
CHANGED
@@ -75,8 +75,8 @@ class Sketch2ImagePipeline(PrimaryModel):
|
|
75 |
self.global_unet.set_adapters(["default"], weights=[r])
|
76 |
set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
|
77 |
|
78 |
-
def automatic_enhance_prompt(self, input_prompt,
|
79 |
-
if
|
80 |
result = self.global_medium_prompt("Enhance the description: " + input_prompt)
|
81 |
enhanced_text = result[0]['summary_text']
|
82 |
|
@@ -87,10 +87,6 @@ class Sketch2ImagePipeline(PrimaryModel):
|
|
87 |
remaining_text = enhanced_text[match.end():].strip()
|
88 |
modified_sentence = match.group(1).capitalize()
|
89 |
enhanced_text = modified_sentence + ' ' + remaining_text
|
90 |
-
else:
|
91 |
-
result = self.global_long_prompt("Enhance the description: " + input_prompt)
|
92 |
-
enhanced_text = result[0]['summary_text']
|
93 |
-
|
94 |
return enhanced_text
|
95 |
|
96 |
def _move_to_cpu(self, module):
|
|
|
75 |
self.global_unet.set_adapters(["default"], weights=[r])
|
76 |
set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
|
77 |
|
78 |
+
def automatic_enhance_prompt(self, input_prompt, prompt_quality):
|
79 |
+
if prompt_quality:
|
80 |
result = self.global_medium_prompt("Enhance the description: " + input_prompt)
|
81 |
enhanced_text = result[0]['summary_text']
|
82 |
|
|
|
87 |
remaining_text = enhanced_text[match.end():].strip()
|
88 |
modified_sentence = match.group(1).capitalize()
|
89 |
enhanced_text = modified_sentence + ' ' + remaining_text
|
|
|
|
|
|
|
|
|
90 |
return enhanced_text
|
91 |
|
92 |
def _move_to_cpu(self, module):
|
app.py
CHANGED
@@ -260,6 +260,7 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
260 |
show_download_button=True,
|
261 |
)
|
262 |
with gr.Group():
|
|
|
263 |
prompt = gr.Textbox(label="Personalized Text", value="", show_label=True)
|
264 |
with gr.Row():
|
265 |
run_button = gr.Button("Generate 🪄", min_width=5, variant='primary')
|
@@ -267,22 +268,12 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
267 |
clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
|
268 |
with gr.Accordion("S2I Advances Option", open=True):
|
269 |
with gr.Row():
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
# interactive=True)
|
275 |
-
|
276 |
-
input_type = gr.Textbox(
|
277 |
-
label="Check URL or Real-time Input",
|
278 |
-
interactive=True)
|
279 |
-
|
280 |
-
prompt_quality = gr.Radio(
|
281 |
-
choices=["short-sentences", "long-sentences"],
|
282 |
-
value="short-sentences",
|
283 |
-
label="Long/Short of Text Prompt",
|
284 |
interactive=True)
|
285 |
-
|
286 |
style = gr.Dropdown(
|
287 |
label="Style",
|
288 |
choices=controller.STYLE_NAMES,
|
|
|
260 |
show_download_button=True,
|
261 |
)
|
262 |
with gr.Group():
|
263 |
+
use_enhancer = gr.Checkbox(label="Use Automatic Prompt High-Quality", value=False)
|
264 |
prompt = gr.Textbox(label="Personalized Text", value="", show_label=True)
|
265 |
with gr.Row():
|
266 |
run_button = gr.Button("Generate 🪄", min_width=5, variant='primary')
|
|
|
268 |
clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
|
269 |
with gr.Accordion("S2I Advances Option", open=True):
|
270 |
with gr.Row():
|
271 |
+
input_type = gr.Radio(
|
272 |
+
choices=["live-sketch", "url-sketch"],
|
273 |
+
value="live-sketch",
|
274 |
+
label="Type Sketch2Image models",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
interactive=True)
|
276 |
+
|
277 |
style = gr.Dropdown(
|
278 |
label="Style",
|
279 |
choices=controller.STYLE_NAMES,
|