Spaces:
Runtime error
Runtime error
# Copyright 2023 Adobe Research. All rights reserved. | |
# To view a copy of the license, visit LICENSE.md. | |
# import os | |
# CACHE_DIR = "/exp/domain-expansion/.cache" | |
# os.environ["HF_DATASETS_CACHE"] = CACHE_DIR | |
# if not os.path.exists(CACHE_DIR): | |
# os.mkdir(CACHE_DIR) | |
import torch | |
import gradio as gr | |
from generate_aligned import generate_images | |
def main(): | |
from huggingface_hub import hf_hub_download | |
dog_path = hf_hub_download("alvanlii/adobe-domain-expansion", "afhq50.pkl") | |
human_path = hf_hub_download("alvanlii/adobe-domain-expansion", "ffhq100.pkl") | |
def gen_img(fn_seed, fn_is_dog): | |
torch.manual_seed(fn_seed) | |
# print(fn_is_dog) | |
with torch.no_grad(): | |
imgs = generate_images(dog_path if fn_is_dog else human_path, 2, 1) | |
return imgs | |
def load_examples(): | |
torch.manual_seed(32) | |
# print(fn_is_dog) | |
with torch.no_grad(): | |
imgs = generate_images(dog_path, 2, 1) | |
return 32, 1, imgs[0], imgs[1] | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<h1 style="font-weight: 900; margin-bottom: 7px;"> | |
Domain Expansion of Image Generators (https://arxiv.org/abs/2301.05225) | |
</h1> | |
Yotam Nitzan, Michaël Gharbi, Richard Zhang, Taesung Park, Jun-Yan Zhu, Daniel Cohen-Or, Eli Shechtman <br/> | |
Using the pretrained weights for Humans and Dog faces to generate images in new domains. Only a quarter of the new domains are showcased due to large number of images generated | |
""") | |
with gr.Row(): | |
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True) | |
is_dog = gr.Radio( | |
["Humans", "Doggos"], | |
value="Doggos", | |
type="index", | |
show_label=False, | |
interactive=True | |
) | |
generate_button = gr.Button("Generate") | |
sample_button = gr.Button("Load Example") | |
with gr.Row(): | |
g1 = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery1" | |
).style(grid=[10], height="auto") | |
with gr.Row(): | |
g2 = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery2" | |
).style(grid=[10], height="auto") | |
generate_button.click( | |
fn=gen_img, | |
inputs=[ | |
seed, is_dog | |
], | |
outputs=[g1, g2] | |
) | |
sample_button.click( | |
fn=load_examples, | |
inputs=[], | |
outputs=[seed, is_dog, 1, g2] | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch(share=False, server_name="0.0.0.0") | |
if __name__ == "__main__": | |
main() | |