Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import argparse | |
from easydict import EasyDict as edict | |
import yaml | |
import os.path as osp | |
import random | |
import numpy.random as npr | |
import sys | |
import imageio | |
import numpy as np | |
# sys.path.append('./code') | |
sys.path.append('/home/user/app/code') | |
# set up diffvg | |
# os.system('git clone https://github.com/BachiLi/diffvg.git') | |
os.system('git submodule update --init') | |
os.chdir('diffvg') | |
os.system('git submodule update --init --recursive') | |
os.system('python setup.py install --user') | |
sys.path.append("/home/user/.local/lib/python3.10/site-packages/diffvg-0.0.1-py3.10-linux-x86_64.egg") | |
os.chdir('/home/user/app') | |
# os.system('bash code/data/fonts/arabic/download_fonts.sh') | |
import torch | |
from diffusers import StableDiffusionPipeline | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = None | |
model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device) | |
from typing import Mapping | |
from tqdm import tqdm | |
import torch | |
from torch.optim.lr_scheduler import LambdaLR | |
import pydiffvg | |
import save_svg | |
from losses import SDSLoss, ToneLoss, ConformalLoss | |
from utils import ( | |
edict_2_dict, | |
update, | |
check_and_create_dir, | |
get_data_augs, | |
save_image, | |
preprocess, | |
learning_rate_decay, | |
combine_word) | |
import warnings | |
TITLE="""<h1 style="font-size: 42px;" align="center">Font-To-Sketch: Morphing Any Font to a Visual Representation</h1>""" | |
DESCRIPTION="""This demo builds on the [Word-As-Image for Semantic Typography](https://wordasimage.github.io/Word-As-Image-Page/) work to support **any** font and morphing whole words and phrases to a visual representation of a given semantic concept. This project started as part of an ongoing effort with the [ARBML](https://arbml.github.io/website/) community to build open-source Arabic tools using machine learning.""" | |
DESCRIPTION+="""The demo currently supports the following scripts: **Arabic**, **Simplified Chinese**, **Cyrillic**, **Greek**, **Latin**, **Tamil**. Therefore you can write the text in any language using those scripts. To add support for more fonts please check the [GitHub ReadMe](https://raw.githubusercontent.com/BKHMSI/Font-To-Sketch).""" | |
# DESCRIPTION += '\n<p>This demo is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/"> Creative Commons Attribution-ShareAlike 4.0 International License</a>.</p>' | |
DESCRIPTION += '\n<p>Note: it takes about 5 minutes for 250 iterations to generate the final GIF. For faster inference without waiting in queue, you can <a href="https://colab.research.google.com/drive/1wobOAsnLpkIzaRxG5yac8NcV7iCrlycP"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a></p>' | |
if (SPACE_ID := os.getenv('SPACE_ID')) is not None: | |
DESCRIPTION = DESCRIPTION.replace("</p>", " ") | |
DESCRIPTION += f'or <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate the Space"/></a> and upgrade to GPU in settings.</p>' | |
else: | |
DESCRIPTION = DESCRIPTION.replace("either", "") | |
DESCRIPTION += "<div align='text-align:center; width:100%'><img src='https://raw.githubusercontent.com/BKHMSI/Font-To-Sketch/main/images/languages_car.gif' alt='Example of Outputs'/></div>" | |
ARABIC_EX = "<img src='https://raw.githubusercontent.com/BKHMSI/Font-To-Sketch/main/images/animals_7.gif' alt='Example of Outputs'/>" | |
warnings.filterwarnings("ignore") | |
pydiffvg.set_print_timing(False) | |
gamma = 1.0 | |
def read_font_names(all_scripts): | |
font_names = [] | |
font_dict = {} | |
for script in all_scripts: | |
script = script.lower() | |
font_dict[script] = [] | |
if script == "simplified chinese": | |
script = "chinese" | |
path = f"code/data/fonts/{script.lower()}/font_names.txt" | |
if not os.path.exists(path): | |
font_dict[script] = [x[:-4] for x in os.listdir(os.path.dirname(path)) if "ttf" in x] | |
else: | |
with open(path, 'r', encoding="utf-8") as fin: | |
font_dict[script] = [line.strip() for line in fin.readlines()] | |
font_names.extend([f"{script.capitalize()}: {f}" for f in font_dict[script]]) | |
return ["Default"] + sorted(font_names), font_dict | |
def set_config(semantic_concept, word, script, prompt_suffix, font_name, num_steps, seed, is_seed_rand, dist_loss_weight, pixel_dist_kernel_blur, pixel_dist_sigma, angeles_w): | |
cfg_d = edict() | |
cfg_d.config = "code/config/base.yaml" | |
cfg_d.experiment = "default" | |
with open(cfg_d.config, 'r') as f: | |
cfg_full = yaml.load(f, Loader=yaml.FullLoader) | |
cfg_key = cfg_d.experiment | |
cfgs = [cfg_d] | |
while cfg_key: | |
cfgs.append(cfg_full[cfg_key]) | |
cfg_key = cfgs[-1].get('parent_config', 'baseline') | |
cfg = edict() | |
for options in reversed(cfgs): | |
update(cfg, options) | |
del cfgs | |
cfg.semantic_concept = semantic_concept | |
cfg.prompt_suffix = prompt_suffix | |
cfg.word = word | |
cfg.optimized_letter = word | |
cfg.script = script.lower() | |
cfg.font = font_name | |
if is_seed_rand == "Random Seed": | |
cfg.seed = np.random.randint(10000) | |
else: | |
cfg.seed = int(seed) | |
cfg.num_iter = num_steps | |
cfg.batch_size = 1 | |
cfg.loss.tone.dist_loss_weight = int(dist_loss_weight) | |
cfg.loss.tone.pixel_dist_kernel_blur = int(pixel_dist_kernel_blur) | |
cfg.loss.tone.pixel_dist_sigma = int(pixel_dist_sigma) | |
cfg.loss.conformal.angeles_w = angeles_w | |
cfg.caption = f"a {cfg.semantic_concept}. {cfg.prompt_suffix}" | |
cfg.log_dir = f"{cfg.script}" | |
if cfg.optimized_letter in cfg.word: | |
cfg.optimized_letter = cfg.optimized_letter | |
else: | |
raise gr.Error(f'letter should be in word') | |
# if ' ' in cfg.word: | |
# cfg.optimized_letter = cfg.optimized_letter.replace(' ', '_') | |
cfg.letter = f"{cfg.font}_{cfg.optimized_letter}_scaled" | |
cfg.target = f"code/data/init/{cfg.letter}" | |
if ' ' in cfg.target: | |
cfg.target = cfg.target.replace(' ', '_') | |
# set experiment dir | |
signature = f"{cfg.word}_{cfg.semantic_concept}_{cfg.seed}" | |
cfg.experiment_dir = osp.join(cfg.log_dir, cfg.font, signature) | |
configfile = osp.join(cfg.experiment_dir, 'config.yaml') | |
# create experiment dir and save config | |
check_and_create_dir(configfile) | |
with open(osp.join(configfile), 'w') as f: | |
yaml.dump(edict_2_dict(cfg), f) | |
if cfg.seed is not None: | |
random.seed(cfg.seed) | |
npr.seed(cfg.seed) | |
torch.manual_seed(cfg.seed) | |
torch.backends.cudnn.benchmark = False | |
else: | |
assert False | |
return cfg | |
def init_shapes(svg_path, trainable: Mapping[str, bool]): | |
svg = f'{svg_path}.svg' | |
canvas_width, canvas_height, shapes_init, shape_groups_init = pydiffvg.svg_to_scene(svg) | |
parameters = edict() | |
# path points | |
if trainable.point: | |
parameters.point = [] | |
for path in shapes_init: | |
path.points.requires_grad = True | |
parameters.point.append(path.points) | |
return shapes_init, shape_groups_init, parameters | |
def run_main_ex(word, semantic_concept, script, font_selector, num_steps, seed): | |
prompt_suffix = "minimal flat 2d vector. lineal color. trending on artstation" | |
is_seed_rand = "Use Set Value" | |
return list(next(run_main_app(semantic_concept, word, script, font_selector, prompt_suffix, num_steps, seed, is_seed_rand, 100, 201, 30, 0.5, 1))) | |
def run_main_app(semantic_concept, word, script, font_selected, prompt_suffix, num_steps, seed, is_seed_rand, dist_loss_weight, pixel_dist_kernel_blur, pixel_dist_sigma, angeles_w, example=0): | |
if font_selected.lower() != "default": | |
font_key, font_val = font_selected.split(":") | |
font_key = font_key.lower().strip() | |
font_val = font_val.strip() | |
else: | |
font_key = "default" | |
font_val = "default" | |
if script.lower() == "simplified chinese": | |
script = "chinese" | |
if font_key != script.lower(): | |
print(f"Setting font to {script} default font") | |
font_key = script.lower() | |
if len(font_dict[font_key]) == 1: | |
font_name = font_dict[font_key][0] | |
else: | |
if font_val == "default": | |
font_name = "00" | |
else: | |
font_name = str(font_dict[font_key].index(font_val)).zfill(2) | |
print(font_name) | |
cfg = set_config(semantic_concept, word, script, prompt_suffix, font_name, num_steps, seed, is_seed_rand, dist_loss_weight, pixel_dist_kernel_blur, pixel_dist_sigma, angeles_w) | |
pydiffvg.set_use_gpu(torch.cuda.is_available()) | |
print("preprocessing") | |
preprocess(cfg.font, cfg.word, cfg.optimized_letter, cfg.script, cfg.level_of_cc) | |
filename_init = os.path.join("code/data/init/", f"{cfg.font}_{cfg.word}_scaled.svg").replace(" ", "_") | |
if not example: | |
yield gr.update(value=filename_init,visible=True),gr.update(visible=True, label='Initializing'),gr.update(visible=False),gr.update(value=cfg.caption,visible=True),gr.update(value=cfg.seed,visible=True) | |
sds_loss = SDSLoss(cfg, device, model) | |
h, w = cfg.render_size, cfg.render_size | |
data_augs = get_data_augs(cfg.cut_size) | |
render = pydiffvg.RenderFunction.apply | |
# initialize shape | |
print('initializing shape') | |
shapes, shape_groups, parameters = init_shapes(svg_path=cfg.target, trainable=cfg.trainable) | |
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups) | |
img_init = render(w, h, 2, 2, 0, None, *scene_args) | |
img_init = img_init[:, :, 3:4] * img_init[:, :, :3] + \ | |
torch.ones(img_init.shape[0], img_init.shape[1], 3, device=device) * (1 - img_init[:, :, 3:4]) | |
img_init = img_init[:, :, :3] | |
tone_loss = ToneLoss(cfg) | |
tone_loss.set_image_init(img_init) | |
num_iter = cfg.num_iter | |
pg = [{'params': parameters["point"], 'lr': cfg.lr_base["point"]}] | |
optim = torch.optim.Adam(pg, betas=(0.9, 0.9), eps=1e-6) | |
conformal_loss = ConformalLoss(parameters, device, cfg.optimized_letter, shape_groups) | |
lr_lambda = lambda step: learning_rate_decay(step, cfg.lr.lr_init, cfg.lr.lr_final, num_iter, | |
lr_delay_steps=cfg.lr.lr_delay_steps, | |
lr_delay_mult=cfg.lr.lr_delay_mult) / cfg.lr.lr_init | |
scheduler = LambdaLR(optim, lr_lambda=lr_lambda, last_epoch=-1) # lr.base * lrlambda_f | |
print("start training") | |
# training loop | |
t_range = tqdm(range(num_iter)) | |
gif_frames = [] | |
skip = 5 | |
for step in t_range: | |
optim.zero_grad() | |
# render image | |
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups) | |
img = render(w, h, 2, 2, step, None, *scene_args) | |
# compose image with white background | |
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=device) * (1 - img[:, :, 3:4]) | |
img = img[:, :, :3] | |
filename = os.path.join(cfg.experiment_dir, "video-svg", f"iter{step:04d}.svg") | |
check_and_create_dir(filename) | |
save_svg.save_svg(filename, w, h, shapes, shape_groups) | |
if not example: | |
yield gr.update(visible=True),gr.update(value=filename, label=f'iters: {step} / {num_iter}', visible=True),gr.update(visible=False),gr.update(value=cfg.caption,visible=True),gr.update(value=cfg.seed,visible=True) | |
x = img.unsqueeze(0).permute(0, 3, 1, 2) # HWC -> NCHW | |
if step % skip == 0: | |
img_tensor = x.detach().cpu() | |
img_tensor = torch.nn.functional.interpolate(img_tensor, size=(300, 300), mode='bilinear', align_corners=False) | |
img_tensor = img_tensor.permute(0, 2, 3, 1).squeeze(0) | |
gif_frames += [img_tensor.numpy()] | |
x = x.repeat(cfg.batch_size, 1, 1, 1) | |
x_aug = data_augs.forward(x) | |
# compute diffusion loss per pixel | |
loss = sds_loss(x_aug) | |
tone_loss_res = tone_loss(x, step) | |
loss = loss + tone_loss_res | |
loss_angles = conformal_loss() | |
loss_angles = cfg.loss.conformal.angeles_w * loss_angles | |
loss = loss + loss_angles | |
loss.backward() | |
optim.step() | |
scheduler.step() | |
filename = os.path.join(cfg.experiment_dir, "output-svg", "output.svg") | |
check_and_create_dir(filename) | |
save_svg.save_svg(filename, w, h, shapes, shape_groups) | |
filename = os.path.join(cfg.experiment_dir, "final.gif") | |
# writer = imageio.get_writer(filename, fps=20) | |
# for frame in gif_frames: writer.append_data(frame) | |
# writer.close() | |
gif_frames = np.array(gif_frames) * 255 | |
imageio.mimsave(filename, gif_frames.astype(np.uint8)) | |
# imageio.mimsave(filename, np.array(gif_frames)) | |
yield gr.update(value=filename_init,visible=True),gr.update(visible=False),gr.update(value=filename,visible=True),gr.update(value=cfg.caption,visible=True),gr.update(value=cfg.seed,visible=True) | |
all_scripts = ["Arabic", "Simplified Chinese", "Cyrillic", "Greek", "Latin", "Tamil"] | |
with gr.Blocks() as demo: | |
gr.HTML(TITLE) | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
word = gr.Text( | |
label='Text', | |
max_lines=1, | |
placeholder= | |
'Enter text. For example: قطة|猫|γάτα|кошка|பூனை|Cat' | |
) | |
semantic_concept = gr.Text( | |
label='Concept', | |
max_lines=1, | |
placeholder= | |
'Enter a semantic concept that you want your text to morph into (in English). For example: cat' | |
) | |
with gr.Row(): | |
script_selector = gr.Dropdown( | |
all_scripts, | |
value="Arabic", | |
label="Font Script" | |
) | |
font_names, font_dict = read_font_names(all_scripts) | |
font_selector = gr.Dropdown( | |
font_names, | |
value=font_names[0], | |
label="Font Name", | |
visible=True, | |
) | |
prompt_suffix = gr.Text( | |
label='Prompt Suffix', | |
max_lines=1, | |
value="minimal flat 2d vector. lineal color. trending on artstation" | |
) | |
with gr.Row(): | |
with gr.Accordion("Advanced Parameters", open=False, visible=True): | |
with gr.Row(): | |
is_seed_rand = gr.Radio(["Random Seed", "Use Set Value"], label="Use Random Seed", value="Random Seed") | |
seed = gr.Number( | |
label='Seed (Set Value)', | |
value=42 | |
) | |
angeles_w = gr.Number( | |
label='ACAP Deformation Loss Weight', | |
value=0.5 | |
) | |
dist_loss_weight = gr.Number( | |
label='Tone Loss: dist_loss_weight', | |
value=100 | |
) | |
pixel_dist_kernel_blur = gr.Number( | |
label='Tone Loss: pixel_dist_kernel_blur', | |
value=201 | |
) | |
pixel_dist_sigma = gr.Number( | |
label='Tone Loss: pixel_dist_sigma', | |
value=30 | |
) | |
num_steps = gr.Slider(label='Optimization Iterations', | |
minimum=0, | |
maximum=500, | |
step=10, | |
value=250) | |
run = gr.Button('Generate') | |
with gr.Column(): | |
with gr.Row(): | |
prompt = gr.Text( | |
label='Prompt', | |
visible=False, | |
max_lines=1, | |
interactive=False, | |
) | |
seed_value = gr.Text( | |
label='Seed Used', | |
visible=False, | |
max_lines=1, | |
interactive=False, | |
) | |
result0 = gr.Image(type="filepath", label="Initial Word").style(height=250) | |
result1 = gr.Image(type="filepath", label="Optimization Process").style(height=300) | |
result2 = gr.Image(type="filepath", label="Final GIF",visible=False).style(height=300) | |
with gr.Row(): | |
# examples | |
examples = [ | |
["موسيقى", "music", "Arabic", "Arabic: حر طويل", 250, 42], | |
["音乐", "music", "Simplified Chinese", "Chinese: ZhiMangXing-Regular", 250, 42], | |
["μουσική", "music", "Greek", "Greek: EBGaramond-Regular", 250, 42], | |
["музыка", "music", "Cyrillic", "Cyrillic: Geologica_Auto-Regular", 250, 42], | |
] | |
demo.queue(max_size=10, concurrency_count=2) | |
gr.Examples(examples=examples, | |
inputs=[ | |
word, | |
semantic_concept, | |
script_selector, | |
font_selector, | |
num_steps, | |
seed | |
], | |
outputs=[ | |
result0, | |
result1, | |
result2, | |
prompt, | |
seed_value | |
], | |
fn=run_main_ex, | |
cache_examples=True) | |
gr.Markdown(ARABIC_EX) | |
# inputs | |
inputs = [ | |
semantic_concept, | |
word, | |
script_selector, | |
font_selector, | |
prompt_suffix, | |
num_steps, | |
seed, | |
is_seed_rand, | |
dist_loss_weight, | |
pixel_dist_kernel_blur, | |
pixel_dist_sigma, | |
angeles_w | |
] | |
outputs = [ | |
result0, | |
result1, | |
result2, | |
prompt, | |
seed_value | |
] | |
run.click(fn=run_main_app, inputs=inputs, outputs=outputs, queue=True) | |
demo.launch(share=False) |