Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import argparse | |
from easydict import EasyDict as edict | |
import yaml | |
import os.path as osp | |
import random | |
import numpy.random as npr | |
import sys | |
import imageio | |
import numpy as np | |
# sys.path.append('./code') | |
sys.path.append('/home/user/app/code') | |
# set up diffvg | |
# os.system('git clone https://github.com/BachiLi/diffvg.git') | |
os.system('git submodule update --init') | |
os.chdir('diffvg') | |
os.system('git submodule update --init --recursive') | |
os.system('python setup.py install --user') | |
sys.path.append("/home/user/.local/lib/python3.8/site-packages/diffvg-0.0.1-py3.8-linux-x86_64.egg") | |
os.chdir('/home/user/app') | |
import torch | |
from diffusers import StableDiffusionPipeline | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = None | |
model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device) | |
from typing import Mapping | |
from tqdm import tqdm | |
import torch | |
from torch.optim.lr_scheduler import LambdaLR | |
import pydiffvg | |
import save_svg | |
from losses import SDSLoss, ToneLoss, ConformalLoss | |
from utils import ( | |
edict_2_dict, | |
update, | |
check_and_create_dir, | |
get_data_augs, | |
save_image, | |
preprocess, | |
learning_rate_decay, | |
combine_word) | |
import warnings | |
TITLE="""<h1 style="font-size: 42px;" align="center">Word-To-Image: Morphing Arabic Text to a Visual Representation</h1>""" | |
DESCRIPTION="""This demo builds on the [Word-As-Image for Semantic Typography](https://wordasimage.github.io/Word-As-Image-Page/) work to support Arabic fonts and morphing whole words and phrases to a visual representation of a semantic concept. This is part of an ongoing effort with the [ARBML](https://arbml.github.io/website/) community to build open-source Arabic tools using machine learning.""" | |
# DESCRIPTION += '\n<p>This demo is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/"> Creative Commons Attribution-ShareAlike 4.0 International License</a>.</p>' | |
DESCRIPTION += '\n<p>Note: it takes about 5 minutes for 500 iterations to generate the final GIF. For faster inference without waiting in queue, you can <a href="https://colab.research.google.com/drive/1wobOAsnLpkIzaRxG5yac8NcV7iCrlycP"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a></p>' | |
if (SPACE_ID := os.getenv('SPACE_ID')) is not None: | |
DESCRIPTION = DESCRIPTION.replace("</p>", " ") | |
DESCRIPTION += f'or <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate the Space"/></a> and upgrade to GPU in settings.</p>' | |
else: | |
DESCRIPTION = DESCRIPTION.replace("either", "") | |
DESCRIPTION += "<img src='https://raw.githubusercontent.com/BKHMSI/Word-As-Image-Ar/main/collage.gif' alt='Example of Outputs'/>" | |
warnings.filterwarnings("ignore") | |
pydiffvg.set_print_timing(False) | |
gamma = 1.0 | |
def set_config(semantic_concept, word, prompt_suffix, font_name, num_steps, seed, dist_loss_weight, pixel_dist_kernel_blur, pixel_dist_sigma, angeles_w): | |
cfg_d = edict() | |
cfg_d.config = "code/config/base.yaml" | |
cfg_d.experiment = "demo" | |
with open(cfg_d.config, 'r') as f: | |
cfg_full = yaml.load(f, Loader=yaml.FullLoader) | |
cfg_key = cfg_d.experiment | |
cfgs = [cfg_d] | |
while cfg_key: | |
cfgs.append(cfg_full[cfg_key]) | |
cfg_key = cfgs[-1].get('parent_config', 'baseline') | |
cfg = edict() | |
for options in reversed(cfgs): | |
update(cfg, options) | |
del cfgs | |
cfg.semantic_concept = semantic_concept | |
cfg.prompt_suffix = prompt_suffix | |
cfg.word = word | |
cfg.optimized_letter = word | |
cfg.font = font_name | |
cfg.seed = int(seed) | |
cfg.num_iter = num_steps | |
cfg.batch_size = 1 | |
cfg.loss.tone.dist_loss_weight = int(dist_loss_weight) | |
cfg.loss.tone.pixel_dist_kernel_blur = int(pixel_dist_kernel_blur) | |
cfg.loss.tone.pixel_dist_sigma = int(pixel_dist_sigma) | |
cfg.loss.conformal.angeles_w = angeles_w | |
cfg.caption = f"a {cfg.semantic_concept}. {cfg.prompt_suffix}" | |
cfg.log_dir = f"output/{cfg.experiment}_{cfg.word}" | |
if cfg.optimized_letter in cfg.word: | |
cfg.optimized_letter = cfg.optimized_letter | |
else: | |
raise gr.Error(f'letter should be in word') | |
# if ' ' in cfg.word: | |
# cfg.optimized_letter = cfg.optimized_letter.replace(' ', '_') | |
cfg.letter = f"{cfg.font}_{cfg.optimized_letter}_scaled" | |
cfg.target = f"code/data/init/{cfg.letter.replace(' ', '_')}" | |
# set experiment dir | |
signature = f"{cfg.letter}_concept_{cfg.semantic_concept}_seed_{cfg.seed}" | |
cfg.experiment_dir = \ | |
osp.join(cfg.log_dir, cfg.font, signature) | |
configfile = osp.join(cfg.experiment_dir, 'config.yaml') | |
# create experiment dir and save config | |
check_and_create_dir(configfile) | |
with open(osp.join(configfile), 'w') as f: | |
yaml.dump(edict_2_dict(cfg), f) | |
if cfg.seed is not None: | |
random.seed(cfg.seed) | |
npr.seed(cfg.seed) | |
torch.manual_seed(cfg.seed) | |
torch.backends.cudnn.benchmark = False | |
else: | |
assert False | |
return cfg | |
def init_shapes(svg_path, trainable: Mapping[str, bool]): | |
svg = f'{svg_path}.svg' | |
canvas_width, canvas_height, shapes_init, shape_groups_init = pydiffvg.svg_to_scene(svg) | |
parameters = edict() | |
# path points | |
if trainable.point: | |
parameters.point = [] | |
for path in shapes_init: | |
path.points.requires_grad = True | |
parameters.point.append(path.points) | |
return shapes_init, shape_groups_init, parameters | |
def run_main_ex(word, semantic_concept, num_steps, seed): | |
prompt_suffix = "minimal flat 2d vector. lineal color. trending on artstation" | |
font_name = "ArefRuqaa" | |
return list(next(run_main_app(semantic_concept, word, prompt_suffix, font_name, num_steps, seed, 100, 201, 30, 0.5, 0))) | |
def run_main_app(semantic_concept, word, prompt_suffix, font_name, num_steps, seed, dist_loss_weight, pixel_dist_kernel_blur, pixel_dist_sigma, angeles_w, example=0): | |
cfg = set_config(semantic_concept, word, prompt_suffix, font_name, num_steps, seed, dist_loss_weight, pixel_dist_kernel_blur, pixel_dist_sigma, angeles_w) | |
pydiffvg.set_use_gpu(torch.cuda.is_available()) | |
print("preprocessing") | |
preprocess(cfg.font, cfg.word, cfg.optimized_letter, cfg.level_of_cc) | |
filename_init = os.path.join("code/data/init/", f"{cfg.font}_{cfg.word}_scaled.svg").replace(" ", "_") | |
if not example: | |
yield gr.update(value=filename_init,visible=True),gr.update(visible=True, label='Initializing'),gr.update(visible=False) | |
sds_loss = SDSLoss(cfg, device, model) | |
h, w = cfg.render_size, cfg.render_size | |
data_augs = get_data_augs(cfg.cut_size) | |
render = pydiffvg.RenderFunction.apply | |
# initialize shape | |
print('initializing shape') | |
shapes, shape_groups, parameters = init_shapes(svg_path=cfg.target, trainable=cfg.trainable) | |
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups) | |
img_init = render(w, h, 2, 2, 0, None, *scene_args) | |
img_init = img_init[:, :, 3:4] * img_init[:, :, :3] + \ | |
torch.ones(img_init.shape[0], img_init.shape[1], 3, device=device) * (1 - img_init[:, :, 3:4]) | |
img_init = img_init[:, :, :3] | |
tone_loss = ToneLoss(cfg) | |
tone_loss.set_image_init(img_init) | |
num_iter = cfg.num_iter | |
pg = [{'params': parameters["point"], 'lr': cfg.lr_base["point"]}] | |
optim = torch.optim.Adam(pg, betas=(0.9, 0.9), eps=1e-6) | |
conformal_loss = ConformalLoss(parameters, device, cfg.optimized_letter, shape_groups) | |
lr_lambda = lambda step: learning_rate_decay(step, cfg.lr.lr_init, cfg.lr.lr_final, num_iter, | |
lr_delay_steps=cfg.lr.lr_delay_steps, | |
lr_delay_mult=cfg.lr.lr_delay_mult) / cfg.lr.lr_init | |
scheduler = LambdaLR(optim, lr_lambda=lr_lambda, last_epoch=-1) # lr.base * lrlambda_f | |
print("start training") | |
# training loop | |
t_range = tqdm(range(num_iter)) | |
gif_frames = [] | |
skip = 5 | |
for step in t_range: | |
optim.zero_grad() | |
# render image | |
scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups) | |
img = render(w, h, 2, 2, step, None, *scene_args) | |
# compose image with white background | |
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=device) * ( | |
1 - img[:, :, 3:4]) | |
img = img[:, :, :3] | |
if step % skip == 0: | |
gif_frames += [img.detach().cpu().numpy()*255] | |
filename = os.path.join(cfg.experiment_dir, "video-svg", f"iter{step:04d}.svg") | |
check_and_create_dir(filename) | |
save_svg.save_svg(filename, w, h, shapes, shape_groups) | |
if not example: | |
yield gr.update(visible=True),gr.update(value=filename, label=f'iters: {step} / {num_iter}', visible=True),gr.update(visible=False) | |
x = img.unsqueeze(0).permute(0, 3, 1, 2) # HWC -> NCHW | |
x = x.repeat(cfg.batch_size, 1, 1, 1) | |
x_aug = data_augs.forward(x) | |
# compute diffusion loss per pixel | |
loss = sds_loss(x_aug) | |
tone_loss_res = tone_loss(x, step) | |
loss = loss + tone_loss_res | |
loss_angles = conformal_loss() | |
loss_angles = cfg.loss.conformal.angeles_w * loss_angles | |
loss = loss + loss_angles | |
loss.backward() | |
optim.step() | |
scheduler.step() | |
filename = os.path.join( | |
cfg.experiment_dir, "output-svg", "output.svg") | |
check_and_create_dir(filename) | |
save_svg.save_svg( | |
filename, w, h, shapes, shape_groups) | |
# combine_word(cfg.word, cfg.optimized_letter, cfg.font, cfg.experiment_dir, device) | |
filename = os.path.join(cfg.experiment_dir, "final.mp4") | |
writer = imageio.get_writer(filename, fps=20) | |
for frame in gif_frames: writer.append_data(frame) | |
writer.close() | |
# imageio.mimsave(filename, np.array(gif_frames).astype(np.uint8)) | |
yield gr.update(value=filename_init,visible=True),gr.update(visible=False),gr.update(value=filename,visible=True) | |
def change_prompt(concept, prompt_suffix): | |
if concept == "": | |
concept = "{concept}" | |
return f"a {concept}. {prompt_suffix}" | |
with gr.Blocks() as demo: | |
gr.HTML(TITLE) | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
word = gr.Text( | |
label='Text', | |
max_lines=1, | |
placeholder= | |
'Enter text. For example: حصان' | |
) | |
semantic_concept = gr.Text( | |
label='Concept', | |
max_lines=1, | |
placeholder= | |
'Enter a semantic concept that you want your text to morph into (in English). For example: horse' | |
) | |
prompt_suffix = gr.Text( | |
label='Prompt Suffix', | |
max_lines=1, | |
value="minimal flat 2d vector. lineal color. trending on artstation" | |
) | |
prompt = gr.Text( | |
label='Prompt', | |
max_lines=1, | |
value="a {concept}. minimal flat 2d vector. lineal color. trending on artstation." | |
) | |
with gr.Row(): | |
with gr.Accordion("Advanced Parameters", open=False, visible=True): | |
seed = gr.Number( | |
label='Seed', | |
value=42 | |
) | |
angeles_w = gr.Number( | |
label='ACAP Deformation Loss Weight', | |
value=0.5 | |
) | |
dist_loss_weight = gr.Number( | |
label='Tone Loss: dist_loss_weight', | |
value=100 | |
) | |
pixel_dist_kernel_blur = gr.Number( | |
label='Tone Loss: pixel_dist_kernel_blur', | |
value=201 | |
) | |
pixel_dist_sigma = gr.Number( | |
label='Tone Loss: pixel_dist_sigma', | |
value=30 | |
) | |
semantic_concept.change(change_prompt, [semantic_concept, prompt_suffix], prompt) | |
prompt_suffix.change(change_prompt, [semantic_concept, prompt_suffix], prompt) | |
num_steps = gr.Slider(label='Optimization Iterations', | |
minimum=0, | |
maximum=500, | |
step=10, | |
value=250) | |
font_name = gr.Text(value=None,visible=False,label="Font Name") | |
def on_select(evt: gr.SelectData): | |
return evt.value | |
font_name.value = "ArefRuqaa" | |
run = gr.Button('Generate') | |
with gr.Column(): | |
result0 = gr.Image(type="filepath", label="Initial Word").style(height=250) | |
result1 = gr.Image(type="filepath", label="Optimization Process").style(height=300) | |
result2 = gr.Video(type="filepath", label="Final Result",visible=False).style(height=300) | |
with gr.Row(): | |
# examples | |
examples = [ | |
["قطة", "Cat", 250, 42], | |
["جمل جميل", "Camel", 250, 42], | |
["كلب", "Dog", 250, 42], | |
["أخطبوط", "Octopus", 250, 42], | |
] | |
demo.queue(max_size=10, concurrency_count=1) | |
gr.Examples(examples=examples, | |
inputs=[ | |
word, | |
semantic_concept, | |
num_steps, | |
seed | |
], | |
outputs=[ | |
result0, | |
result1, | |
result2 | |
], | |
fn=run_main_ex, | |
cache_examples=False) | |
# inputs | |
inputs = [ | |
semantic_concept, | |
word, | |
prompt_suffix, | |
font_name, | |
num_steps, | |
seed, | |
dist_loss_weight, | |
pixel_dist_kernel_blur, | |
pixel_dist_sigma, | |
angeles_w | |
] | |
outputs = [ | |
result0, | |
result1, | |
result2 | |
] | |
run.click(fn=run_main_app, inputs=inputs, outputs=outputs, queue=True) | |
demo.launch(share=False) |