gustproof's picture
Update description
698eb47
raw
history blame
4.67 kB
from lib.get_model import get_model, device
from lib.sampling import edm_sampler
from lib.embedding import extract_features
from lib.encoders import StabilityVAEEncoder
from lib.cond_gen import get_vae_decoder
from safetensors.torch import load_file
from torchvision.transforms import ToPILImage
import torch
import gradio as gr
import json
torch.set_grad_enabled(False)
net = get_model()
net.load_state_dict(load_file("model_weights/1girl-edm-xs-test-1.safetensors"))
cond_gen = get_vae_decoder().to(device)
cond_gen.load_state_dict(load_file("model_weights/condgen_vae_decoder.safetensors"))
stability_encoder = StabilityVAEEncoder()
def guided(net, scale=1):
def f(x, t, label):
if scale == 1:
return net(x, t, label)
return torch.lerp(net(x, t, net.uncond_emb), net(x, t, label), float(scale))
return f
@torch.no_grad()
def generate_image(label, guidance_scale, n_steps, seed):
label = torch.tensor(label)[None].to(device)
gen = torch.Generator(device).manual_seed(seed)
x = torch.randn((1, 4, 88, 64), device=device, generator=gen)
randn_like = lambda *a, **ka: torch.zeros_like(*a, **ka).normal_(generator=gen)
im = edm_sampler(
guided(net, guidance_scale), x, label, num_steps=n_steps, randn_like=randn_like
)
im = stability_encoder.decode(im)
return ToPILImage()(im[0])
with gr.Blocks() as demo:
selected = [0]
with gr.Row():
gr.Markdown(
"""# 1girl-EDM2-XS-test-1 Demo
Demo of a 125M param model trained in 1 GPU-day for generating `1girl solo` images.
For those who don't want to wait 10 minutes for an image, a [Colab demo](https://colab.research.google.com/drive/1aBopomQ6wLJtQUTZ1CZn2i-zGpyLGWKM?usp=sharing) is also available.
[More info](https://huggingface.co/gustproof/1girl-EDM2-XS-test-1)
"""
)
with gr.Row():
with gr.Column():
with gr.Group():
btn = gr.Button("Generate", variant="primary")
guidance = gr.Slider(1, 15, 5, step=0.1, label="Guidance scale")
n_steps = gr.Slider(2, 35, 24, step=1, label="Inference steps")
seed = gr.Slider(
-1, 2147483647, -1, step=1, label="Random seed (-1: randomize)"
)
with gr.Tab("Condition: auto") as auto_tab:
gr.Markdown("Conditioning is generated with an external model")
with gr.Tab("Condition: from image") as img_tab:
gr.Markdown(
"Conditioning is extracted from the image with a [tagger](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3). "
"This is works like [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and not like [img2img](https://github.com/CompVis/stable-diffusion/blob/main/scripts/img2img.py)"
)
ref_im = gr.Image(label="Reference image", type="pil")
with gr.Tab("Condition: precomputed") as txt_tab:
gr.Markdown("Use a precomputed 1024D vector a the condition")
ref_txt = gr.TextArea(
label="Precomputed conditioning",
placeholder="Copy & Paste from the output",
)
with gr.Column():
out_im = gr.Image(label="Generated Image", show_download_button=True)
out_seed = gr.Textbox(label="Seed", show_copy_button=True)
out_emb = gr.TextArea(label="Condition vector", show_copy_button=True)
@torch.no_grad()
def get_label(tab_index, cond_img=None, cond_txt=None):
if tab_index == 0:
return cond_gen(torch.randn((1, 512), device=device))[0].detach().cpu()
if tab_index == 1:
return extract_features(cond_img, device)
return torch.tensor(json.loads(cond_txt))
def on_select(e: gr.SelectData):
selected[0] = e.index
for t in [auto_tab, img_tab, txt_tab]:
t.select(on_select)
def main(guidance, n_steps, seed, cond_img=None, cond_txt=None):
if seed < 0:
seed = torch.randint(0, 2147483647, ()).item()
label = get_label(selected[0], cond_img, cond_txt)
im = generate_image(label, guidance, n_steps, seed)
label_txt = json.dumps(label.numpy().astype(float).round(3).tolist())
return im, seed, label_txt
btn.click(
main, [guidance, n_steps, seed, ref_im, ref_txt], [out_im, out_seed, out_emb]
)
demo.launch()