import os import re import gradio as gr import torch import torch.nn.functional as F from torch.optim import Adam from torchvision.transforms import transforms as T import clip from tr0n.config import parse_args from tr0n.modules.models.model_stylegan import Model from tr0n.modules.models.loss import AugCosineSimLatent from tr0n.modules.optimizers.sgld import SGLD from bad_words import bad_words device = "cuda" if torch.cuda.is_available() else "cpu" model_modes = { "text": { "checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-text.pth", }, "image": { "checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-image.pth", } } os.environ['TOKENIZERS_PARALLELISM'] = "false" # set config params config = parse_args(is_demo=True) config_vars = vars(config) config_vars["stylegan_gen"] = "sg2-ffhq-1024" config_vars["with_gmm"] = True config_vars["num_mixtures"] = 10 model = Model(config, device, None) model.to(device) model.eval() for p in model.translator.parameters(): p.requires_grad = False loss = AugCosineSimLatent() transforms_image = T.Compose([ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor(), T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) checkpoint_text = torch.hub.load_state_dict_from_url(model_modes["text"]["checkpoint"], map_location="cpu") translator_state_dict_text = checkpoint_text['translator_state_dict'] checkpoint_image = torch.hub.load_state_dict_from_url(model_modes["image"]["checkpoint"], map_location="cpu") translator_state_dict_image = checkpoint_image['translator_state_dict'] # default model.translator.load_state_dict(translator_state_dict_text) css = """ a { display: inline-block; color: black !important; text-decoration: none !important; } #image-gen { height: 256px; width: 256px; margin-left: auto; margin-right: auto; } """ def _slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm*high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high return res def model_mode_text_select(): model.translator.load_state_dict(translator_state_dict_text) def model_mode_image_select(): model.translator.load_state_dict(translator_state_dict_image) def text_to_face_generate(text): if text == "": raise gr.Error("You need to provide to provide a prompt.") for word in bad_words: if re.search(rf"\b{word}\b", text): raise gr.Error("Unsafe content found. Please try again with a different prompt.") text_tok = clip.tokenize([text], truncate=True).to(device) # initialize optimization from the translator's output with torch.no_grad(): target_clip_latent, w_mixture_logits, w_means = model(x=text_tok, x_type='text', return_after_translator=True, no_sample=True) pi = w_mixture_logits.unsqueeze(-1).repeat(1, 1, w_means.shape[-1]) # 1 x num_mixtures x w_dim w = w_means # 1 x num_mixtures x w_dim w.requires_grad = True pi.requires_grad = True optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device) optimizer_pi = Adam((pi,), lr=5e-3) # optimization for _ in range(100): soft_pi = F.softmax(pi, dim=1) w_prime = soft_pi * w w_prime = w_prime.sum(dim=1) _, _, pred_clip_latent, _, _ = model(x=w_prime, x_type='gan_latent', times_augment_pred_image=50) l = loss(target_clip_latent, pred_clip_latent) l.backward() torch.nn.utils.clip_grad_norm_((w,), 1.) torch.nn.utils.clip_grad_norm_((pi,), 1.) optimizer_w.step() optimizer_pi.step() optimizer_w.zero_grad() optimizer_pi.zero_grad() # generate final image with torch.no_grad(): soft_pi = F.softmax(pi, dim=1) w_prime = soft_pi * w w_prime = w_prime.sum(dim=1) _, _, _, _, pred_image_raw = model(x=w_prime, x_type='gan_latent') pred_image = ((pred_image_raw[0]+1.)/2.).cpu() return T.ToPILImage()(pred_image) def face_to_face_interpolate(image1, image2, interp_lambda=0.5): if image1 is None or image2 is None: raise gr.Error("You need to provide two images as input.") image1_pt = transforms_image(image1).to(device) image2_pt = transforms_image(image2).to(device) # initialize optimization from the translator's output with torch.no_grad(): images_pt = torch.stack([image1_pt, image2_pt]) target_clip_latents = model.clip.encode_image(images_pt).detach().float() target_clip_latent = _slerp(interp_lambda, target_clip_latents[0].unsqueeze(0), target_clip_latents[1].unsqueeze(0)) _, _, w = model(x=target_clip_latent, x_type='clip_latent', return_after_translator=True) w.requires_grad = True optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device) # optimization for _ in range(100): _, _, pred_clip_latent, _, _ = model(x=w, x_type='gan_latent', times_augment_pred_image=50) l = loss(target_clip_latent, pred_clip_latent) l.backward() torch.nn.utils.clip_grad_norm_((w,), 1.) optimizer_w.step() optimizer_w.zero_grad() # generate final image with torch.no_grad(): _, _, _, _, pred_image_raw = model(x=w, x_type='gan_latent') pred_image = ((pred_image_raw[0]+1.)/2.).cpu() return T.ToPILImage()(pred_image) examples_text = [ "Muhammad Ali", "Tinker Bell", "A man with glasses, long black hair with sideburns and a goatee", "A child with blue eyes and straight brown hair in the sunshine", "A hairdresser", "A young boy with glasses and an angry face", "Denzel Washington", "A portrait of Angela Merkel", "President Emmanuel Macron", "President Xi Jinping" ] examples_image = [ ["./examples/example_1_1.jpg", "./examples/example_1_2.jpg"], ["./examples/example_2_1.jpg", "./examples/example_2_2.jpg"], ["./examples/example_3_1.jpg", "./examples/example_3_2.jpg"], ["./examples/example_4_1.jpg", "./examples/example_4_2.jpg"], ] with gr.Blocks(css=css) as demo: gr.Markdown("

TR0N Face Generation Demo

") gr.Markdown("

by Layer 6 AI

") gr.Markdown("""

""") gr.Markdown("We introduce TR0N, a simple and efficient method to add any type of conditioning to pre-trained generative models. For this demo, we add two types of conditioning to a StyleGAN2 model pre-trained on images of human faces. First, we add text-conditioning to turn StyleGAN2 into a text-to-face model. Second, we add image semantic conditioning to StyleGAN2 to enable face-to-face interpolation. For more details and results on many other generative models, please refer to our paper linked above.") with gr.Tab("Text-to-face generation") as text_to_face_generation_demo: text_to_face_generation_input = gr.Textbox(label="Enter your prompt", placeholder="e.g. A man with a beard and glasses", max_lines=1) text_to_face_generation_button = gr.Button("Generate") text_to_face_generation_output = gr.Image(label="Generated image", elem_id="image-gen") text_to_face_generation_examples = gr.Examples(examples=examples_text, fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output) with gr.Tab("Face-to-face interpolation") as face_to_face_interpolation_demo: gr.Markdown("We note that interpolations are not expected to recover the given images, even when the coefficient is 0 or 1.") with gr.Row(): face_to_face_interpolation_input1 = gr.Image(label="Image 1", type="pil") face_to_face_interpolation_input2 = gr.Image(label="Image 2", type="pil") face_to_face_interpolation_lambda = gr.Slider(label="Interpolation coefficient", minimum=0, maximum=1, value=0.5, step=0.01) face_to_face_interpolation_button = gr.Button("Interpolate") face_to_face_interpolation_output = gr.Image(label="Interpolated image", elem_id="image-gen") face_to_face_interpolation_examples = gr.Examples(examples=examples_image, fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output) text_to_face_generation_demo.select(fn=model_mode_text_select) text_to_face_generation_input.submit(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output) text_to_face_generation_button.click(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output) face_to_face_interpolation_demo.select(fn=model_mode_image_select) face_to_face_interpolation_button.click(fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output) demo.queue() demo.launch()