Tzktz's picture
Upload 7664 files
6fc683c verified
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file defines a set of commonly used utility functions.
# ------------------------------------------
import os
import re
import cv2
import math
import shutil
import string
import textwrap
import numpy as np
from PIL import Image, ImageFont, ImageDraw, ImageOps
from typing import *
# define alphabet and alphabet_dic
alphabet = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(aphabet) = 95
alphabet_dic = {}
for index, c in enumerate(alphabet):
alphabet_dic[c] = index + 1 # the index 0 stands for non-character
def transform_mask_pil(mask_root):
"""
This function extracts the mask area and text area from the images.
Args:
mask_root (str): The path of mask image.
* The white area is the unmasked area
* The gray area is the masked area
* The white area is the text area
"""
img = np.array(mask_root)
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
return 1 - (binary.astype(np.float32) / 255)
def transform_mask(mask_root: str):
"""
This function extracts the mask area and text area from the images.
Args:
mask_root (str): The path of mask image.
* The white area is the unmasked area
* The gray area is the masked area
* The white area is the text area
"""
img = cv2.imread(mask_root)
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
return 1 - (binary.astype(np.float32) / 255)
def segmentation_mask_visualization(font_path: str, segmentation_mask: np.array):
"""
This function visualizes the segmentaiton masks with characters.
Args:
font_path (str): The path of font. We recommand to use Arial.ttf
segmentation_mask (np.array): The character-level segmentation mask.
"""
segmentation_mask = cv2.resize(segmentation_mask, (64, 64), interpolation=cv2.INTER_NEAREST)
font = ImageFont.truetype(font_path, 8)
blank = Image.new('RGB', (512,512), (0,0,0))
d = ImageDraw.Draw(blank)
for i in range(64):
for j in range(64):
if int(segmentation_mask[i][j]) == 0 or int(segmentation_mask[i][j])-1 >= len(alphabet):
continue
else:
d.text((j*8, i*8), alphabet[int(segmentation_mask[i][j])-1], font=font, fill=(0, 255, 0))
return blank
def make_caption_pil(font_path: str, captions: List[str]):
"""
This function converts captions into pil images.
Args:
font_path (str): The path of font. We recommand to use Arial.ttf
captions (List[str]): List of captions.
"""
caption_pil_list = []
font = ImageFont.truetype(font_path, 18)
for caption in captions:
border_size = 2
img = Image.new('RGB', (512-4,48-4), (255,255,255))
img = ImageOps.expand(img, border=(border_size, border_size, border_size, border_size), fill=(127, 127, 127))
draw = ImageDraw.Draw(img)
border_size = 2
text = caption
lines = textwrap.wrap(text, width=40)
x, y = 4, 4
line_height = font.getsize('A')[1] + 4
start = 0
for line in lines:
draw.text((x, y+start), line, font=font, fill=(200, 127, 0))
y += line_height
caption_pil_list.append(img)
return caption_pil_list
def filter_segmentation_mask(segmentation_mask: np.array):
"""
This function removes some noisy predictions of segmentation masks.
Args:
segmentation_mask (np.array): The character-level segmentation mask.
"""
segmentation_mask[segmentation_mask==alphabet_dic['-']] = 0
segmentation_mask[segmentation_mask==alphabet_dic[' ']] = 0
return segmentation_mask
def combine_image(args, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
"""
This function combines all the outputs and useful inputs together.
Args:
args (argparse.ArgumentParser): The arguments.
pred_image_list (List): List of predicted images.
image_pil (Image): The original image.
character_mask_pil (Image): The character-level segmentation mask.
character_mask_highlight_pil (Image): The character-level segmentation mask highlighting character regions with green color.
caption_pil_list (List): List of captions.
"""
# # create a "latest" folder to store the results
# if os.path.exists(f'{args.output_dir}/latest'):
# shutil.rmtree(f'{args.output_dir}/latest')
# os.mkdir(f'{args.output_dir}/latest')
# save each predicted image
# os.makedirs(f'{args.output_dir}/{sub_output_dir}', exist_ok=True)
for index, img in enumerate(pred_image_list):
img.save(f'{args.output_dir}/{sub_output_dir}/{index}.jpg')
# img.save(f'{args.output_dir}/latest/{index}.jpg')
length = len(pred_image_list)
lines = math.ceil(length / 3)
blank = Image.new('RGB', (512*3, 512*(lines+1)+48*lines), (0,0,0))
blank.paste(image_pil,(0,0))
blank.paste(character_mask_pil,(512,0))
blank.paste(character_mask_highlight_pil,(512*2,0))
for i in range(length):
row, col = i // 3, i % 3
blank.paste(pred_image_list[i],(512*col,512*(row+1)+48*row))
blank.paste(caption_pil_list[i],(512*col,512*(row+1)+48*row+512))
blank.save(f'{args.output_dir}/{sub_output_dir}/combine.jpg')
# blank.save(f'{args.output_dir}/latest/combine.jpg')
return blank.convert('RGB')
def combine_image_gradio(args, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
"""
This function combines all the outputs and useful inputs together.
Args:
args (argparse.ArgumentParser): The arguments.
pred_image_list (List): List of predicted images.
image_pil (Image): The original image.
character_mask_pil (Image): The character-level segmentation mask.
character_mask_highlight_pil (Image): The character-level segmentation mask highlighting character regions with green color.
caption_pil_list (List): List of captions.
"""
size = len(pred_image_list)
if size == 1:
return pred_image_list[0]
elif size == 2:
blank = Image.new('RGB', (512*2, 512), (0,0,0))
blank.paste(pred_image_list[0],(0,0))
blank.paste(pred_image_list[1],(512,0))
elif size == 3:
blank = Image.new('RGB', (512*3, 512), (0,0,0))
blank.paste(pred_image_list[0],(0,0))
blank.paste(pred_image_list[1],(512,0))
blank.paste(pred_image_list[2],(1024,0))
elif size == 4:
blank = Image.new('RGB', (512*2, 512*2), (0,0,0))
blank.paste(pred_image_list[0],(0,0))
blank.paste(pred_image_list[1],(512,0))
blank.paste(pred_image_list[2],(0,512))
blank.paste(pred_image_list[3],(512,512))
return blank
def get_width(font_path, text):
"""
This function calculates the width of the text.
Args:
font_path (str): user prompt.
text (str): user prompt.
"""
font = ImageFont.truetype(font_path, 24)
width, _ = font.getsize(text)
return width
def get_key_words(text: str):
"""
This function detect keywords (enclosed by quotes) from user prompts. The keywords are used to guide the layout generation.
Args:
text (str): user prompt.
"""
words = []
text = text
matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
if matches:
for match in matches:
words.extend(match.split())
if len(words) >= 8:
return []
return words
def adjust_overlap_box(box_output, current_index):
"""
This function adjust the overlapping boxes.
Args:
box_output (List): List of predicted boxes.
current_index (int): the index of current box.
"""
if current_index == 0:
return box_output
else:
# judge whether it contains overlap with the last output
last_box = box_output[0, current_index-1, :]
xmin_last, ymin_last, xmax_last, ymax_last = last_box
current_box = box_output[0, current_index, :]
xmin, ymin, xmax, ymax = current_box
if xmin_last <= xmin <= xmax_last and ymin_last <= ymin <= ymax_last:
print('adjust overlapping')
distance_x = xmax_last - xmin
distance_y = ymax_last - ymin
if distance_x <= distance_y:
# avoid overlap
new_x_min = xmax_last + 0.025
new_x_max = xmax - xmin + xmax_last + 0.025
box_output[0,current_index,0] = new_x_min
box_output[0,current_index,2] = new_x_max
else:
new_y_min = ymax_last + 0.025
new_y_max = ymax - ymin + ymax_last + 0.025
box_output[0,current_index,1] = new_y_min
box_output[0,current_index,3] = new_y_max
elif xmin_last <= xmin <= xmax_last and ymin_last <= ymax <= ymax_last:
print('adjust overlapping')
new_x_min = xmax_last + 0.05
new_x_max = xmax - xmin + xmax_last + 0.05
box_output[0,current_index,0] = new_x_min
box_output[0,current_index,2] = new_x_max
return box_output
def shrink_box(box, scale_factor = 0.9):
"""
This function shrinks the box.
Args:
box (List): List of predicted boxes.
scale_factor (float): The scale factor of shrinking.
"""
x1, y1, x2, y2 = box
x1_new = x1 + (x2 - x1) * (1 - scale_factor) / 2
y1_new = y1 + (y2 - y1) * (1 - scale_factor) / 2
x2_new = x2 - (x2 - x1) * (1 - scale_factor) / 2
y2_new = y2 - (y2 - y1) * (1 - scale_factor) / 2
return (x1_new, y1_new, x2_new, y2_new)
def adjust_font_size(args, width, height, draw, text):
"""
This function adjusts the font size.
Args:
args (argparse.ArgumentParser): The arguments.
width (int): The width of the text.
height (int): The height of the text.
draw (ImageDraw): The ImageDraw object.
text (str): The text.
"""
size_start = height
while True:
font = ImageFont.truetype(args.font_path, size_start)
text_width, _ = draw.textsize(text, font=font)
if text_width >= width:
size_start = size_start - 1
else:
return size_start
def inpainting_merge_image(original_image, mask_image, inpainting_image):
"""
This function merges the original image, mask image and inpainting image.
Args:
original_image (PIL.Image): The original image.
mask_image (PIL.Image): The mask images.
inpainting_image (PIL.Image): The inpainting images.
"""
original_image = original_image.resize((512, 512))
mask_image = mask_image.resize((512, 512))
inpainting_image = inpainting_image.resize((512, 512))
mask_image.convert('L')
threshold = 250
table = []
for i in range(256):
if i < threshold:
table.append(1)
else:
table.append(0)
mask_image = mask_image.point(table, "1")
merged_image = Image.composite(inpainting_image, original_image, mask_image)
return merged_image