Spaces:
Sleeping
Sleeping
JingyeChen22
commited on
Commit
•
9de996f
1
Parent(s):
4595437
Update app.py
Browse files
app.py
CHANGED
@@ -26,10 +26,6 @@ os.system('wget https://huggingface.co/datasets/JingyeChen22/TextDiffuser/resolv
|
|
26 |
if not os.path.exists('Arial.ttf'):
|
27 |
os.system('wget https://huggingface.co/datasets/JingyeChen22/TextDiffuser/resolve/main/Arial.ttf')
|
28 |
|
29 |
-
|
30 |
-
os.system('echo finish')
|
31 |
-
os.system('ls -a')
|
32 |
-
|
33 |
import cv2
|
34 |
import random
|
35 |
import logging
|
@@ -67,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|
67 |
import transformers
|
68 |
from transformers import CLIPTextModel, CLIPTokenizer
|
69 |
|
70 |
-
from util import segmentation_mask_visualization, make_caption_pil, combine_image,
|
71 |
from model.layout_generator import get_layout_from_prompt
|
72 |
from model.text_segmenter.unet import UNet
|
73 |
|
@@ -364,20 +360,40 @@ if accelerator.is_main_process:
|
|
364 |
print(args.output_dir)
|
365 |
|
366 |
# Load scheduler, tokenizer and models.
|
367 |
-
|
368 |
-
|
369 |
)
|
370 |
-
|
371 |
-
|
|
|
|
|
|
|
|
|
372 |
)
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
).cuda()
|
377 |
|
|
|
|
|
|
|
|
|
|
|
378 |
# Freeze vae and text_encoder
|
379 |
-
|
380 |
-
|
|
|
|
|
381 |
|
382 |
if args.enable_xformers_memory_efficient_attention:
|
383 |
if is_xformers_available():
|
@@ -421,7 +437,6 @@ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
|
421 |
|
422 |
|
423 |
# setup schedulers
|
424 |
-
scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
425 |
# sample_num = args.vis_num
|
426 |
|
427 |
def to_tensor(image):
|
@@ -461,7 +476,25 @@ def has_chinese_char(string):
|
|
461 |
|
462 |
image_404 = Image.open('404.jpg')
|
463 |
|
464 |
-
def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
if has_chinese_char(prompt):
|
467 |
print('trigger')
|
@@ -484,7 +517,7 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
|
|
484 |
set_seed(seed)
|
485 |
scheduler.set_timesteps(slider_step)
|
486 |
|
487 |
-
noise = torch.randn((sample_num, 4,
|
488 |
input = noise # (b, 4, 64, 64)
|
489 |
|
490 |
captions = [args.prompt] * sample_num
|
@@ -504,25 +537,18 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
|
|
504 |
encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
|
505 |
print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
|
506 |
|
507 |
-
# load character-level segmenter
|
508 |
-
segmenter = UNet(3, 96, True).cuda()
|
509 |
-
segmenter = torch.nn.DataParallel(segmenter)
|
510 |
-
segmenter.load_state_dict(torch.load(args.character_segmenter_path))
|
511 |
-
segmenter.eval()
|
512 |
-
print(f'{colored("[√]", "green")} Text segmenter is successfully loaded.')
|
513 |
-
|
514 |
#### text-to-image ####
|
515 |
render_image, segmentation_mask_from_pillow = get_layout_from_prompt(args)
|
516 |
|
517 |
segmentation_mask = torch.Tensor(np.array(segmentation_mask_from_pillow)).cuda() # (512, 512)
|
518 |
|
519 |
segmentation_mask = filter_segmentation_mask(segmentation_mask)
|
520 |
-
segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(
|
521 |
segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (1, 1, 256, 256)
|
522 |
print(f'{colored("[√]", "green")} character-level segmentation_mask: {segmentation_mask.shape}.')
|
523 |
|
524 |
-
feature_mask = torch.ones(sample_num, 1,
|
525 |
-
masked_image = torch.zeros(sample_num, 3,
|
526 |
masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
|
527 |
masked_feature = masked_feature * vae.config.scaling_factor
|
528 |
print(f'{colored("[√]", "green")} feature_mask: {feature_mask.shape}.')
|
@@ -543,10 +569,11 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
|
|
543 |
input = 1 / vae.config.scaling_factor * input
|
544 |
sample_images = vae.decode(input.float(), return_dict=False)[0] # (b, 3, 512, 512)
|
545 |
|
546 |
-
image_pil = render_image.resize((
|
547 |
segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
|
548 |
-
character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((
|
549 |
character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
|
|
|
550 |
caption_pil = make_caption_pil(args.font_path, captions)
|
551 |
|
552 |
# save pred_img
|
@@ -557,12 +584,12 @@ def text_to_image(prompt,slider_step,slider_guidance,slider_batch):
|
|
557 |
image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
|
558 |
pred_image_list.append(image)
|
559 |
|
560 |
-
blank_pil = combine_image(args, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
|
561 |
|
562 |
-
intermediate_result = Image.new('RGB', (
|
563 |
intermediate_result.paste(image_pil, (0, 0))
|
564 |
-
intermediate_result.paste(character_mask_pil, (
|
565 |
-
intermediate_result.paste(character_mask_highlight_pil, (
|
566 |
|
567 |
return blank_pil, intermediate_result
|
568 |
|
@@ -577,7 +604,25 @@ print(f'{colored("[√]", "green")} Text segmenter is successfully loaded.')
|
|
577 |
|
578 |
|
579 |
|
580 |
-
def text_to_image_with_template(prompt,template_image,slider_step,slider_guidance,slider_batch, binary):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
|
582 |
if has_chinese_char(prompt):
|
583 |
print('trigger')
|
@@ -586,7 +631,7 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
|
|
586 |
if slider_step>=50:
|
587 |
slider_step = 50
|
588 |
|
589 |
-
orig_template_image = template_image.resize((
|
590 |
args.prompt = prompt
|
591 |
sample_num = slider_batch
|
592 |
# If passed along, set the training seed now.
|
@@ -595,7 +640,7 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
|
|
595 |
set_seed(seed)
|
596 |
scheduler.set_timesteps(slider_step)
|
597 |
|
598 |
-
noise = torch.randn((sample_num, 4,
|
599 |
input = noise # (b, 4, 64, 64)
|
600 |
|
601 |
captions = [args.prompt] * sample_num
|
@@ -634,12 +679,12 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
|
|
634 |
segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) # (256, 256)
|
635 |
segmentation_mask = filter_segmentation_mask(segmentation_mask) # (256, 256)
|
636 |
|
637 |
-
segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(
|
638 |
segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (b, 1, 256, 256)
|
639 |
print(f'{colored("[√]", "green")} Character-level segmentation_mask: {segmentation_mask.shape}.')
|
640 |
|
641 |
-
feature_mask = torch.ones(sample_num, 1,
|
642 |
-
masked_image = torch.zeros(sample_num, 3,
|
643 |
masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
|
644 |
masked_feature = masked_feature * vae.config.scaling_factor # (b, 4, 64, 64)
|
645 |
|
@@ -660,8 +705,9 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
|
|
660 |
|
661 |
image_pil = None
|
662 |
segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
|
663 |
-
character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((
|
664 |
character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
|
|
|
665 |
caption_pil = make_caption_pil(args.font_path, captions)
|
666 |
|
667 |
# save pred_img
|
@@ -672,17 +718,35 @@ def text_to_image_with_template(prompt,template_image,slider_step,slider_guidanc
|
|
672 |
image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
|
673 |
pred_image_list.append(image)
|
674 |
|
675 |
-
blank_pil = combine_image(args, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
|
676 |
|
677 |
-
intermediate_result = Image.new('RGB', (
|
678 |
intermediate_result.paste(orig_template_image, (0, 0))
|
679 |
-
intermediate_result.paste(character_mask_pil, (
|
680 |
-
intermediate_result.paste(character_mask_highlight_pil, (
|
681 |
|
682 |
return blank_pil, intermediate_result
|
683 |
|
684 |
|
685 |
-
def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
|
687 |
if has_chinese_char(prompt):
|
688 |
print('trigger')
|
@@ -699,7 +763,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
|
|
699 |
set_seed(seed)
|
700 |
scheduler.set_timesteps(slider_step)
|
701 |
|
702 |
-
noise = torch.randn((sample_num, 4,
|
703 |
input = noise # (b, 4, 64, 64)
|
704 |
|
705 |
captions = [args.prompt] * sample_num
|
@@ -719,7 +783,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
|
|
719 |
encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
|
720 |
print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
|
721 |
|
722 |
-
mask_image = cv2.resize(mask_image, (
|
723 |
# mask_image = mask_image.resize((512,512)).convert('RGB')
|
724 |
text_mask = np.array(mask_image)
|
725 |
threshold = 128
|
@@ -732,21 +796,21 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
|
|
732 |
|
733 |
segmentation_mask = segmentation_mask.max(1)[1].squeeze(0)
|
734 |
segmentation_mask = filter_segmentation_mask(segmentation_mask)
|
735 |
-
segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(
|
736 |
|
737 |
-
image_mask = transform_mask_pil(mask_image)
|
738 |
image_mask = torch.from_numpy(image_mask).cuda().unsqueeze(0).unsqueeze(0)
|
739 |
|
740 |
-
orig_image = orig_image.convert('RGB').resize((
|
741 |
image = orig_image
|
742 |
image_tensor = to_tensor(image).unsqueeze(0).cuda().sub_(0.5).div_(0.5)
|
743 |
masked_image = image_tensor * (1-image_mask)
|
744 |
masked_feature = vae.encode(masked_image).latent_dist.sample().repeat(sample_num, 1, 1, 1)
|
745 |
masked_feature = masked_feature * vae.config.scaling_factor
|
746 |
|
747 |
-
image_mask = torch.nn.functional.interpolate(image_mask, size=(
|
748 |
segmentation_mask = segmentation_mask * image_mask
|
749 |
-
feature_mask = torch.nn.functional.interpolate(image_mask, size=(
|
750 |
|
751 |
# diffusion process
|
752 |
intermediate_images = []
|
@@ -767,6 +831,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
|
|
767 |
segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
|
768 |
character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512))
|
769 |
character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
|
|
|
770 |
caption_pil = make_caption_pil(args.font_path, captions)
|
771 |
|
772 |
# save pred_img
|
@@ -786,7 +851,7 @@ def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,sli
|
|
786 |
character_mask_highlight_pil.save('character_mask_highlight_pil.png')
|
787 |
|
788 |
|
789 |
-
blank_pil = combine_image(args, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
|
790 |
|
791 |
|
792 |
background = orig_image.resize((512, 512))
|
@@ -825,6 +890,11 @@ with gr.Blocks() as demo:
|
|
825 |
We propose <b>TextDiffuser</b>, a flexible and controllable framework to generate images with visually appealing text that is coherent with backgrounds.
|
826 |
Main features include: (a) <b><font color="#A52A2A">Text-to-Image</font></b>: The user provides a prompt and encloses the keywords with single quotes (e.g., a text image of ‘hello’). The model first determines the layout of the keywords and then draws the image based on the layout and prompt. (b) <b><font color="#A52A2A">Text-to-Image with Templates</font></b>: The user provides a prompt and a template image containing text, which can be a printed, handwritten, or scene text image. These template images can be used to determine the layout of the characters. (c) <b><font color="#A52A2A">Text Inpainting</font></b>: The user provides an image and specifies the region to be modified along with the desired text content. The model is able to modify the original text or add text to areas without text.
|
827 |
</h2>
|
|
|
|
|
|
|
|
|
|
|
828 |
<img src="file/images/huggingface_blank.jpg" alt="textdiffuser">
|
829 |
</div>
|
830 |
""")
|
@@ -833,9 +903,10 @@ with gr.Blocks() as demo:
|
|
833 |
with gr.Row():
|
834 |
with gr.Column(scale=1):
|
835 |
prompt = gr.Textbox(label="Input your prompt here. Please enclose keywords with 'single quotes', you may refer to the examples below. The current version only supports input in English characters.", placeholder="Placeholder 'Team' hat")
|
|
|
836 |
slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
|
837 |
slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
|
838 |
-
slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
|
839 |
# slider_seed = gr.Slider(minimum=1, maximum=10000, label="Seed", randomize=True)
|
840 |
button = gr.Button("Generate")
|
841 |
|
@@ -851,7 +922,7 @@ with gr.Blocks() as demo:
|
|
851 |
[
|
852 |
["Distinguished poster of 'SPIDERMAN'. Trending on ArtStation and Pixiv. A vibrant digital oil painting. A highly detailed fantasy character illustration by Wayne Reynolds and Charles Monet and Gustave Dore and Carl Critchlow and Bram Sels"],
|
853 |
["A detailed portrait of a fox guardian with a shield with 'Kung Fu' written on it, by victo ngai and justin gerard, digital art, realistic painting, very detailed, fantasy, high definition, cinematic light, dnd, trending on artstation"],
|
854 |
-
["portrait
|
855 |
["elderly woman dressed in extremely colorful clothes with many strange patterns posing for a high fashion photoshoot of 'FASHION', haute couture, golden hour, artstation, by J. C. Leyendecker and Peter Paul Rubens"],
|
856 |
["epic digital art of a luxury yacht named 'Time Machine' driving through very dark hard edged city towers from tron movie, faint tall mountains in background, wlop, pixiv"],
|
857 |
["A poster of 'Adventurer'. A beautiful so tall boy with big eyes and small nose is in the jungle, he wears normal clothes and shows his full length, which we see from the front, unreal engine, cozy indoor lighting, artstation, detailed"],
|
@@ -876,16 +947,17 @@ with gr.Blocks() as demo:
|
|
876 |
examples_per_page=100
|
877 |
)
|
878 |
|
879 |
-
button.click(text_to_image, inputs=[prompt,slider_step,slider_guidance,slider_batch], outputs=[output,intermediate_results])
|
880 |
|
881 |
with gr.Tab("Text-to-Image-with-Template"):
|
882 |
with gr.Row():
|
883 |
with gr.Column(scale=1):
|
884 |
prompt = gr.Textbox(label='Input your prompt here.')
|
885 |
template_image = gr.Image(label='Template image', type="pil")
|
|
|
886 |
slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
|
887 |
slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
|
888 |
-
slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
|
889 |
# binary = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
|
890 |
binary = gr.Checkbox(label="Binarization", bool=True, info="Whether to binarize the template image? You may need it when using handwritten images as templates.")
|
891 |
button = gr.Button("Generate")
|
@@ -923,7 +995,7 @@ with gr.Blocks() as demo:
|
|
923 |
examples_per_page=100
|
924 |
)
|
925 |
|
926 |
-
button.click(text_to_image_with_template, inputs=[prompt,template_image,slider_step,slider_guidance,slider_batch,binary], outputs=[output,intermediate_results])
|
927 |
|
928 |
with gr.Tab("Text-Inpainting"):
|
929 |
with gr.Row():
|
@@ -932,9 +1004,10 @@ with gr.Blocks() as demo:
|
|
932 |
with gr.Row():
|
933 |
orig_image = gr.Image(label='Original image', type="pil")
|
934 |
mask_image = gr.Image(label='Mask image', type="numpy")
|
|
|
935 |
slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
|
936 |
slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
|
937 |
-
slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
|
938 |
button = gr.Button("Generate")
|
939 |
with gr.Column(scale=1):
|
940 |
output = gr.Image(label='Generated image')
|
@@ -969,7 +1042,7 @@ with gr.Blocks() as demo:
|
|
969 |
)
|
970 |
|
971 |
|
972 |
-
button.click(text_inpainting, inputs=[prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch], outputs=[output, intermediate_results])
|
973 |
|
974 |
|
975 |
|
|
|
26 |
if not os.path.exists('Arial.ttf'):
|
27 |
os.system('wget https://huggingface.co/datasets/JingyeChen22/TextDiffuser/resolve/main/Arial.ttf')
|
28 |
|
|
|
|
|
|
|
|
|
29 |
import cv2
|
30 |
import random
|
31 |
import logging
|
|
|
63 |
import transformers
|
64 |
from transformers import CLIPTextModel, CLIPTokenizer
|
65 |
|
66 |
+
from util import segmentation_mask_visualization, make_caption_pil, combine_image, transform_mask_pil, filter_segmentation_mask, inpainting_merge_image
|
67 |
from model.layout_generator import get_layout_from_prompt
|
68 |
from model.text_segmenter.unet import UNet
|
69 |
|
|
|
360 |
print(args.output_dir)
|
361 |
|
362 |
# Load scheduler, tokenizer and models.
|
363 |
+
tokenizer15 = CLIPTokenizer.from_pretrained(
|
364 |
+
'runwayml/stable-diffusion-v1-5', subfolder="tokenizer", revision=args.revision
|
365 |
)
|
366 |
+
tokenizer21 = CLIPTokenizer.from_pretrained(
|
367 |
+
'stabilityai/stable-diffusion-2-1', subfolder="tokenizer", revision=args.revision
|
368 |
+
)
|
369 |
+
|
370 |
+
text_encoder15 = CLIPTextModel.from_pretrained(
|
371 |
+
'runwayml/stable-diffusion-v1-5', subfolder="text_encoder", revision=args.revision
|
372 |
)
|
373 |
+
text_encoder21 = CLIPTextModel.from_pretrained(
|
374 |
+
'stabilityai/stable-diffusion-2-1', subfolder="text_encoder", revision=args.revision
|
375 |
+
)
|
376 |
+
|
377 |
+
vae15 = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae", revision=args.revision).cuda()
|
378 |
+
unet15 = UNet2DConditionModel.from_pretrained(
|
379 |
+
'textdiffuser-ckpt/diffusion_backbone_1.5', subfolder="unet", revision=None
|
380 |
+
).cuda()
|
381 |
+
|
382 |
+
vae21 = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder="vae", revision=args.revision).cuda()
|
383 |
+
unet21 = UNet2DConditionModel.from_pretrained(
|
384 |
+
'textdiffuser-ckpt/diffusion_backbone_2.1', subfolder="unet", revision=None
|
385 |
).cuda()
|
386 |
|
387 |
+
scheduler15 = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
|
388 |
+
scheduler21 = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder="scheduler")
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
# Freeze vae and text_encoder
|
393 |
+
vae15.requires_grad_(False)
|
394 |
+
vae21.requires_grad_(False)
|
395 |
+
text_encoder15.requires_grad_(False)
|
396 |
+
text_encoder21.requires_grad_(False)
|
397 |
|
398 |
if args.enable_xformers_memory_efficient_attention:
|
399 |
if is_xformers_available():
|
|
|
437 |
|
438 |
|
439 |
# setup schedulers
|
|
|
440 |
# sample_num = args.vis_num
|
441 |
|
442 |
def to_tensor(image):
|
|
|
476 |
|
477 |
image_404 = Image.open('404.jpg')
|
478 |
|
479 |
+
def text_to_image(prompt,slider_step,slider_guidance,slider_batch, version):
|
480 |
+
print(f'【version】{version}')
|
481 |
+
if version == 'Stable Diffusion v2.1':
|
482 |
+
vae = vae21
|
483 |
+
unet = unet21
|
484 |
+
text_encoder = text_encoder21
|
485 |
+
tokenizer = tokenizer21
|
486 |
+
scheduler = scheduler21
|
487 |
+
slider_batch = min(slider_batch, 2)
|
488 |
+
size = 768
|
489 |
+
elif version == 'Stable Diffusion v1.5':
|
490 |
+
vae = vae15
|
491 |
+
unet = unet15
|
492 |
+
text_encoder = text_encoder15
|
493 |
+
tokenizer = tokenizer15
|
494 |
+
scheduler = scheduler15
|
495 |
+
size = 512
|
496 |
+
else:
|
497 |
+
assert False, 'Version Not Found'
|
498 |
|
499 |
if has_chinese_char(prompt):
|
500 |
print('trigger')
|
|
|
517 |
set_seed(seed)
|
518 |
scheduler.set_timesteps(slider_step)
|
519 |
|
520 |
+
noise = torch.randn((sample_num, 4, size//8, size//8)).to("cuda") # (b, 4, 64, 64)
|
521 |
input = noise # (b, 4, 64, 64)
|
522 |
|
523 |
captions = [args.prompt] * sample_num
|
|
|
537 |
encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
|
538 |
print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
|
539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
#### text-to-image ####
|
541 |
render_image, segmentation_mask_from_pillow = get_layout_from_prompt(args)
|
542 |
|
543 |
segmentation_mask = torch.Tensor(np.array(segmentation_mask_from_pillow)).cuda() # (512, 512)
|
544 |
|
545 |
segmentation_mask = filter_segmentation_mask(segmentation_mask)
|
546 |
+
segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(size//2, size//2), mode='nearest')
|
547 |
segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (1, 1, 256, 256)
|
548 |
print(f'{colored("[√]", "green")} character-level segmentation_mask: {segmentation_mask.shape}.')
|
549 |
|
550 |
+
feature_mask = torch.ones(sample_num, 1, size//8, size//8).to('cuda') # (b, 1, 64, 64)
|
551 |
+
masked_image = torch.zeros(sample_num, 3, size, size).to('cuda') # (b, 3, 512, 512)
|
552 |
masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
|
553 |
masked_feature = masked_feature * vae.config.scaling_factor
|
554 |
print(f'{colored("[√]", "green")} feature_mask: {feature_mask.shape}.')
|
|
|
569 |
input = 1 / vae.config.scaling_factor * input
|
570 |
sample_images = vae.decode(input.float(), return_dict=False)[0] # (b, 3, 512, 512)
|
571 |
|
572 |
+
image_pil = render_image.resize((size,size))
|
573 |
segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
|
574 |
+
character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((size,size))
|
575 |
character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
|
576 |
+
character_mask_highlight_pil = character_mask_highlight_pil.resize((size, size))
|
577 |
caption_pil = make_caption_pil(args.font_path, captions)
|
578 |
|
579 |
# save pred_img
|
|
|
584 |
image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
|
585 |
pred_image_list.append(image)
|
586 |
|
587 |
+
blank_pil = combine_image(args, size, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
|
588 |
|
589 |
+
intermediate_result = Image.new('RGB', (size*3, size))
|
590 |
intermediate_result.paste(image_pil, (0, 0))
|
591 |
+
intermediate_result.paste(character_mask_pil, (size, 0))
|
592 |
+
intermediate_result.paste(character_mask_highlight_pil, (size*2, 0))
|
593 |
|
594 |
return blank_pil, intermediate_result
|
595 |
|
|
|
604 |
|
605 |
|
606 |
|
607 |
+
def text_to_image_with_template(prompt,template_image,slider_step,slider_guidance,slider_batch, binary, version):
|
608 |
+
|
609 |
+
if version == 'Stable Diffusion v2.1':
|
610 |
+
vae = vae21
|
611 |
+
unet = unet21
|
612 |
+
text_encoder = text_encoder21
|
613 |
+
tokenizer = tokenizer21
|
614 |
+
scheduler = scheduler21
|
615 |
+
slider_batch = min(slider_batch, 2)
|
616 |
+
size = 768
|
617 |
+
elif version == 'Stable Diffusion v1.5':
|
618 |
+
vae = vae15
|
619 |
+
unet = unet15
|
620 |
+
text_encoder = text_encoder15
|
621 |
+
tokenizer = tokenizer15
|
622 |
+
scheduler = scheduler15
|
623 |
+
size = 512
|
624 |
+
else:
|
625 |
+
assert False, 'Version Not Found'
|
626 |
|
627 |
if has_chinese_char(prompt):
|
628 |
print('trigger')
|
|
|
631 |
if slider_step>=50:
|
632 |
slider_step = 50
|
633 |
|
634 |
+
orig_template_image = template_image.resize((size,size)).convert('RGB')
|
635 |
args.prompt = prompt
|
636 |
sample_num = slider_batch
|
637 |
# If passed along, set the training seed now.
|
|
|
640 |
set_seed(seed)
|
641 |
scheduler.set_timesteps(slider_step)
|
642 |
|
643 |
+
noise = torch.randn((sample_num, 4, size//8, size//8)).to("cuda") # (b, 4, 64, 64)
|
644 |
input = noise # (b, 4, 64, 64)
|
645 |
|
646 |
captions = [args.prompt] * sample_num
|
|
|
679 |
segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) # (256, 256)
|
680 |
segmentation_mask = filter_segmentation_mask(segmentation_mask) # (256, 256)
|
681 |
|
682 |
+
segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(size//2, size//2), mode='nearest') # (b, 1, 256, 256)
|
683 |
segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (b, 1, 256, 256)
|
684 |
print(f'{colored("[√]", "green")} Character-level segmentation_mask: {segmentation_mask.shape}.')
|
685 |
|
686 |
+
feature_mask = torch.ones(sample_num, 1, size//8, size//8).to('cuda') # (b, 1, 64, 64)
|
687 |
+
masked_image = torch.zeros(sample_num, 3, size, size).to('cuda') # (b, 3, 512, 512)
|
688 |
masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64)
|
689 |
masked_feature = masked_feature * vae.config.scaling_factor # (b, 4, 64, 64)
|
690 |
|
|
|
705 |
|
706 |
image_pil = None
|
707 |
segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
|
708 |
+
character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((size,size))
|
709 |
character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
|
710 |
+
character_mask_highlight_pil = character_mask_highlight_pil.resize((size, size))
|
711 |
caption_pil = make_caption_pil(args.font_path, captions)
|
712 |
|
713 |
# save pred_img
|
|
|
718 |
image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
|
719 |
pred_image_list.append(image)
|
720 |
|
721 |
+
blank_pil = combine_image(args, size, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
|
722 |
|
723 |
+
intermediate_result = Image.new('RGB', (size*3, size))
|
724 |
intermediate_result.paste(orig_template_image, (0, 0))
|
725 |
+
intermediate_result.paste(character_mask_pil, (size, 0))
|
726 |
+
intermediate_result.paste(character_mask_highlight_pil, (size*2, 0))
|
727 |
|
728 |
return blank_pil, intermediate_result
|
729 |
|
730 |
|
731 |
+
def text_inpainting(prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch, version):
|
732 |
+
|
733 |
+
if version == 'Stable Diffusion v2.1':
|
734 |
+
vae = vae21
|
735 |
+
unet = unet21
|
736 |
+
text_encoder = text_encoder21
|
737 |
+
tokenizer = tokenizer21
|
738 |
+
scheduler = scheduler21
|
739 |
+
slider_batch = min(slider_batch, 2)
|
740 |
+
size = 768
|
741 |
+
elif version == 'Stable Diffusion v1.5':
|
742 |
+
vae = vae15
|
743 |
+
unet = unet15
|
744 |
+
text_encoder = text_encoder15
|
745 |
+
tokenizer = tokenizer15
|
746 |
+
scheduler = scheduler15
|
747 |
+
size = 512
|
748 |
+
else:
|
749 |
+
assert False, 'Version Not Found'
|
750 |
|
751 |
if has_chinese_char(prompt):
|
752 |
print('trigger')
|
|
|
763 |
set_seed(seed)
|
764 |
scheduler.set_timesteps(slider_step)
|
765 |
|
766 |
+
noise = torch.randn((sample_num, 4, size//8, size//8)).to("cuda") # (b, 4, 64, 64)
|
767 |
input = noise # (b, 4, 64, 64)
|
768 |
|
769 |
captions = [args.prompt] * sample_num
|
|
|
783 |
encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768)
|
784 |
print(f'{colored("[√]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.')
|
785 |
|
786 |
+
mask_image = cv2.resize(mask_image, (size,size))
|
787 |
# mask_image = mask_image.resize((512,512)).convert('RGB')
|
788 |
text_mask = np.array(mask_image)
|
789 |
threshold = 128
|
|
|
796 |
|
797 |
segmentation_mask = segmentation_mask.max(1)[1].squeeze(0)
|
798 |
segmentation_mask = filter_segmentation_mask(segmentation_mask)
|
799 |
+
segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(size//2, size//2), mode='nearest')
|
800 |
|
801 |
+
image_mask = transform_mask_pil(mask_image, size)
|
802 |
image_mask = torch.from_numpy(image_mask).cuda().unsqueeze(0).unsqueeze(0)
|
803 |
|
804 |
+
orig_image = orig_image.convert('RGB').resize((size,size))
|
805 |
image = orig_image
|
806 |
image_tensor = to_tensor(image).unsqueeze(0).cuda().sub_(0.5).div_(0.5)
|
807 |
masked_image = image_tensor * (1-image_mask)
|
808 |
masked_feature = vae.encode(masked_image).latent_dist.sample().repeat(sample_num, 1, 1, 1)
|
809 |
masked_feature = masked_feature * vae.config.scaling_factor
|
810 |
|
811 |
+
image_mask = torch.nn.functional.interpolate(image_mask, size=(size//2, size//2), mode='nearest').repeat(sample_num, 1, 1, 1)
|
812 |
segmentation_mask = segmentation_mask * image_mask
|
813 |
+
feature_mask = torch.nn.functional.interpolate(image_mask, size=(size//8, size//8), mode='nearest')
|
814 |
|
815 |
# diffusion process
|
816 |
intermediate_images = []
|
|
|
831 |
segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy()
|
832 |
character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512))
|
833 |
character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask)
|
834 |
+
character_mask_highlight_pil = character_mask_highlight_pil.resize((size, size))
|
835 |
caption_pil = make_caption_pil(args.font_path, captions)
|
836 |
|
837 |
# save pred_img
|
|
|
851 |
character_mask_highlight_pil.save('character_mask_highlight_pil.png')
|
852 |
|
853 |
|
854 |
+
blank_pil = combine_image(args, size, None, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil)
|
855 |
|
856 |
|
857 |
background = orig_image.resize((512, 512))
|
|
|
890 |
We propose <b>TextDiffuser</b>, a flexible and controllable framework to generate images with visually appealing text that is coherent with backgrounds.
|
891 |
Main features include: (a) <b><font color="#A52A2A">Text-to-Image</font></b>: The user provides a prompt and encloses the keywords with single quotes (e.g., a text image of ‘hello’). The model first determines the layout of the keywords and then draws the image based on the layout and prompt. (b) <b><font color="#A52A2A">Text-to-Image with Templates</font></b>: The user provides a prompt and a template image containing text, which can be a printed, handwritten, or scene text image. These template images can be used to determine the layout of the characters. (c) <b><font color="#A52A2A">Text Inpainting</font></b>: The user provides an image and specifies the region to be modified along with the desired text content. The model is able to modify the original text or add text to areas without text.
|
892 |
</h2>
|
893 |
+
<h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
|
894 |
+
🔥 <b>News</b>: We further trained TextDiffuser based on <b>Stable Diffusion v2.1</b> pre-trained model, enlarging the resolution from 512x512 to <b>768x768</b> to enhance the legibility of small text. Additionally, we fine-tuned the model with images with <b>high aesthetical score</b>, enabling generating images with richer details.
|
895 |
+
</h2>
|
896 |
+
|
897 |
+
|
898 |
<img src="file/images/huggingface_blank.jpg" alt="textdiffuser">
|
899 |
</div>
|
900 |
""")
|
|
|
903 |
with gr.Row():
|
904 |
with gr.Column(scale=1):
|
905 |
prompt = gr.Textbox(label="Input your prompt here. Please enclose keywords with 'single quotes', you may refer to the examples below. The current version only supports input in English characters.", placeholder="Placeholder 'Team' hat")
|
906 |
+
radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v2.1")
|
907 |
slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
|
908 |
slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
|
909 |
+
slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled. Maximum number is set to 【2】 for SD v2.1 to avoid OOM.")
|
910 |
# slider_seed = gr.Slider(minimum=1, maximum=10000, label="Seed", randomize=True)
|
911 |
button = gr.Button("Generate")
|
912 |
|
|
|
922 |
[
|
923 |
["Distinguished poster of 'SPIDERMAN'. Trending on ArtStation and Pixiv. A vibrant digital oil painting. A highly detailed fantasy character illustration by Wayne Reynolds and Charles Monet and Gustave Dore and Carl Critchlow and Bram Sels"],
|
924 |
["A detailed portrait of a fox guardian with a shield with 'Kung Fu' written on it, by victo ngai and justin gerard, digital art, realistic painting, very detailed, fantasy, high definition, cinematic light, dnd, trending on artstation"],
|
925 |
+
["portrait of a 'dragon', concept art, sumi - e style, intricate linework, green smoke, artstation, trending, highly detailed, smooth, focus, art by yoji shinkawa,"],
|
926 |
["elderly woman dressed in extremely colorful clothes with many strange patterns posing for a high fashion photoshoot of 'FASHION', haute couture, golden hour, artstation, by J. C. Leyendecker and Peter Paul Rubens"],
|
927 |
["epic digital art of a luxury yacht named 'Time Machine' driving through very dark hard edged city towers from tron movie, faint tall mountains in background, wlop, pixiv"],
|
928 |
["A poster of 'Adventurer'. A beautiful so tall boy with big eyes and small nose is in the jungle, he wears normal clothes and shows his full length, which we see from the front, unreal engine, cozy indoor lighting, artstation, detailed"],
|
|
|
947 |
examples_per_page=100
|
948 |
)
|
949 |
|
950 |
+
button.click(text_to_image, inputs=[prompt,slider_step,slider_guidance,slider_batch,radio], outputs=[output,intermediate_results])
|
951 |
|
952 |
with gr.Tab("Text-to-Image-with-Template"):
|
953 |
with gr.Row():
|
954 |
with gr.Column(scale=1):
|
955 |
prompt = gr.Textbox(label='Input your prompt here.')
|
956 |
template_image = gr.Image(label='Template image', type="pil")
|
957 |
+
radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v2.1")
|
958 |
slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
|
959 |
slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
|
960 |
+
slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled. Maximum number is set to 【2】 for SD v2.1 to avoid OOM.")
|
961 |
# binary = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
|
962 |
binary = gr.Checkbox(label="Binarization", bool=True, info="Whether to binarize the template image? You may need it when using handwritten images as templates.")
|
963 |
button = gr.Button("Generate")
|
|
|
995 |
examples_per_page=100
|
996 |
)
|
997 |
|
998 |
+
button.click(text_to_image_with_template, inputs=[prompt,template_image,slider_step,slider_guidance,slider_batch,binary,radio], outputs=[output,intermediate_results])
|
999 |
|
1000 |
with gr.Tab("Text-Inpainting"):
|
1001 |
with gr.Row():
|
|
|
1004 |
with gr.Row():
|
1005 |
orig_image = gr.Image(label='Original image', type="pil")
|
1006 |
mask_image = gr.Image(label='Mask image', type="numpy")
|
1007 |
+
radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v2.1")
|
1008 |
slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
|
1009 |
slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
|
1010 |
+
slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled. Maximum number is set to 【2】 for SD v2.1 to avoid OOM.")
|
1011 |
button = gr.Button("Generate")
|
1012 |
with gr.Column(scale=1):
|
1013 |
output = gr.Image(label='Generated image')
|
|
|
1042 |
)
|
1043 |
|
1044 |
|
1045 |
+
button.click(text_inpainting, inputs=[prompt,orig_image,mask_image,slider_step,slider_guidance,slider_batch,radio], outputs=[output, intermediate_results])
|
1046 |
|
1047 |
|
1048 |
|