from lib.get_model import get_model, device from lib.sampling import edm_sampler from lib.embedding import extract_features from lib.encoders import StabilityVAEEncoder from lib.cond_gen import get_vae_decoder from safetensors.torch import load_file from torchvision.transforms import ToPILImage import torch import gradio as gr import json torch.set_grad_enabled(False) net = get_model() net.load_state_dict(load_file("model_weights/1girl-edm-xs-test-1.safetensors")) cond_gen = get_vae_decoder().to(device) cond_gen.load_state_dict(load_file("model_weights/condgen_vae_decoder.safetensors")) stability_encoder = StabilityVAEEncoder() def guided(net, scale=1): def f(x, t, label): if scale == 1: return net(x, t, label) return torch.lerp(net(x, t, net.uncond_emb), net(x, t, label), float(scale)) return f @torch.no_grad() def generate_image(label, guidance_scale, n_steps, seed): label = torch.tensor(label)[None].to(device) gen = torch.Generator(device).manual_seed(seed) x = torch.randn((1, 4, 88, 64), device=device, generator=gen) randn_like = lambda *a, **ka: torch.zeros_like(*a, **ka).normal_(generator=gen) im = edm_sampler( guided(net, guidance_scale), x, label, num_steps=n_steps, randn_like=randn_like ) im = stability_encoder.decode(im) return ToPILImage()(im[0]) with gr.Blocks() as demo: selected = [0] with gr.Row(): gr.Markdown( """# 1girl-EDM2-XS-test-1 Demo Demo of a 125M param model trained in 1 GPU-day for generating `1girl solo` images. For those who don't want to wait 10 minutes for an image, a [Colab demo](https://colab.research.google.com/drive/1aBopomQ6wLJtQUTZ1CZn2i-zGpyLGWKM?usp=sharing) is also available. [More info](https://huggingface.co/gustproof/1girl-EDM2-XS-test-1) """ ) with gr.Row(): with gr.Column(): with gr.Group(): btn = gr.Button("Generate", variant="primary") guidance = gr.Slider(1, 15, 5, step=0.1, label="Guidance scale") n_steps = gr.Slider(2, 35, 24, step=1, label="Inference steps") seed = gr.Slider( -1, 2147483647, -1, step=1, label="Random seed (-1: randomize)" ) with gr.Tab("Condition: auto") as auto_tab: gr.Markdown("Conditioning is generated with an external model") with gr.Tab("Condition: from image") as img_tab: gr.Markdown( "Conditioning is extracted from the image with a [tagger](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3). " "This is works like [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and not like [img2img](https://github.com/CompVis/stable-diffusion/blob/main/scripts/img2img.py)" ) ref_im = gr.Image(label="Reference image", type="pil") with gr.Tab("Condition: precomputed") as txt_tab: gr.Markdown("Use a precomputed 1024D vector a the condition") ref_txt = gr.TextArea( label="Precomputed conditioning", placeholder="Copy & Paste from the output", ) with gr.Column(): out_im = gr.Image(label="Generated Image", show_download_button=True) out_seed = gr.Textbox(label="Seed", show_copy_button=True) out_emb = gr.TextArea(label="Condition vector", show_copy_button=True) @torch.no_grad() def get_label(tab_index, cond_img=None, cond_txt=None): if tab_index == 0: return cond_gen(torch.randn((1, 512), device=device))[0].detach().cpu() if tab_index == 1: return extract_features(cond_img, device) return torch.tensor(json.loads(cond_txt)) def on_select(e: gr.SelectData): selected[0] = e.index for t in [auto_tab, img_tab, txt_tab]: t.select(on_select) def main(guidance, n_steps, seed, cond_img=None, cond_txt=None): if seed < 0: seed = torch.randint(0, 2147483647, ()).item() label = get_label(selected[0], cond_img, cond_txt) im = generate_image(label, guidance, n_steps, seed) label_txt = json.dumps(label.numpy().astype(float).round(3).tolist()) return im, seed, label_txt btn.click( main, [guidance, n_steps, seed, ref_im, ref_txt], [out_im, out_seed, out_emb] ) demo.launch()