import os import json import torch import random import numpy as np COLORS = { 'brown': [165, 42, 42], 'red': [255, 0, 0], 'pink': [253, 108, 158], 'orange': [255, 165, 0], 'yellow': [255, 255, 0], 'purple': [128, 0, 128], 'green': [0, 128, 0], 'blue': [0, 0, 255], 'white': [255, 255, 255], 'gray': [128, 128, 128], 'black': [0, 0, 0], } def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) def hex_to_rgb(hex_string, return_nearest_color=False, device='cuda'): r""" Covert Hex triplet to RGB triplet. """ # Remove '#' symbol if present hex_string = hex_string.lstrip('#') # Convert hex values to integers red = int(hex_string[0:2], 16) green = int(hex_string[2:4], 16) blue = int(hex_string[4:6], 16) rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255. if return_nearest_color: nearest_color = find_nearest_color(rgb) return rgb.to(device), nearest_color return rgb.to(device) def find_nearest_color(rgb): r""" Find the nearest neighbor color given the RGB value. """ if isinstance(rgb, list) or isinstance(rgb, tuple): rgb = torch.FloatTensor(rgb)[None, :, None, None]/255. color_distance = torch.FloatTensor([np.linalg.norm( rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()]) nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()] return nearest_color def font2style(font, device='cuda'): r""" Convert the font name to the style name. """ return {'mirza': 'Claud Monet, impressionism, oil on canvas', 'roboto': 'Ukiyoe', 'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq', 'sofia': 'Pop Art, masterpiece, andy warhol', 'slabo': 'Vincent Van Gogh', 'inconsolata': 'Pixel Art, 8 bits, 16 bits', 'ubuntu': 'Rembrandt', 'Monoton': 'neon art, colorful light, highly details, octane render', 'Akronim': 'Abstract Cubism, Pablo Picasso', }[font] def parse_json(json_str, device): r""" Convert the JSON string to attributes. """ # initialze region-base attributes. base_text_prompt = '' style_text_prompts = [] footnote_text_prompts = [] footnote_target_tokens = [] color_text_prompts = [] color_rgbs = [] color_names = [] size_text_prompts_and_sizes = [] # parse the attributes from JSON. prev_style = None prev_color_rgb = None use_grad_guidance = False for span in json_str['ops']: text_prompt = span['insert'].rstrip('\n') base_text_prompt += span['insert'].rstrip('\n') if text_prompt == ' ': continue if 'attributes' in span: if 'font' in span['attributes']: style = font2style(span['attributes']['font']) if prev_style == style: prev_text_prompt = style_text_prompts[-1].split('in the style of')[ 0] style_text_prompts[-1] = prev_text_prompt + \ ' ' + text_prompt + f' in the style of {style}' else: style_text_prompts.append( text_prompt + f' in the style of {style}') prev_style = style else: prev_style = None if 'link' in span['attributes']: footnote_text_prompts.append(span['attributes']['link']) footnote_target_tokens.append(text_prompt) font_size = 1 if 'size' in span['attributes'] and 'strike' not in span['attributes']: font_size = float(span['attributes']['size'][:-2])/3. elif 'size' in span['attributes'] and 'strike' in span['attributes']: font_size = -float(span['attributes']['size'][:-2])/3. elif 'size' not in span['attributes'] and 'strike' not in span['attributes']: font_size = 1 if 'color' in span['attributes']: use_grad_guidance = True color_rgb, nearest_color = hex_to_rgb( span['attributes']['color'], True, device=device) if prev_color_rgb == color_rgb: prev_text_prompt = color_text_prompts[-1] color_text_prompts[-1] = prev_text_prompt + \ ' ' + text_prompt else: color_rgbs.append(color_rgb) color_names.append(nearest_color) color_text_prompts.append(text_prompt) if font_size != 1: size_text_prompts_and_sizes.append([text_prompt, font_size]) return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens, color_text_prompts, color_names): r""" Algorithm 1 in the paper. """ region_text_prompts = [] region_target_token_ids = [] base_tokens = model.tokenizer._tokenize(base_text_prompt) # process the style text prompt for text_prompt in style_text_prompts: region_text_prompts.append(text_prompt) region_target_token_ids.append([]) style_tokens = model.tokenizer._tokenize( text_prompt.split('in the style of')[0]) for style_token in style_tokens: region_target_token_ids[-1].append( base_tokens.index(style_token)+1) # process the complementary text prompt for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens): region_target_token_ids.append([]) region_text_prompts.append(footnote_text_prompt) style_tokens = model.tokenizer._tokenize(text_prompt) for style_token in style_tokens: region_target_token_ids[-1].append( base_tokens.index(style_token)+1) # process the color text prompt for color_text_prompt, color_name in zip(color_text_prompts, color_names): region_target_token_ids.append([]) region_text_prompts.append(color_name+' '+color_text_prompt) style_tokens = model.tokenizer._tokenize(color_text_prompt) for style_token in style_tokens: region_target_token_ids[-1].append( base_tokens.index(style_token)+1) # process the remaining tokens without any attributes region_text_prompts.append(base_text_prompt) region_target_token_ids_all = [ id for ids in region_target_token_ids for id in ids] target_token_ids_rest = [id for id in range( 1, len(base_tokens)+1) if id not in region_target_token_ids_all] region_target_token_ids.append(target_token_ids_rest) region_target_token_ids = [torch.LongTensor( obj_token_id) for obj_token_id in region_target_token_ids] return region_text_prompts, region_target_token_ids, base_tokens def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes): r""" Control the token impact using font sizes. """ word_pos = [] font_sizes = [] for text_prompt, font_size in size_text_prompts_and_sizes: size_tokens = model.tokenizer._tokenize(text_prompt) for size_token in size_tokens: word_pos.append(base_tokens.index(size_token)+1) font_sizes.append(font_size) if len(word_pos) > 0: word_pos = torch.LongTensor(word_pos).to(model.device) font_sizes = torch.FloatTensor(font_sizes).to(model.device) else: word_pos = None font_sizes = None text_format_dict = { 'word_pos': word_pos, 'font_size': font_sizes, } return text_format_dict def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, guidance_start_step=999, color_guidance_weight=1): r""" Control the token impact using font sizes. """ color_target_token_ids = [] for text_prompt in color_text_prompts: color_target_token_ids.append([]) color_tokens = model.tokenizer._tokenize(text_prompt) for color_token in color_tokens: color_target_token_ids[-1].append(base_tokens.index(color_token)+1) color_target_token_ids_all = [ id for ids in color_target_token_ids for id in ids] color_target_token_ids_rest = [id for id in range( 1, len(base_tokens)+1) if id not in color_target_token_ids_all] color_target_token_ids.append(color_target_token_ids_rest) color_target_token_ids = [torch.LongTensor( obj_token_id) for obj_token_id in color_target_token_ids] text_format_dict['target_RGB'] = color_rgbs text_format_dict['guidance_start_step'] = guidance_start_step text_format_dict['color_guidance_weight'] = color_guidance_weight return text_format_dict, color_target_token_ids